1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicU8, AtomicU64},
5 },
6 time::Duration,
7};
8
9use futures::{StreamExt, pin_mut};
10use hopr_utils::runtime::AbortHandle;
11use tracing::{Instrument, instrument};
12
13use super::{
14 BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
15 SurbFlowController, SurbFlowEstimator,
16};
17use crate::SessionId;
18
19#[cfg(all(feature = "telemetry", not(test)))]
20lazy_static::lazy_static! {
21 static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_types::telemetry::MultiGauge =
22 hopr_types::telemetry::MultiGauge::new(
23 "hopr_surb_balancer_target_error_estimate",
24 "Target error estimation by the SURB balancer",
25 &["session_id"]
26 ).unwrap();
27 static ref METRIC_CONTROL_OUTPUT: hopr_types::telemetry::MultiGauge =
28 hopr_types::telemetry::MultiGauge::new(
29 "hopr_surb_balancer_control_output",
30 "Control output of the SURB balancer",
31 &["session_id"]
32 ).unwrap();
33 static ref METRIC_CURRENT_BUFFER: hopr_types::telemetry::MultiGauge =
34 hopr_types::telemetry::MultiGauge::new(
35 "hopr_surb_balancer_current_buffer_estimate",
36 "Estimated number of SURBs in the buffer",
37 &["session_id"]
38 ).unwrap();
39 static ref METRIC_CURRENT_TARGET: hopr_types::telemetry::MultiGauge =
40 hopr_types::telemetry::MultiGauge::new(
41 "hopr_surb_balancer_current_buffer_target",
42 "Current target (setpoint) number of SURBs in the buffer",
43 &["session_id"]
44 ).unwrap();
45 static ref METRIC_SURB_RATE: hopr_types::telemetry::MultiGauge =
46 hopr_types::telemetry::MultiGauge::new(
47 "hopr_surb_balancer_surbs_rate",
48 "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
49 &["session_id"]
50 ).unwrap();
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
55pub struct SurbBalancerConfig {
56 #[default(7_000)]
70 pub target_surb_buffer_size: u64,
71 #[default(5_000)]
80 pub max_surbs_per_sec: u64,
81
82 #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
89 pub surb_decay: Option<(Duration, f64)>,
90}
91
92impl SurbBalancerConfig {
93 #[inline]
95 pub fn as_controller_bounds(&self) -> BalancerControllerBounds {
96 BalancerControllerBounds::new(self.target_surb_buffer_size, self.max_surbs_per_sec)
97 }
98}
99
100#[derive(Debug, Default)]
102pub struct BalancerStateValues {
103 pub target_surb_buffer_size: AtomicU64,
104 pub max_surbs_per_sec: AtomicU64,
105 pub decay_duration_msec: AtomicU64,
106 pub decay_volume_pct: AtomicU8,
107 pub buffer_level: AtomicU64,
108}
109
110impl BalancerStateValues {
111 pub fn new(cfg: SurbBalancerConfig) -> Self {
113 let state = Self::default();
114 state.update(&cfg);
115 state
116 }
117
118 pub fn update(&self, cfg: &SurbBalancerConfig) {
121 self.target_surb_buffer_size
122 .store(cfg.target_surb_buffer_size, std::sync::atomic::Ordering::Relaxed);
123 self.max_surbs_per_sec
124 .store(cfg.max_surbs_per_sec, std::sync::atomic::Ordering::Relaxed);
125 self.decay_duration_msec.store(
126 cfg.surb_decay
127 .map(|(d, _)| d.as_millis().min(u64::MAX as u128) as u64)
128 .unwrap_or_default(),
129 std::sync::atomic::Ordering::Relaxed,
130 );
131 self.decay_volume_pct.store(
132 cfg.surb_decay
133 .map(|(_, p)| (p.clamp(0.0, 1.0) * 100.0).round() as u8)
134 .unwrap_or_default(),
135 std::sync::atomic::Ordering::Relaxed,
136 );
137 }
138
139 pub fn as_config(&self) -> SurbBalancerConfig {
141 SurbBalancerConfig {
142 target_surb_buffer_size: self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
143 max_surbs_per_sec: self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
144 surb_decay: self.surb_decay(),
145 }
146 }
147
148 pub fn is_disabled(&self) -> bool {
150 self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed) == 0
151 }
152
153 pub fn surb_decay(&self) -> Option<(Duration, f64)> {
155 Some((
156 self.decay_duration_msec.load(std::sync::atomic::Ordering::Relaxed),
157 self.decay_volume_pct.load(std::sync::atomic::Ordering::Relaxed),
158 ))
159 .filter(|&(d, p)| d > 0 && p > 0)
160 .map(|(d, p)| (Duration::from_millis(d), p as f64 / 100.0))
161 }
162
163 #[inline]
165 pub fn buffer_level(&self) -> u64 {
166 self.buffer_level.load(std::sync::atomic::Ordering::Relaxed)
167 }
168
169 #[inline]
171 pub fn controller_bounds(&self) -> BalancerControllerBounds {
172 BalancerControllerBounds::new(
173 self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
174 self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
175 )
176 }
177}
178
179impl From<SurbBalancerConfig> for BalancerStateValues {
180 fn from(cfg: SurbBalancerConfig) -> Self {
181 Self::new(cfg)
182 }
183}
184
185pub struct SurbBalancer<C, E, F> {
204 session_id: SessionId,
205 controller: C,
206 surb_estimator: E,
207 flow_control: F,
208 state: Arc<BalancerStateValues>,
209 last_estimator_state: SimpleSurbFlowEstimator,
210 last_update: std::time::Instant,
211 last_decay: std::time::Instant,
212 was_below_target: bool,
213}
214
215impl<C, E, F> SurbBalancer<C, E, F>
216where
217 C: SurbBalancerController + Send + Sync + 'static,
218 E: SurbFlowEstimator + Send + Sync + 'static,
219 F: SurbFlowController + Send + Sync + 'static,
220{
221 pub fn new(
222 session_id: SessionId,
223 mut controller: C,
224 surb_estimator: E,
225 flow_control: F,
226 state: Arc<BalancerStateValues>,
227 ) -> Self {
228 #[cfg(all(feature = "telemetry", not(test)))]
229 {
230 let sid = session_id.to_string();
231 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
232 METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
233 }
234
235 controller.set_target_and_limit(state.controller_bounds());
236
237 Self {
238 surb_estimator,
239 flow_control,
240 controller,
241 session_id,
242 state,
243 last_estimator_state: Default::default(),
244 last_update: std::time::Instant::now(),
245 last_decay: std::time::Instant::now(),
246 was_below_target: true,
247 }
248 }
249
250 #[tracing::instrument(level = "trace", skip_all)]
252 fn update(&mut self) -> u64 {
253 let dt = self.last_update.elapsed();
254
255 let mut current = self.state.buffer_level.load(std::sync::atomic::Ordering::Acquire);
257
258 if dt < Duration::from_millis(10) {
259 tracing::debug!("time elapsed since last update is too short, skipping update");
260 return current;
261 }
262
263 self.last_update = std::time::Instant::now();
264
265 let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
267 let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
268 tracing::error!("non-monotonic change in SURB estimators");
269 return current;
270 };
271
272 self.last_estimator_state = snapshot;
273 current = current.saturating_add_signed(target_buffer_change);
274
275 if let Some(num_decayed_surbs) = self
278 .state
279 .surb_decay()
280 .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
281 .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * decay_coeff).round() as u64)
282 {
283 current = current.saturating_sub(num_decayed_surbs);
284 self.last_decay = std::time::Instant::now();
285 tracing::trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
286 }
287
288 self.state
289 .buffer_level
290 .store(current, std::sync::atomic::Ordering::Release);
291
292 let error = current as i64 - self.controller.bounds().target() as i64;
294
295 if self.was_below_target && error >= 0 {
296 tracing::trace!(current, "reached target SURB buffer size");
297 self.was_below_target = false;
298 } else if !self.was_below_target && error < 0 {
299 tracing::trace!(current, "SURB buffer size is below target");
300 self.was_below_target = true;
301 }
302
303 tracing::trace!(
304 ?dt,
305 delta = target_buffer_change,
306 rate = target_buffer_change as f64 / dt.as_secs_f64(),
307 current,
308 error,
309 "estimated SURB buffer change"
310 );
311
312 let output = self.controller.next_control_output(current);
313 tracing::trace!(output, "next balancer control output for session");
314
315 self.flow_control.adjust_surb_flow(output as usize);
316
317 #[cfg(all(feature = "telemetry", not(test)))]
318 {
319 let sid = self.session_id.to_string();
320 METRIC_CURRENT_BUFFER.set(&[&sid], current as f64);
321 METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
322 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
323 METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
324 METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
325 }
326
327 current
328 }
329
330 #[instrument(level = "debug", skip(self), fields(session_id = %self.session_id))]
339 pub fn start_control_loop(
340 mut self,
341 sampling_interval: Duration,
342 ) -> (impl futures::Stream<Item = u64>, AbortHandle) {
343 let (abort_handle, abort_reg) = AbortHandle::new_pair();
344
345 let sampling_stream = futures::stream::Abortable::new(
347 futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
348 abort_reg,
349 );
350
351 let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
352 .ok()
353 .and_then(|s| s.trim().parse::<usize>().ok())
354 .filter(|&c| c > 0)
355 .unwrap_or(32_768);
356
357 tracing::debug!(
358 capacity = balancer_level_capacity,
359 "Creating session balancer level channel"
360 );
361 let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
362 hopr_utils::runtime::prelude::spawn(
363 async move {
364 pin_mut!(sampling_stream);
365 while sampling_stream.next().await.is_some() {
366 let current_bounds = self.state.controller_bounds();
368 if current_bounds != self.controller.bounds() {
369 self.controller.set_target_and_limit(current_bounds);
370 tracing::debug!(new_cfg = ?self.state.as_config(), "surb balancer has been reconfigured");
371 }
372
373 let level = self.update();
377 if !level_tx.is_closed()
378 && let Err(error) = level_tx.try_send(level)
379 {
380 tracing::error!(%error, "cannot send balancer level update");
381 }
382 }
383
384 tracing::debug!("balancer done");
385 }
386 .in_current_span(),
387 );
388
389 (level_rx, abort_handle)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::sync::{Arc, atomic::AtomicU64};
396
397 use hopr_types::{crypto_random::Randomizable, internal::prelude::HoprPseudonym};
398
399 use super::*;
400 use crate::balancer::{AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController};
401
402 #[test]
403 fn surb_balancer_config_should_be_convertible_to_atomics() {
404 let cfg = SurbBalancerConfig::default();
405 let state_data = BalancerStateValues::new(cfg);
406 assert_eq!(cfg, state_data.as_config());
407 }
408
409 #[test]
410 fn surb_balancer_config_default_snapshot() {
411 let cfg = SurbBalancerConfig::default();
412 insta::assert_debug_snapshot!(cfg);
413 }
414
415 #[test]
416 fn surb_balancer_config_as_controller_bounds() {
417 let cfg = SurbBalancerConfig {
418 target_surb_buffer_size: 1000,
419 max_surbs_per_sec: 500,
420 surb_decay: None,
421 };
422 let bounds = cfg.as_controller_bounds();
423 assert_eq!(bounds.target(), 1000);
424 assert_eq!(bounds.output_limit(), 500);
425 }
426
427 #[test]
428 fn balancer_state_values_disabled_when_target_is_zero() {
429 let cfg = SurbBalancerConfig {
430 target_surb_buffer_size: 0,
431 max_surbs_per_sec: 0,
432 surb_decay: None,
433 };
434 let state = BalancerStateValues::new(cfg);
435 assert!(state.is_disabled());
436 }
437
438 #[test]
439 fn balancer_state_values_enabled_when_target_is_nonzero() {
440 let state = BalancerStateValues::new(SurbBalancerConfig::default());
441 assert!(!state.is_disabled());
442 }
443
444 #[test]
445 fn balancer_state_values_update_propagates_all_fields() {
446 let state = BalancerStateValues::default();
447 let cfg = SurbBalancerConfig {
448 target_surb_buffer_size: 3000,
449 max_surbs_per_sec: 1500,
450 surb_decay: Some((Duration::from_secs(30), 0.10)),
451 };
452 state.update(&cfg);
453 assert_eq!(state.as_config(), cfg);
454 assert_eq!(state.controller_bounds(), cfg.as_controller_bounds());
455 }
456
457 #[test]
458 fn balancer_state_values_surb_decay_none_maps_to_none() {
459 let cfg = SurbBalancerConfig {
460 target_surb_buffer_size: 1000,
461 max_surbs_per_sec: 500,
462 surb_decay: None,
463 };
464 let state = BalancerStateValues::new(cfg);
465 assert!(state.surb_decay().is_none());
466 }
467
468 #[test]
469 fn balancer_state_values_buffer_level_default_is_zero() {
470 let state = BalancerStateValues::default();
471 assert_eq!(state.buffer_level(), 0);
472 }
473
474 #[test]
475 fn balancer_state_values_buffer_level_can_be_updated() {
476 let state = BalancerStateValues::default();
477 state.buffer_level.store(42, std::sync::atomic::Ordering::Relaxed);
478 assert_eq!(state.buffer_level(), 42);
479 }
480
481 #[test]
482 fn balancer_state_values_from_config() {
483 let cfg = SurbBalancerConfig {
484 target_surb_buffer_size: 5000,
485 max_surbs_per_sec: 2500,
486 surb_decay: Some((Duration::from_secs(60), 0.05)),
487 };
488 let state: BalancerStateValues = cfg.into();
489 assert_eq!(state.as_config(), cfg);
490 }
491
492 #[test]
493 fn balancer_state_values_decay_zero_duration_should_map_to_none() {
494 let cfg = SurbBalancerConfig {
495 surb_decay: Some((Duration::ZERO, 0.10)),
496 ..Default::default()
497 };
498 let state = BalancerStateValues::new(cfg);
499 assert!(
500 state.surb_decay().is_none(),
501 "zero duration decay should be filtered out"
502 );
503 }
504
505 #[test]
506 fn balancer_state_values_decay_zero_percent_should_map_to_none() {
507 let cfg = SurbBalancerConfig {
508 surb_decay: Some((Duration::from_secs(60), 0.0)),
509 ..Default::default()
510 };
511 let state = BalancerStateValues::new(cfg);
512 assert!(
513 state.surb_decay().is_none(),
514 "zero percent decay should be filtered out"
515 );
516 }
517
518 #[test]
519 fn balancer_state_values_decay_should_clamp_above_one() {
520 let cfg = SurbBalancerConfig {
521 surb_decay: Some((Duration::from_secs(1), 1.5)), ..Default::default()
523 };
524 let state = BalancerStateValues::new(cfg);
525 let (_, pct) = state.surb_decay().expect("decay should be present");
526 assert!((pct - 1.0).abs() < f64::EPSILON, "percentage should be clamped to 1.0");
527 }
528
529 #[test_log::test]
530 fn surb_balancer_should_start_increase_level_when_below_target() {
531 let production_rate = Arc::new(AtomicU64::new(0));
532 let consumption_rate = 100;
533 let steps = 3;
534 let step_duration = std::time::Duration::from_millis(1000);
535
536 let mut controller = MockSurbFlowController::new();
537 let production_rate_clone = production_rate.clone();
538 controller
539 .expect_adjust_surb_flow()
540 .times(steps)
541 .with(mockall::predicate::ge(100))
542 .returning(move |r| {
543 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
544 });
545
546 let surb_estimator = AtomicSurbFlowEstimator::default();
547 let mut balancer = SurbBalancer::new(
548 SessionId::new(1234_u64, HoprPseudonym::random()),
549 PidBalancerController::default(),
550 surb_estimator.clone(),
551 controller,
552 Arc::new(
553 SurbBalancerConfig {
554 target_surb_buffer_size: 5_000,
555 max_surbs_per_sec: 2500,
556 surb_decay: None,
557 }
558 .into(),
559 ),
560 );
561
562 let mut last_update = 0;
563 for i in 0..steps {
564 std::thread::sleep(step_duration);
565 surb_estimator.produced.fetch_add(
566 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
567 std::sync::atomic::Ordering::Relaxed,
568 );
569 surb_estimator.consumed.fetch_add(
570 consumption_rate * step_duration.as_secs(),
571 std::sync::atomic::Ordering::Relaxed,
572 );
573
574 let next_update = balancer.update();
575 assert!(
576 i == 0 || next_update > last_update,
577 "{next_update} should be greater than {last_update}"
578 );
579 last_update = next_update;
580 }
581 }
582
583 #[test_log::test]
584 fn surb_balancer_should_start_decrease_level_when_above_target() {
585 let production_rate = Arc::new(AtomicU64::new(11_000));
586 let consumption_rate = 100;
587 let steps = 3;
588 let step_duration = std::time::Duration::from_millis(1000);
589
590 let mut controller = MockSurbFlowController::new();
591 let production_rate_clone = production_rate.clone();
592 controller
593 .expect_adjust_surb_flow()
594 .times(steps)
595 .with(mockall::predicate::ge(0))
596 .returning(move |r| {
597 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
598 });
599
600 let surb_estimator = AtomicSurbFlowEstimator::default();
601 let mut balancer = SurbBalancer::new(
602 SessionId::new(1234_u64, HoprPseudonym::random()),
603 PidBalancerController::default(),
604 surb_estimator.clone(),
605 controller,
606 Arc::new(
607 SurbBalancerConfig {
608 surb_decay: None,
609 ..Default::default()
610 }
611 .into(),
612 ),
613 );
614
615 let mut last_update = 0;
616 for i in 0..steps {
617 std::thread::sleep(step_duration);
618 surb_estimator.produced.fetch_add(
619 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
620 std::sync::atomic::Ordering::Relaxed,
621 );
622 surb_estimator.consumed.fetch_add(
623 consumption_rate * step_duration.as_secs(),
624 std::sync::atomic::Ordering::Relaxed,
625 );
626
627 let next_update = balancer.update();
628 assert!(
629 i == 0 || next_update < last_update,
630 "{next_update} should be greater than {last_update}"
631 );
632 last_update = next_update;
633 }
634 }
635
636 #[test_log::test(tokio::test)]
637 async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
638 const NUM_STEPS: usize = 5;
639 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
640 let cfg = SurbBalancerConfig {
641 target_surb_buffer_size: 5_000,
642 max_surbs_per_sec: 2500,
643 surb_decay: Some((Duration::from_millis(200), 0.05)),
644 };
645
646 let mut mock_flow_ctl = MockSurbFlowController::new();
647 mock_flow_ctl
648 .expect_adjust_surb_flow()
649 .times(NUM_STEPS)
650 .returning(|_| ());
651
652 let balancer = SurbBalancer::new(
653 session_id,
654 PidBalancerController::default(),
655 SimpleSurbFlowEstimator::default(),
656 mock_flow_ctl,
657 Arc::new(cfg.into()),
658 );
659
660 balancer
661 .state
662 .buffer_level
663 .store(5000, std::sync::atomic::Ordering::Relaxed);
664
665 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100));
666 let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
667 handle.abort();
668
669 assert_eq!(levels.len(), NUM_STEPS);
670 assert!(
671 levels.windows(2).all(|w| w[1] <= w[0]),
672 "buffer levels should be monotonic non-increasing: {levels:?}"
673 );
674 assert!(
675 levels.last().is_some_and(|last| *last < 5_000),
676 "expected at least one decay step: {levels:?}"
677 );
678 }
679}