1use std::time::Duration;
2
3use futures::{StreamExt, future::AbortRegistration, pin_mut};
4use hopr_async_runtime::AbortHandle;
5use tracing::{Instrument, debug, error, instrument, trace};
6
7use super::{
8 BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
9 SurbFlowController, SurbFlowEstimator,
10};
11use crate::SessionId;
12
13#[cfg(all(feature = "prometheus", not(test)))]
14lazy_static::lazy_static! {
15 static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::metrics::MultiGauge =
16 hopr_metrics::metrics::MultiGauge::new(
17 "hopr_surb_balancer_target_error_estimate",
18 "Target error estimation by the SURB balancer",
19 &["session_id"]
20 ).unwrap();
21 static ref METRIC_CONTROL_OUTPUT: hopr_metrics::metrics::MultiGauge =
22 hopr_metrics::metrics::MultiGauge::new(
23 "hopr_surb_balancer_control_output",
24 "Control output of the SURB balancer",
25 &["session_id"]
26 ).unwrap();
27 static ref METRIC_CURRENT_BUFFER: hopr_metrics::metrics::MultiGauge =
28 hopr_metrics::metrics::MultiGauge::new(
29 "hopr_surb_balancer_current_buffer_estimate",
30 "Estimated number of SURBs in the buffer",
31 &["session_id"]
32 ).unwrap();
33 static ref METRIC_CURRENT_TARGET: hopr_metrics::metrics::MultiGauge =
34 hopr_metrics::metrics::MultiGauge::new(
35 "hopr_surb_balancer_current_buffer_target",
36 "Current target (setpoint) number of SURBs in the buffer",
37 &["session_id"]
38 ).unwrap();
39 static ref METRIC_SURB_RATE: hopr_metrics::metrics::MultiGauge =
40 hopr_metrics::metrics::MultiGauge::new(
41 "hopr_surb_balancer_surbs_rate",
42 "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
43 &["session_id"]
44 ).unwrap();
45}
46
47#[cfg_attr(test, mockall::automock)]
49#[async_trait::async_trait]
50pub trait BalancerConfigFeedback {
51 async fn get_config(&self, id: &SessionId) -> crate::errors::Result<SurbBalancerConfig>;
53 async fn on_config_update(&self, id: &SessionId, cfg: SurbBalancerConfig) -> crate::errors::Result<()>;
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
59pub struct SurbBalancerConfig {
60 #[default(7_000)]
74 pub target_surb_buffer_size: u64,
75 #[default(5_000)]
83 pub max_surbs_per_sec: u64,
84
85 #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
92 pub surb_decay: Option<(Duration, f64)>,
93}
94
95pub struct SurbBalancer<C, E, F> {
114 session_id: SessionId,
115 controller: C,
116 surb_estimator: E,
117 flow_control: F,
118 current_buffer: u64,
119 last_estimator_state: SimpleSurbFlowEstimator,
120 last_update: std::time::Instant,
121 last_decay: std::time::Instant,
122 was_below_target: bool,
123}
124
125impl<C, E, F> SurbBalancer<C, E, F>
126where
127 C: SurbBalancerController + Send + Sync + 'static,
128 E: SurbFlowEstimator + Send + Sync + 'static,
129 F: SurbFlowController + Send + Sync + 'static,
130{
131 pub fn new(
132 session_id: SessionId,
133 mut controller: C,
134 surb_estimator: E,
135 flow_control: F,
136 cfg: SurbBalancerConfig,
137 ) -> Self {
138 #[cfg(all(feature = "prometheus", not(test)))]
139 {
140 let sid = session_id.to_string();
141 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
142 METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
143 }
144
145 controller.set_target_and_limit(BalancerControllerBounds::new(
146 cfg.target_surb_buffer_size,
147 cfg.max_surbs_per_sec,
148 ));
149
150 Self {
151 surb_estimator,
152 flow_control,
153 controller,
154 session_id,
155 current_buffer: 0,
156 last_estimator_state: Default::default(),
157 last_update: std::time::Instant::now(),
158 last_decay: std::time::Instant::now(),
159 was_below_target: true,
160 }
161 }
162
163 #[tracing::instrument(level = "trace", skip_all)]
165 fn update(&mut self, surb_decay: Option<&(Duration, f64)>) -> u64 {
166 let dt = self.last_update.elapsed();
167 if dt < Duration::from_millis(10) {
168 debug!("time elapsed since last update is too short, skipping update");
169 return self.current_buffer;
170 }
171
172 self.last_update = std::time::Instant::now();
173
174 let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
176 let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
177 error!("non-monotonic change in SURB estimators");
178 return self.current_buffer;
179 };
180
181 self.last_estimator_state = snapshot;
182 self.current_buffer = self.current_buffer.saturating_add_signed(target_buffer_change);
183
184 if let Some(num_decayed_surbs) = surb_decay
187 .as_ref()
188 .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
189 .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * *decay_coeff).round() as u64)
190 {
191 self.current_buffer = self.current_buffer.saturating_sub(num_decayed_surbs);
192 self.last_decay = std::time::Instant::now();
193 trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
194 }
195
196 let error = self.current_buffer as i64 - self.controller.bounds().target() as i64;
198
199 if self.was_below_target && error >= 0 {
200 trace!(current = self.current_buffer, "reached target SURB buffer size");
201 self.was_below_target = false;
202 } else if !self.was_below_target && error < 0 {
203 trace!(current = self.current_buffer, "SURB buffer size is below target");
204 self.was_below_target = true;
205 }
206
207 trace!(
208 ?dt,
209 delta = target_buffer_change,
210 rate = target_buffer_change as f64 / dt.as_secs_f64(),
211 current = self.current_buffer,
212 error,
213 "estimated SURB buffer change"
214 );
215
216 let output = self.controller.next_control_output(self.current_buffer);
217 trace!(output, "next balancer control output for session");
218
219 self.flow_control.adjust_surb_flow(output as usize);
220
221 #[cfg(all(feature = "prometheus", not(test)))]
222 {
223 let sid = self.session_id.to_string();
224 METRIC_CURRENT_BUFFER.set(&[&sid], self.current_buffer as f64);
225 METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
226 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
227 METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
228 METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
229 }
230
231 self.current_buffer
232 }
233
234 #[instrument(level = "debug", skip_all, fields(session_id = %self.session_id))]
244 pub fn start_control_loop<B>(
245 mut self,
246 sampling_interval: Duration,
247 cfg_feedback: B,
248 abort_reg: Option<AbortRegistration>,
249 ) -> (impl futures::Stream<Item = u64>, AbortHandle)
250 where
251 B: BalancerConfigFeedback + Send + Sync + 'static,
252 {
253 let (abort_handle, abort_reg) = abort_reg
255 .map(|reg| (reg.handle(), reg))
256 .unwrap_or_else(AbortHandle::new_pair);
257
258 let sampling_stream = futures::stream::Abortable::new(
260 futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
261 abort_reg,
262 );
263
264 let (level_tx, level_rx) = futures::channel::mpsc::unbounded();
265 hopr_async_runtime::prelude::spawn(
266 async move {
267 pin_mut!(sampling_stream);
268 while sampling_stream.next().await.is_some() {
269 let Ok(mut current_cfg) = cfg_feedback.get_config(&self.session_id).await else {
270 error!("cannot get config for session");
271 break;
272 };
273
274 let current_bounds = BalancerControllerBounds::new(
275 current_cfg.target_surb_buffer_size,
276 current_cfg.max_surbs_per_sec,
277 );
278
279 if current_bounds != self.controller.bounds() {
281 self.controller.set_target_and_limit(current_bounds);
282 debug!(?current_cfg, "surb balancer has been reconfigured");
283 }
284
285 let bounds_before_update = self.controller.bounds();
286
287 let level = self.update(current_cfg.surb_decay.as_ref());
290 let _ = level_tx.unbounded_send(level);
291
292 let bounds_after_update = self.controller.bounds();
296 if bounds_before_update != bounds_after_update {
297 current_cfg.target_surb_buffer_size = bounds_after_update.target();
298 current_cfg.max_surbs_per_sec = bounds_after_update.output_limit();
299 match cfg_feedback.on_config_update(&self.session_id, current_cfg).await {
300 Ok(_) => debug!(
301 ?bounds_before_update,
302 ?bounds_after_update,
303 "controller bounds has changed after update"
304 ),
305 Err(error) => error!(%error, "failed to update controller bounds after it changed"),
306 }
307 }
308 }
309
310 debug!("balancer done");
311 }
312 .in_current_span(),
313 );
314
315 (level_rx, abort_handle)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use std::sync::{Arc, atomic::AtomicU64};
322
323 use hopr_crypto_random::Randomizable;
324 use hopr_internal_types::prelude::HoprPseudonym;
325
326 use super::*;
327 use crate::balancer::{
328 AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController, simple::SimpleBalancerController,
329 };
330
331 #[test_log::test]
332 fn surb_balancer_should_start_increase_level_when_below_target() {
333 let production_rate = Arc::new(AtomicU64::new(0));
334 let consumption_rate = 100;
335 let steps = 3;
336 let step_duration = std::time::Duration::from_millis(1000);
337
338 let mut controller = MockSurbFlowController::new();
339 let production_rate_clone = production_rate.clone();
340 controller
341 .expect_adjust_surb_flow()
342 .times(steps)
343 .with(mockall::predicate::ge(100))
344 .returning(move |r| {
345 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
346 });
347
348 let surb_estimator = AtomicSurbFlowEstimator::default();
349 let mut balancer = SurbBalancer::new(
350 SessionId::new(1234_u64, HoprPseudonym::random()),
351 PidBalancerController::default(),
352 surb_estimator.clone(),
353 controller,
354 SurbBalancerConfig {
355 target_surb_buffer_size: 5_000,
356 max_surbs_per_sec: 2500,
357 ..Default::default()
358 },
359 );
360
361 let mut last_update = 0;
362 for i in 0..steps {
363 std::thread::sleep(step_duration);
364 surb_estimator.produced.fetch_add(
365 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
366 std::sync::atomic::Ordering::Relaxed,
367 );
368 surb_estimator.consumed.fetch_add(
369 consumption_rate * step_duration.as_secs(),
370 std::sync::atomic::Ordering::Relaxed,
371 );
372
373 let next_update = balancer.update(None);
374 assert!(
375 i == 0 || next_update > last_update,
376 "{next_update} should be greater than {last_update}"
377 );
378 last_update = next_update;
379 }
380 }
381
382 #[test_log::test]
383 fn surb_balancer_should_start_decrease_level_when_above_target() {
384 let production_rate = Arc::new(AtomicU64::new(11_000));
385 let consumption_rate = 100;
386 let steps = 3;
387 let step_duration = std::time::Duration::from_millis(1000);
388
389 let mut controller = MockSurbFlowController::new();
390 let production_rate_clone = production_rate.clone();
391 controller
392 .expect_adjust_surb_flow()
393 .times(steps)
394 .with(mockall::predicate::ge(0))
395 .returning(move |r| {
396 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
397 });
398
399 let surb_estimator = AtomicSurbFlowEstimator::default();
400 let mut balancer = SurbBalancer::new(
401 SessionId::new(1234_u64, HoprPseudonym::random()),
402 PidBalancerController::default(),
403 surb_estimator.clone(),
404 controller,
405 SurbBalancerConfig::default(),
406 );
407
408 let mut last_update = 0;
409 for i in 0..steps {
410 std::thread::sleep(step_duration);
411 surb_estimator.produced.fetch_add(
412 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
413 std::sync::atomic::Ordering::Relaxed,
414 );
415 surb_estimator.consumed.fetch_add(
416 consumption_rate * step_duration.as_secs(),
417 std::sync::atomic::Ordering::Relaxed,
418 );
419
420 let next_update = balancer.update(None);
421 assert!(
422 i == 0 || next_update < last_update,
423 "{next_update} should be greater than {last_update}"
424 );
425 last_update = next_update;
426 }
427 }
428
429 #[test_log::test(tokio::test)]
430 async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
431 const NUM_STEPS: usize = 5;
432 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
433 let cfg = SurbBalancerConfig {
434 target_surb_buffer_size: 5_000,
435 max_surbs_per_sec: 2500,
436 surb_decay: Some((Duration::from_millis(200), 0.05)),
437 ..Default::default()
438 };
439
440 let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
441 mock_balancer_feedback
442 .expect_get_config()
443 .with(mockall::predicate::eq(session_id))
444 .times(NUM_STEPS)
445 .returning(move |_| Ok(cfg));
446
447 mock_balancer_feedback.expect_on_config_update().never();
448
449 let mut mock_flow_ctl = MockSurbFlowController::new();
450 mock_flow_ctl
451 .expect_adjust_surb_flow()
452 .times(NUM_STEPS)
453 .returning(|_| ());
454
455 let mut balancer = SurbBalancer::new(
456 session_id,
457 PidBalancerController::default(),
458 SimpleSurbFlowEstimator::default(),
459 mock_flow_ctl,
460 cfg,
461 );
462
463 balancer.current_buffer = 5000;
464
465 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
466 let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
467 handle.abort();
468
469 assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
470 }
471
472 struct IterSurbFlowEstimator<P, C>(std::sync::Mutex<P>, std::sync::Mutex<C>);
473
474 impl<P, C> IterSurbFlowEstimator<P, C> {
475 fn new<I1, I2>(production: I1, consumption: I2) -> Self
476 where
477 I1: IntoIterator<Item = u64, IntoIter = P>,
478 I2: IntoIterator<Item = u64, IntoIter = C>,
479 {
480 Self(
481 std::sync::Mutex::new(production.into_iter()),
482 std::sync::Mutex::new(consumption.into_iter()),
483 )
484 }
485 }
486
487 impl<P, C> SurbFlowEstimator for IterSurbFlowEstimator<P, C>
488 where
489 P: Iterator<Item = u64>,
490 C: Iterator<Item = u64>,
491 {
492 fn estimate_surbs_consumed(&self) -> u64 {
493 self.1.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
494 }
495
496 fn estimate_surbs_produced(&self) -> u64 {
497 self.0.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
498 }
499 }
500
501 #[test_log::test(tokio::test)]
502 async fn surb_balancer_should_increase_target_when_using_simple_controller() {
503 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
504 let cfg_1 = SurbBalancerConfig {
505 target_surb_buffer_size: 4500,
506 max_surbs_per_sec: 2500,
507 surb_decay: None,
508 ..Default::default()
509 };
510
511 let cfg_2 = SurbBalancerConfig {
512 target_surb_buffer_size: 5500,
513 max_surbs_per_sec: 3055,
514 surb_decay: None,
515 ..Default::default()
516 };
517
518 let mut seq = mockall::Sequence::new();
519
520 let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
521 mock_balancer_feedback
522 .expect_get_config()
523 .times(3)
524 .in_sequence(&mut seq)
525 .with(mockall::predicate::eq(session_id))
526 .returning(move |_| {
527 tracing::trace!("get config 1");
528 Ok(cfg_1)
529 });
530
531 mock_balancer_feedback
532 .expect_on_config_update()
533 .once()
534 .in_sequence(&mut seq)
535 .with(mockall::predicate::eq(session_id), mockall::predicate::eq(cfg_2))
536 .returning(|_, _| {
537 tracing::trace!("on config update");
538 Ok(())
539 });
540
541 mock_balancer_feedback
542 .expect_get_config()
543 .times(2)
544 .in_sequence(&mut seq)
545 .with(mockall::predicate::eq(session_id))
546 .returning(move |_| {
547 tracing::trace!("get config 2");
548 Ok(cfg_2)
549 });
550
551 let mut mock_flow_ctl = MockSurbFlowController::new();
552 mock_flow_ctl.expect_adjust_surb_flow().times(5).returning(|_| ());
553
554 let mut balancer = SurbBalancer::new(
555 session_id,
556 SimpleBalancerController::with_increasing_setpoint(0.2, 5),
557 IterSurbFlowEstimator::new([1000, 2000, 3000, 3000, 3000], vec![500, 1000, 1500, 1500, 1500]),
558 mock_flow_ctl,
559 cfg_1,
560 );
561
562 balancer.current_buffer = 4500;
563
564 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
565
566 let levels = stream.take(5).collect::<Vec<_>>().await;
567 handle.abort();
568
569 assert_eq!(levels, vec![5000, 5500, 6000, 6000, 6000]);
570 }
571}