Skip to main content

hopr_transport_session/balancer/
controller.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicU8, AtomicU64},
5    },
6    time::Duration,
7};
8
9use futures::{StreamExt, pin_mut};
10use hopr_utils::runtime::AbortHandle;
11use tracing::{Instrument, instrument};
12
13use super::{
14    BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
15    SurbFlowController, SurbFlowEstimator,
16};
17use crate::SessionId;
18
19#[cfg(all(feature = "telemetry", not(test)))]
20lazy_static::lazy_static! {
21    static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_types::telemetry::MultiGauge =
22        hopr_types::telemetry::MultiGauge::new(
23            "hopr_surb_balancer_target_error_estimate",
24            "Target error estimation by the SURB balancer",
25            &["session_id"]
26    ).unwrap();
27    static ref METRIC_CONTROL_OUTPUT: hopr_types::telemetry::MultiGauge =
28        hopr_types::telemetry::MultiGauge::new(
29            "hopr_surb_balancer_control_output",
30            "Control output of the SURB balancer",
31            &["session_id"]
32    ).unwrap();
33    static ref METRIC_CURRENT_BUFFER: hopr_types::telemetry::MultiGauge =
34        hopr_types::telemetry::MultiGauge::new(
35            "hopr_surb_balancer_current_buffer_estimate",
36            "Estimated number of SURBs in the buffer",
37            &["session_id"]
38    ).unwrap();
39    static ref METRIC_CURRENT_TARGET: hopr_types::telemetry::MultiGauge =
40        hopr_types::telemetry::MultiGauge::new(
41            "hopr_surb_balancer_current_buffer_target",
42            "Current target (setpoint) number of SURBs in the buffer",
43            &["session_id"]
44    ).unwrap();
45    static ref METRIC_SURB_RATE: hopr_types::telemetry::MultiGauge =
46        hopr_types::telemetry::MultiGauge::new(
47            "hopr_surb_balancer_surbs_rate",
48            "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
49            &["session_id"]
50    ).unwrap();
51}
52
53/// Configuration for the `SurbBalancer`.
54#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
55pub struct SurbBalancerConfig {
56    /// The desired number of SURBs to be always kept as a buffer locally or at the Session counterparty.
57    ///
58    /// The `SurbBalancer` will try to maintain approximately this number of SURBs
59    /// locally or remotely (at the counterparty) at all times.
60    ///
61    /// The local buffer is maintained by regulating (`SurbFlowController`) the egress from the Session.
62    /// The remote buffer (at session counterparty) is maintained by regulating the flow of non-organic SURBs via
63    /// keep-alive messages.
64    ///
65    /// It does not make sense to set this value higher than the [`max_surb_buffer_size`](crate::SessionManagerConfig)
66    /// configuration at the counterparty.
67    ///
68    /// Default is 7000 SURBs.
69    #[default(7_000)]
70    pub target_surb_buffer_size: u64,
71    /// Maximum outflow of SURBs.
72    ///
73    /// - In the context of the local SURB buffer (Entry), this is the maximum egress Session traffic (= SURB
74    ///   consumption).
75    /// - In the context of the remote SURB buffer (Exit), this is the maximum egress of keep-alive messages to the
76    ///   counterparty (= artificial SURB production).
77    ///
78    /// The default is 5000 (which is 2500 packets/second currently)
79    #[default(5_000)]
80    pub max_surbs_per_sec: u64,
81
82    /// Sets what percentage of the target buffer size should be discarded at each window.
83    ///
84    /// The `SurbBalancer` will discard the given percentage of `target_surb_buffer_size` at each
85    /// window with the given `Duration`.
86    ///
87    /// The default is `(60, 0.05)` (5% of the target buffer size is discarded every 60 seconds).
88    #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
89    pub surb_decay: Option<(Duration, f64)>,
90}
91
92impl SurbBalancerConfig {
93    /// Convenience function to convert the [`SurbBalancerConfig`] into `BalancerControllerBounds`.
94    #[inline]
95    pub fn as_controller_bounds(&self) -> BalancerControllerBounds {
96        BalancerControllerBounds::new(self.target_surb_buffer_size, self.max_surbs_per_sec)
97    }
98}
99
100/// Runtime state of the `SurbBalancer`.
101#[derive(Debug, Default)]
102pub struct BalancerStateValues {
103    pub target_surb_buffer_size: AtomicU64,
104    pub max_surbs_per_sec: AtomicU64,
105    pub decay_duration_msec: AtomicU64,
106    pub decay_volume_pct: AtomicU8,
107    pub buffer_level: AtomicU64,
108}
109
110impl BalancerStateValues {
111    /// Constructor from a [`SurbBalancerConfig`].
112    pub fn new(cfg: SurbBalancerConfig) -> Self {
113        let state = Self::default();
114        state.update(&cfg);
115        state
116    }
117
118    /// Performs update of the [`BalancerStateValues`] from the [`SurbBalancerConfig`] and
119    /// enables it.
120    pub fn update(&self, cfg: &SurbBalancerConfig) {
121        self.target_surb_buffer_size
122            .store(cfg.target_surb_buffer_size, std::sync::atomic::Ordering::Relaxed);
123        self.max_surbs_per_sec
124            .store(cfg.max_surbs_per_sec, std::sync::atomic::Ordering::Relaxed);
125        self.decay_duration_msec.store(
126            cfg.surb_decay
127                .map(|(d, _)| d.as_millis().min(u64::MAX as u128) as u64)
128                .unwrap_or_default(),
129            std::sync::atomic::Ordering::Relaxed,
130        );
131        self.decay_volume_pct.store(
132            cfg.surb_decay
133                .map(|(_, p)| (p.clamp(0.0, 1.0) * 100.0).round() as u8)
134                .unwrap_or_default(),
135            std::sync::atomic::Ordering::Relaxed,
136        );
137    }
138
139    /// Extracts the [`SurbBalancerConfig`] from the [`BalancerStateValues`].
140    pub fn as_config(&self) -> SurbBalancerConfig {
141        SurbBalancerConfig {
142            target_surb_buffer_size: self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
143            max_surbs_per_sec: self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
144            surb_decay: self.surb_decay(),
145        }
146    }
147
148    /// Checks if SURB balancing is disabled (no target buffer size set).
149    pub fn is_disabled(&self) -> bool {
150        self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed) == 0
151    }
152
153    /// Extracts the SURB decay configuration from the [`BalancerStateValues`].
154    pub fn surb_decay(&self) -> Option<(Duration, f64)> {
155        Some((
156            self.decay_duration_msec.load(std::sync::atomic::Ordering::Relaxed),
157            self.decay_volume_pct.load(std::sync::atomic::Ordering::Relaxed),
158        ))
159        .filter(|&(d, p)| d > 0 && p > 0)
160        .map(|(d, p)| (Duration::from_millis(d), p as f64 / 100.0))
161    }
162
163    /// Gets the current estimated SURB buffer level.
164    #[inline]
165    pub fn buffer_level(&self) -> u64 {
166        self.buffer_level.load(std::sync::atomic::Ordering::Relaxed)
167    }
168
169    /// Returns the current `BalancerControllerBounds` from the [`BalancerStateValues`].
170    #[inline]
171    pub fn controller_bounds(&self) -> BalancerControllerBounds {
172        BalancerControllerBounds::new(
173            self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
174            self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
175        )
176    }
177}
178
179impl From<SurbBalancerConfig> for BalancerStateValues {
180    fn from(cfg: SurbBalancerConfig) -> Self {
181        Self::new(cfg)
182    }
183}
184
185/// Runs a continuous process that attempts to [evaluate](SurbFlowEstimator) and
186/// [regulate](SurbFlowController) the flow of SURBs to the Session counterparty,
187/// to keep the number of SURBs locally or at the counterparty at a certain level.
188///
189/// Internally, the Balancer uses an implementation of [`SurbBalancerController`] to
190/// control the rate of SURBs consumed or sent to the counterparty
191/// each time the [`update`](SurbBalancer::update) method is called:
192///
193/// 1. The size of the SURB buffer at locally or at the counterparty is estimated using [`SurbFlowEstimator`].
194/// 2. Error against a set-point given in [`SurbBalancerConfig`] is evaluated in the `SurbBalancerController`.
195/// 3. The `SurbBalancerController` applies a new SURB flow rate value using the [`SurbFlowController`].
196///
197/// In the local context, the `SurbFlowController` might simply regulate the egress traffic from the
198/// Session, slowing it down to avoid fast SURB drainage.
199///
200/// In the remote context, the `SurbFlowController` might regulate the flow of non-organic SURBs via
201/// Start protocol's `KeepAlive` messages to deliver additional
202/// SURBs to the counterparty.
203pub struct SurbBalancer<C, E, F> {
204    session_id: SessionId,
205    controller: C,
206    surb_estimator: E,
207    flow_control: F,
208    state: Arc<BalancerStateValues>,
209    last_estimator_state: SimpleSurbFlowEstimator,
210    last_update: std::time::Instant,
211    last_decay: std::time::Instant,
212    was_below_target: bool,
213}
214
215impl<C, E, F> SurbBalancer<C, E, F>
216where
217    C: SurbBalancerController + Send + Sync + 'static,
218    E: SurbFlowEstimator + Send + Sync + 'static,
219    F: SurbFlowController + Send + Sync + 'static,
220{
221    pub fn new(
222        session_id: SessionId,
223        mut controller: C,
224        surb_estimator: E,
225        flow_control: F,
226        state: Arc<BalancerStateValues>,
227    ) -> Self {
228        #[cfg(all(feature = "telemetry", not(test)))]
229        {
230            let sid = session_id.to_string();
231            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
232            METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
233        }
234
235        controller.set_target_and_limit(state.controller_bounds());
236
237        Self {
238            surb_estimator,
239            flow_control,
240            controller,
241            session_id,
242            state,
243            last_estimator_state: Default::default(),
244            last_update: std::time::Instant::now(),
245            last_decay: std::time::Instant::now(),
246            was_below_target: true,
247        }
248    }
249
250    /// Computes the next control update and adjusts the [`SurbFlowController`] rate accordingly.
251    #[tracing::instrument(level = "trace", skip_all)]
252    fn update(&mut self) -> u64 {
253        let dt = self.last_update.elapsed();
254
255        // Load the updated current buffer level
256        let mut current = self.state.buffer_level.load(std::sync::atomic::Ordering::Acquire);
257
258        if dt < Duration::from_millis(10) {
259            tracing::debug!("time elapsed since last update is too short, skipping update");
260            return current;
261        }
262
263        self.last_update = std::time::Instant::now();
264
265        // Take a snapshot of the active SURB estimator and calculate the balance change
266        let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
267        let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
268            tracing::error!("non-monotonic change in SURB estimators");
269            return current;
270        };
271
272        self.last_estimator_state = snapshot;
273        current = current.saturating_add_signed(target_buffer_change);
274
275        // If SURB decaying is enabled, check if the decay window has elapsed
276        // and calculate the number of SURBs that will be discarded
277        if let Some(num_decayed_surbs) = self
278            .state
279            .surb_decay()
280            .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
281            .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * decay_coeff).round() as u64)
282        {
283            current = current.saturating_sub(num_decayed_surbs);
284            self.last_decay = std::time::Instant::now();
285            tracing::trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
286        }
287
288        self.state
289            .buffer_level
290            .store(current, std::sync::atomic::Ordering::Release);
291
292        // Error from the desired target SURB buffer size at counterparty
293        let error = current as i64 - self.controller.bounds().target() as i64;
294
295        if self.was_below_target && error >= 0 {
296            tracing::trace!(current, "reached target SURB buffer size");
297            self.was_below_target = false;
298        } else if !self.was_below_target && error < 0 {
299            tracing::trace!(current, "SURB buffer size is below target");
300            self.was_below_target = true;
301        }
302
303        tracing::trace!(
304            ?dt,
305            delta = target_buffer_change,
306            rate = target_buffer_change as f64 / dt.as_secs_f64(),
307            current,
308            error,
309            "estimated SURB buffer change"
310        );
311
312        let output = self.controller.next_control_output(current);
313        tracing::trace!(output, "next balancer control output for session");
314
315        self.flow_control.adjust_surb_flow(output as usize);
316
317        #[cfg(all(feature = "telemetry", not(test)))]
318        {
319            let sid = self.session_id.to_string();
320            METRIC_CURRENT_BUFFER.set(&[&sid], current as f64);
321            METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
322            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
323            METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
324            METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
325        }
326
327        current
328    }
329
330    /// Spawns a new task that performs updates of the given [`SurbBalancer`] at the given `sampling_interval`.
331    ///
332    /// If `cfg_feedback` is given, [`SurbBalancerConfig`] can be queried for updates and also updated
333    /// if the underlying [`SurbBalancerController`] also does target updates.
334    ///
335    /// Returns a stream of current estimated buffer levels, and also an `AbortHandle`
336    /// to terminate the loop. If `abort_reg` was given, the returned `AbortHandle` corresponds
337    /// to it.
338    #[instrument(level = "debug", skip(self), fields(session_id = %self.session_id))]
339    pub fn start_control_loop(
340        mut self,
341        sampling_interval: Duration,
342    ) -> (impl futures::Stream<Item = u64>, AbortHandle) {
343        let (abort_handle, abort_reg) = AbortHandle::new_pair();
344
345        // Start an interval stream at which the balancer will sample and perform updates
346        let sampling_stream = futures::stream::Abortable::new(
347            futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
348            abort_reg,
349        );
350
351        let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
352            .ok()
353            .and_then(|s| s.trim().parse::<usize>().ok())
354            .filter(|&c| c > 0)
355            .unwrap_or(32_768);
356
357        tracing::debug!(
358            capacity = balancer_level_capacity,
359            "Creating session balancer level channel"
360        );
361        let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
362        hopr_utils::runtime::prelude::spawn(
363            async move {
364                pin_mut!(sampling_stream);
365                while sampling_stream.next().await.is_some() {
366                    // Check if the balancer controller needs to be reconfigured
367                    let current_bounds = self.state.controller_bounds();
368                    if current_bounds != self.controller.bounds() {
369                        self.controller.set_target_and_limit(current_bounds);
370                        tracing::debug!(new_cfg = ?self.state.as_config(), "surb balancer has been reconfigured");
371                    }
372
373                    // Perform controller update (this internally samples the SurbFlowEstimator)
374                    // and send an update about the current level to the outgoing stream.
375                    // If the other party has closed the stream, we don't care about the update.
376                    let level = self.update();
377                    if !level_tx.is_closed()
378                        && let Err(error) = level_tx.try_send(level)
379                    {
380                        tracing::error!(%error, "cannot send balancer level update");
381                    }
382                }
383
384                tracing::debug!("balancer done");
385            }
386            .in_current_span(),
387        );
388
389        (level_rx, abort_handle)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::sync::{Arc, atomic::AtomicU64};
396
397    use hopr_types::{crypto_random::Randomizable, internal::prelude::HoprPseudonym};
398
399    use super::*;
400    use crate::balancer::{AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController};
401
402    #[test]
403    fn surb_balancer_config_should_be_convertible_to_atomics() {
404        let cfg = SurbBalancerConfig::default();
405        let state_data = BalancerStateValues::new(cfg);
406        assert_eq!(cfg, state_data.as_config());
407    }
408
409    #[test]
410    fn surb_balancer_config_default_snapshot() {
411        let cfg = SurbBalancerConfig::default();
412        insta::assert_debug_snapshot!(cfg);
413    }
414
415    #[test]
416    fn surb_balancer_config_as_controller_bounds() {
417        let cfg = SurbBalancerConfig {
418            target_surb_buffer_size: 1000,
419            max_surbs_per_sec: 500,
420            surb_decay: None,
421        };
422        let bounds = cfg.as_controller_bounds();
423        assert_eq!(bounds.target(), 1000);
424        assert_eq!(bounds.output_limit(), 500);
425    }
426
427    #[test]
428    fn balancer_state_values_disabled_when_target_is_zero() {
429        let cfg = SurbBalancerConfig {
430            target_surb_buffer_size: 0,
431            max_surbs_per_sec: 0,
432            surb_decay: None,
433        };
434        let state = BalancerStateValues::new(cfg);
435        assert!(state.is_disabled());
436    }
437
438    #[test]
439    fn balancer_state_values_enabled_when_target_is_nonzero() {
440        let state = BalancerStateValues::new(SurbBalancerConfig::default());
441        assert!(!state.is_disabled());
442    }
443
444    #[test]
445    fn balancer_state_values_update_propagates_all_fields() {
446        let state = BalancerStateValues::default();
447        let cfg = SurbBalancerConfig {
448            target_surb_buffer_size: 3000,
449            max_surbs_per_sec: 1500,
450            surb_decay: Some((Duration::from_secs(30), 0.10)),
451        };
452        state.update(&cfg);
453        assert_eq!(state.as_config(), cfg);
454        assert_eq!(state.controller_bounds(), cfg.as_controller_bounds());
455    }
456
457    #[test]
458    fn balancer_state_values_surb_decay_none_maps_to_none() {
459        let cfg = SurbBalancerConfig {
460            target_surb_buffer_size: 1000,
461            max_surbs_per_sec: 500,
462            surb_decay: None,
463        };
464        let state = BalancerStateValues::new(cfg);
465        assert!(state.surb_decay().is_none());
466    }
467
468    #[test]
469    fn balancer_state_values_buffer_level_default_is_zero() {
470        let state = BalancerStateValues::default();
471        assert_eq!(state.buffer_level(), 0);
472    }
473
474    #[test]
475    fn balancer_state_values_buffer_level_can_be_updated() {
476        let state = BalancerStateValues::default();
477        state.buffer_level.store(42, std::sync::atomic::Ordering::Relaxed);
478        assert_eq!(state.buffer_level(), 42);
479    }
480
481    #[test]
482    fn balancer_state_values_from_config() {
483        let cfg = SurbBalancerConfig {
484            target_surb_buffer_size: 5000,
485            max_surbs_per_sec: 2500,
486            surb_decay: Some((Duration::from_secs(60), 0.05)),
487        };
488        let state: BalancerStateValues = cfg.into();
489        assert_eq!(state.as_config(), cfg);
490    }
491
492    #[test]
493    fn balancer_state_values_decay_zero_duration_should_map_to_none() {
494        let cfg = SurbBalancerConfig {
495            surb_decay: Some((Duration::ZERO, 0.10)),
496            ..Default::default()
497        };
498        let state = BalancerStateValues::new(cfg);
499        assert!(
500            state.surb_decay().is_none(),
501            "zero duration decay should be filtered out"
502        );
503    }
504
505    #[test]
506    fn balancer_state_values_decay_zero_percent_should_map_to_none() {
507        let cfg = SurbBalancerConfig {
508            surb_decay: Some((Duration::from_secs(60), 0.0)),
509            ..Default::default()
510        };
511        let state = BalancerStateValues::new(cfg);
512        assert!(
513            state.surb_decay().is_none(),
514            "zero percent decay should be filtered out"
515        );
516    }
517
518    #[test]
519    fn balancer_state_values_decay_should_clamp_above_one() {
520        let cfg = SurbBalancerConfig {
521            surb_decay: Some((Duration::from_secs(1), 1.5)), // > 1.0 should be clamped
522            ..Default::default()
523        };
524        let state = BalancerStateValues::new(cfg);
525        let (_, pct) = state.surb_decay().expect("decay should be present");
526        assert!((pct - 1.0).abs() < f64::EPSILON, "percentage should be clamped to 1.0");
527    }
528
529    #[test_log::test]
530    fn surb_balancer_should_start_increase_level_when_below_target() {
531        let production_rate = Arc::new(AtomicU64::new(0));
532        let consumption_rate = 100;
533        let steps = 3;
534        let step_duration = std::time::Duration::from_millis(1000);
535
536        let mut controller = MockSurbFlowController::new();
537        let production_rate_clone = production_rate.clone();
538        controller
539            .expect_adjust_surb_flow()
540            .times(steps)
541            .with(mockall::predicate::ge(100))
542            .returning(move |r| {
543                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
544            });
545
546        let surb_estimator = AtomicSurbFlowEstimator::default();
547        let mut balancer = SurbBalancer::new(
548            SessionId::new(1234_u64, HoprPseudonym::random()),
549            PidBalancerController::default(),
550            surb_estimator.clone(),
551            controller,
552            Arc::new(
553                SurbBalancerConfig {
554                    target_surb_buffer_size: 5_000,
555                    max_surbs_per_sec: 2500,
556                    surb_decay: None,
557                }
558                .into(),
559            ),
560        );
561
562        let mut last_update = 0;
563        for i in 0..steps {
564            std::thread::sleep(step_duration);
565            surb_estimator.produced.fetch_add(
566                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
567                std::sync::atomic::Ordering::Relaxed,
568            );
569            surb_estimator.consumed.fetch_add(
570                consumption_rate * step_duration.as_secs(),
571                std::sync::atomic::Ordering::Relaxed,
572            );
573
574            let next_update = balancer.update();
575            assert!(
576                i == 0 || next_update > last_update,
577                "{next_update} should be greater than {last_update}"
578            );
579            last_update = next_update;
580        }
581    }
582
583    #[test_log::test]
584    fn surb_balancer_should_start_decrease_level_when_above_target() {
585        let production_rate = Arc::new(AtomicU64::new(11_000));
586        let consumption_rate = 100;
587        let steps = 3;
588        let step_duration = std::time::Duration::from_millis(1000);
589
590        let mut controller = MockSurbFlowController::new();
591        let production_rate_clone = production_rate.clone();
592        controller
593            .expect_adjust_surb_flow()
594            .times(steps)
595            .with(mockall::predicate::ge(0))
596            .returning(move |r| {
597                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
598            });
599
600        let surb_estimator = AtomicSurbFlowEstimator::default();
601        let mut balancer = SurbBalancer::new(
602            SessionId::new(1234_u64, HoprPseudonym::random()),
603            PidBalancerController::default(),
604            surb_estimator.clone(),
605            controller,
606            Arc::new(
607                SurbBalancerConfig {
608                    surb_decay: None,
609                    ..Default::default()
610                }
611                .into(),
612            ),
613        );
614
615        let mut last_update = 0;
616        for i in 0..steps {
617            std::thread::sleep(step_duration);
618            surb_estimator.produced.fetch_add(
619                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
620                std::sync::atomic::Ordering::Relaxed,
621            );
622            surb_estimator.consumed.fetch_add(
623                consumption_rate * step_duration.as_secs(),
624                std::sync::atomic::Ordering::Relaxed,
625            );
626
627            let next_update = balancer.update();
628            assert!(
629                i == 0 || next_update < last_update,
630                "{next_update} should be greater than {last_update}"
631            );
632            last_update = next_update;
633        }
634    }
635
636    #[test_log::test(tokio::test)]
637    async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
638        const NUM_STEPS: usize = 5;
639        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
640        let cfg = SurbBalancerConfig {
641            target_surb_buffer_size: 5_000,
642            max_surbs_per_sec: 2500,
643            surb_decay: Some((Duration::from_millis(200), 0.05)),
644        };
645
646        let mut mock_flow_ctl = MockSurbFlowController::new();
647        mock_flow_ctl
648            .expect_adjust_surb_flow()
649            .times(NUM_STEPS)
650            .returning(|_| ());
651
652        let balancer = SurbBalancer::new(
653            session_id,
654            PidBalancerController::default(),
655            SimpleSurbFlowEstimator::default(),
656            mock_flow_ctl,
657            Arc::new(cfg.into()),
658        );
659
660        balancer
661            .state
662            .buffer_level
663            .store(5000, std::sync::atomic::Ordering::Relaxed);
664
665        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100));
666        let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
667        handle.abort();
668
669        assert_eq!(levels.len(), NUM_STEPS);
670        assert!(
671            levels.windows(2).all(|w| w[1] <= w[0]),
672            "buffer levels should be monotonic non-increasing: {levels:?}"
673        );
674        assert!(
675            levels.last().is_some_and(|last| *last < 5_000),
676            "expected at least one decay step: {levels:?}"
677        );
678    }
679}