Skip to main content

hopr_transport_session/balancer/
controller.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicU8, AtomicU64},
5    },
6    time::Duration,
7};
8
9use futures::{StreamExt, pin_mut};
10use hopr_async_runtime::AbortHandle;
11use tracing::{Instrument, instrument};
12
13use super::{
14    BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
15    SurbFlowController, SurbFlowEstimator,
16};
17use crate::SessionId;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21    static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::MultiGauge =
22        hopr_metrics::MultiGauge::new(
23            "hopr_surb_balancer_target_error_estimate",
24            "Target error estimation by the SURB balancer",
25            &["session_id"]
26    ).unwrap();
27    static ref METRIC_CONTROL_OUTPUT: hopr_metrics::MultiGauge =
28        hopr_metrics::MultiGauge::new(
29            "hopr_surb_balancer_control_output",
30            "Control output of the SURB balancer",
31            &["session_id"]
32    ).unwrap();
33    static ref METRIC_CURRENT_BUFFER: hopr_metrics::MultiGauge =
34        hopr_metrics::MultiGauge::new(
35            "hopr_surb_balancer_current_buffer_estimate",
36            "Estimated number of SURBs in the buffer",
37            &["session_id"]
38    ).unwrap();
39    static ref METRIC_CURRENT_TARGET: hopr_metrics::MultiGauge =
40        hopr_metrics::MultiGauge::new(
41            "hopr_surb_balancer_current_buffer_target",
42            "Current target (setpoint) number of SURBs in the buffer",
43            &["session_id"]
44    ).unwrap();
45    static ref METRIC_SURB_RATE: hopr_metrics::MultiGauge =
46        hopr_metrics::MultiGauge::new(
47            "hopr_surb_balancer_surbs_rate",
48            "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
49            &["session_id"]
50    ).unwrap();
51}
52
53/// Configuration for the `SurbBalancer`.
54#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
55pub struct SurbBalancerConfig {
56    /// The desired number of SURBs to be always kept as a buffer locally or at the Session counterparty.
57    ///
58    /// The `SurbBalancer` will try to maintain approximately this number of SURBs
59    /// locally or remotely (at the counterparty) at all times.
60    ///
61    /// The local buffer is maintained by [regulating](SurbFlowController) the egress from the Session.
62    /// The remote buffer (at session counterparty) is maintained by regulating the flow of non-organic SURBs via
63    /// [keep-alive messages](crate::initiation::StartProtocol::KeepAlive).
64    ///
65    /// It does not make sense to set this value higher than the [`max_surb_buffer_size`](crate::SessionManagerConfig)
66    /// configuration at the counterparty.
67    ///
68    /// Default is 7000 SURBs.
69    #[default(7_000)]
70    pub target_surb_buffer_size: u64,
71    /// Maximum outflow of SURBs.
72    ///
73    /// - In the context of the local SURB buffer (Entry), this is the maximum egress Session traffic (= SURB
74    ///   consumption).
75    /// - In the context of the remote SURB buffer (Exit), this is the maximum egress of keep-alive messages to the
76    ///   counterparty (= artificial SURB production).
77    ///
78    /// The default is 5000 (which is 2500 packets/second currently)
79    #[default(5_000)]
80    pub max_surbs_per_sec: u64,
81
82    /// Sets what percentage of the target buffer size should be discarded at each window.
83    ///
84    /// The [`SurbBalancer`] will discard the given percentage of `target_surb_buffer_size` at each
85    /// window with the given `Duration`.
86    ///
87    /// The default is `(60, 0.05)` (5% of the target buffer size is discarded every 60 seconds).
88    #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
89    pub surb_decay: Option<(Duration, f64)>,
90}
91
92impl SurbBalancerConfig {
93    /// Convenience function to convert the [`SurbBalancerConfig`] into [`BalancerControllerBounds`].
94    #[inline]
95    pub fn as_controller_bounds(&self) -> BalancerControllerBounds {
96        BalancerControllerBounds::new(self.target_surb_buffer_size, self.max_surbs_per_sec)
97    }
98}
99
100/// Runtime state of the [`SurbBalancer`].
101#[derive(Debug, Default)]
102pub struct BalancerStateValues {
103    pub target_surb_buffer_size: AtomicU64,
104    pub max_surbs_per_sec: AtomicU64,
105    pub decay_duration_msec: AtomicU64,
106    pub decay_volume_pct: AtomicU8,
107    pub buffer_level: AtomicU64,
108}
109
110impl BalancerStateValues {
111    /// Constructor from a [`SurbBalancerConfig`].
112    pub fn new(cfg: SurbBalancerConfig) -> Self {
113        let state = Self::default();
114        state.update(&cfg);
115        state
116    }
117
118    /// Performs update of the [`BalancerStateValues`] from the [`SurbBalancerConfig`] and
119    /// enables it.
120    pub fn update(&self, cfg: &SurbBalancerConfig) {
121        self.target_surb_buffer_size
122            .store(cfg.target_surb_buffer_size, std::sync::atomic::Ordering::Relaxed);
123        self.max_surbs_per_sec
124            .store(cfg.max_surbs_per_sec, std::sync::atomic::Ordering::Relaxed);
125        self.decay_duration_msec.store(
126            cfg.surb_decay
127                .map(|(d, _)| d.as_millis().min(u64::MAX as u128) as u64)
128                .unwrap_or_default(),
129            std::sync::atomic::Ordering::Relaxed,
130        );
131        self.decay_volume_pct.store(
132            cfg.surb_decay
133                .map(|(_, p)| (p.clamp(0.0, 1.0) * 100.0).round() as u8)
134                .unwrap_or_default(),
135            std::sync::atomic::Ordering::Relaxed,
136        );
137    }
138
139    /// Extracts the [`SurbBalancerConfig`] from the [`BalancerStateValues`].
140    pub fn as_config(&self) -> SurbBalancerConfig {
141        SurbBalancerConfig {
142            target_surb_buffer_size: self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
143            max_surbs_per_sec: self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
144            surb_decay: self.surb_decay(),
145        }
146    }
147
148    /// Checks if SURB balancing is disabled (no target buffer size set).
149    pub fn is_disabled(&self) -> bool {
150        self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed) == 0
151    }
152
153    /// Extracts the SURB decay configuration from the [`BalancerStateValues`].
154    pub fn surb_decay(&self) -> Option<(Duration, f64)> {
155        Some((
156            self.decay_duration_msec.load(std::sync::atomic::Ordering::Relaxed),
157            self.decay_volume_pct.load(std::sync::atomic::Ordering::Relaxed),
158        ))
159        .filter(|&(d, p)| d > 0 && p > 0)
160        .map(|(d, p)| (Duration::from_millis(d), p as f64 / 100.0))
161    }
162
163    /// Gets the current estimated SURB buffer level.
164    #[inline]
165    pub fn buffer_level(&self) -> u64 {
166        self.buffer_level.load(std::sync::atomic::Ordering::Relaxed)
167    }
168
169    /// Returns the current [`BalancerControllerBounds`] from the [`BalancerStateValues`].
170    #[inline]
171    pub fn controller_bounds(&self) -> BalancerControllerBounds {
172        BalancerControllerBounds::new(
173            self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
174            self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
175        )
176    }
177}
178
179impl From<SurbBalancerConfig> for BalancerStateValues {
180    fn from(cfg: SurbBalancerConfig) -> Self {
181        Self::new(cfg)
182    }
183}
184
185/// Runs a continuous process that attempts to [evaluate](SurbFlowEstimator) and
186/// [regulate](SurbFlowController) the flow of SURBs to the Session counterparty,
187/// to keep the number of SURBs locally or at the counterparty at a certain level.
188///
189/// Internally, the Balancer uses an implementation of [`SurbBalancerController`] to
190/// control the rate of SURBs consumed or sent to the counterparty
191/// each time the [`update`](SurbBalancer::update) method is called:
192///
193/// 1. The size of the SURB buffer at locally or at the counterparty is estimated using [`SurbFlowEstimator`].
194/// 2. Error against a set-point given in [`SurbBalancerConfig`] is evaluated in the `SurbBalancerController`.
195/// 3. The `SurbBalancerController` applies a new SURB flow rate value using the [`SurbFlowController`].
196///
197/// In the local context, the `SurbFlowController` might simply regulate the egress traffic from the
198/// Session, slowing it down to avoid fast SURB drainage.
199///
200/// In the remote context, the `SurbFlowController` might regulate the flow of non-organic SURBs via
201/// Start protocol's [`KeepAlive`](crate::initiation::StartProtocol::KeepAlive) messages to deliver additional
202/// SURBs to the counterparty.
203pub struct SurbBalancer<C, E, F> {
204    session_id: SessionId,
205    controller: C,
206    surb_estimator: E,
207    flow_control: F,
208    state: Arc<BalancerStateValues>,
209    last_estimator_state: SimpleSurbFlowEstimator,
210    last_update: std::time::Instant,
211    last_decay: std::time::Instant,
212    was_below_target: bool,
213}
214
215impl<C, E, F> SurbBalancer<C, E, F>
216where
217    C: SurbBalancerController + Send + Sync + 'static,
218    E: SurbFlowEstimator + Send + Sync + 'static,
219    F: SurbFlowController + Send + Sync + 'static,
220{
221    pub fn new(
222        session_id: SessionId,
223        mut controller: C,
224        surb_estimator: E,
225        flow_control: F,
226        state: Arc<BalancerStateValues>,
227    ) -> Self {
228        #[cfg(all(feature = "prometheus", not(test)))]
229        {
230            let sid = session_id.to_string();
231            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
232            METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
233        }
234
235        controller.set_target_and_limit(state.controller_bounds());
236
237        Self {
238            surb_estimator,
239            flow_control,
240            controller,
241            session_id,
242            state,
243            last_estimator_state: Default::default(),
244            last_update: std::time::Instant::now(),
245            last_decay: std::time::Instant::now(),
246            was_below_target: true,
247        }
248    }
249
250    /// Computes the next control update and adjusts the [`SurbFlowController`] rate accordingly.
251    #[tracing::instrument(level = "trace", skip_all)]
252    fn update(&mut self) -> u64 {
253        let dt = self.last_update.elapsed();
254
255        // Load the updated current buffer level
256        let mut current = self.state.buffer_level.load(std::sync::atomic::Ordering::Acquire);
257
258        if dt < Duration::from_millis(10) {
259            tracing::debug!("time elapsed since last update is too short, skipping update");
260            return current;
261        }
262
263        self.last_update = std::time::Instant::now();
264
265        // Take a snapshot of the active SURB estimator and calculate the balance change
266        let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
267        let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
268            tracing::error!("non-monotonic change in SURB estimators");
269            return current;
270        };
271
272        self.last_estimator_state = snapshot;
273        current = current.saturating_add_signed(target_buffer_change);
274
275        // If SURB decaying is enabled, check if the decay window has elapsed
276        // and calculate the number of SURBs that will be discarded
277        if let Some(num_decayed_surbs) = self
278            .state
279            .surb_decay()
280            .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
281            .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * decay_coeff).round() as u64)
282        {
283            current = current.saturating_sub(num_decayed_surbs);
284            self.last_decay = std::time::Instant::now();
285            tracing::trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
286        }
287
288        self.state
289            .buffer_level
290            .store(current, std::sync::atomic::Ordering::Release);
291
292        // Error from the desired target SURB buffer size at counterparty
293        let error = current as i64 - self.controller.bounds().target() as i64;
294
295        if self.was_below_target && error >= 0 {
296            tracing::trace!(current, "reached target SURB buffer size");
297            self.was_below_target = false;
298        } else if !self.was_below_target && error < 0 {
299            tracing::trace!(current, "SURB buffer size is below target");
300            self.was_below_target = true;
301        }
302
303        tracing::trace!(
304            ?dt,
305            delta = target_buffer_change,
306            rate = target_buffer_change as f64 / dt.as_secs_f64(),
307            current,
308            error,
309            "estimated SURB buffer change"
310        );
311
312        let output = self.controller.next_control_output(current);
313        tracing::trace!(output, "next balancer control output for session");
314
315        self.flow_control.adjust_surb_flow(output as usize);
316
317        #[cfg(all(feature = "prometheus", not(test)))]
318        {
319            let sid = self.session_id.to_string();
320            METRIC_CURRENT_BUFFER.set(&[&sid], current as f64);
321            METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
322            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
323            METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
324            METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
325        }
326
327        current
328    }
329
330    /// Spawns a new task that performs updates of the given [`SurbBalancer`] at the given `sampling_interval`.
331    ///
332    /// If `cfg_feedback` is given, [`SurbBalancerConfig`] can be queried for updates and also updated
333    /// if the underlying [`SurbBalancerController`] also does target updates.
334    ///
335    /// Returns a stream of current estimated buffer levels, and also an `AbortHandle`
336    /// to terminate the loop. If `abort_reg` was given, the returned `AbortHandle` corresponds
337    /// to it.
338    #[instrument(level = "debug", skip(self), fields(session_id = %self.session_id))]
339    pub fn start_control_loop(
340        mut self,
341        sampling_interval: Duration,
342    ) -> (impl futures::Stream<Item = u64>, AbortHandle) {
343        let (abort_handle, abort_reg) = AbortHandle::new_pair();
344
345        // Start an interval stream at which the balancer will sample and perform updates
346        let sampling_stream = futures::stream::Abortable::new(
347            futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
348            abort_reg,
349        );
350
351        let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
352            .ok()
353            .and_then(|s| s.trim().parse::<usize>().ok())
354            .filter(|&c| c > 0)
355            .unwrap_or(32_768);
356
357        tracing::debug!(
358            capacity = balancer_level_capacity,
359            "Creating session balancer level channel"
360        );
361        let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
362        hopr_async_runtime::prelude::spawn(
363            async move {
364                pin_mut!(sampling_stream);
365                while sampling_stream.next().await.is_some() {
366                    // Check if the balancer controller needs to be reconfigured
367                    let current_bounds = self.state.controller_bounds();
368                    if current_bounds != self.controller.bounds() {
369                        self.controller.set_target_and_limit(current_bounds);
370                        tracing::debug!(new_cfg = ?self.state.as_config(), "surb balancer has been reconfigured");
371                    }
372
373                    // Perform controller update (this internally samples the SurbFlowEstimator)
374                    // and send an update about the current level to the outgoing stream.
375                    // If the other party has closed the stream, we don't care about the update.
376                    let level = self.update();
377                    if !level_tx.is_closed()
378                        && let Err(error) = level_tx.try_send(level)
379                    {
380                        tracing::error!(%error, "cannot send balancer level update");
381                    }
382                }
383
384                tracing::debug!("balancer done");
385            }
386            .in_current_span(),
387        );
388
389        (level_rx, abort_handle)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::sync::{Arc, atomic::AtomicU64};
396
397    use hopr_crypto_random::Randomizable;
398    use hopr_internal_types::prelude::HoprPseudonym;
399
400    use super::*;
401    use crate::balancer::{AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController};
402
403    #[test]
404    fn surb_balancer_config_should_be_convertible_to_atomics() {
405        let cfg = SurbBalancerConfig::default();
406        let state_data = BalancerStateValues::new(cfg);
407        assert_eq!(cfg, state_data.as_config());
408    }
409
410    #[test_log::test]
411    fn surb_balancer_should_start_increase_level_when_below_target() {
412        let production_rate = Arc::new(AtomicU64::new(0));
413        let consumption_rate = 100;
414        let steps = 3;
415        let step_duration = std::time::Duration::from_millis(1000);
416
417        let mut controller = MockSurbFlowController::new();
418        let production_rate_clone = production_rate.clone();
419        controller
420            .expect_adjust_surb_flow()
421            .times(steps)
422            .with(mockall::predicate::ge(100))
423            .returning(move |r| {
424                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
425            });
426
427        let surb_estimator = AtomicSurbFlowEstimator::default();
428        let mut balancer = SurbBalancer::new(
429            SessionId::new(1234_u64, HoprPseudonym::random()),
430            PidBalancerController::default(),
431            surb_estimator.clone(),
432            controller,
433            Arc::new(
434                SurbBalancerConfig {
435                    target_surb_buffer_size: 5_000,
436                    max_surbs_per_sec: 2500,
437                    surb_decay: None,
438                }
439                .into(),
440            ),
441        );
442
443        let mut last_update = 0;
444        for i in 0..steps {
445            std::thread::sleep(step_duration);
446            surb_estimator.produced.fetch_add(
447                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
448                std::sync::atomic::Ordering::Relaxed,
449            );
450            surb_estimator.consumed.fetch_add(
451                consumption_rate * step_duration.as_secs(),
452                std::sync::atomic::Ordering::Relaxed,
453            );
454
455            let next_update = balancer.update();
456            assert!(
457                i == 0 || next_update > last_update,
458                "{next_update} should be greater than {last_update}"
459            );
460            last_update = next_update;
461        }
462    }
463
464    #[test_log::test]
465    fn surb_balancer_should_start_decrease_level_when_above_target() {
466        let production_rate = Arc::new(AtomicU64::new(11_000));
467        let consumption_rate = 100;
468        let steps = 3;
469        let step_duration = std::time::Duration::from_millis(1000);
470
471        let mut controller = MockSurbFlowController::new();
472        let production_rate_clone = production_rate.clone();
473        controller
474            .expect_adjust_surb_flow()
475            .times(steps)
476            .with(mockall::predicate::ge(0))
477            .returning(move |r| {
478                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
479            });
480
481        let surb_estimator = AtomicSurbFlowEstimator::default();
482        let mut balancer = SurbBalancer::new(
483            SessionId::new(1234_u64, HoprPseudonym::random()),
484            PidBalancerController::default(),
485            surb_estimator.clone(),
486            controller,
487            Arc::new(
488                SurbBalancerConfig {
489                    surb_decay: None,
490                    ..Default::default()
491                }
492                .into(),
493            ),
494        );
495
496        let mut last_update = 0;
497        for i in 0..steps {
498            std::thread::sleep(step_duration);
499            surb_estimator.produced.fetch_add(
500                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
501                std::sync::atomic::Ordering::Relaxed,
502            );
503            surb_estimator.consumed.fetch_add(
504                consumption_rate * step_duration.as_secs(),
505                std::sync::atomic::Ordering::Relaxed,
506            );
507
508            let next_update = balancer.update();
509            assert!(
510                i == 0 || next_update < last_update,
511                "{next_update} should be greater than {last_update}"
512            );
513            last_update = next_update;
514        }
515    }
516
517    #[test_log::test(tokio::test)]
518    async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
519        const NUM_STEPS: usize = 5;
520        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
521        let cfg = SurbBalancerConfig {
522            target_surb_buffer_size: 5_000,
523            max_surbs_per_sec: 2500,
524            surb_decay: Some((Duration::from_millis(200), 0.05)),
525        };
526
527        let mut mock_flow_ctl = MockSurbFlowController::new();
528        mock_flow_ctl
529            .expect_adjust_surb_flow()
530            .times(NUM_STEPS)
531            .returning(|_| ());
532
533        let balancer = SurbBalancer::new(
534            session_id,
535            PidBalancerController::default(),
536            SimpleSurbFlowEstimator::default(),
537            mock_flow_ctl,
538            Arc::new(cfg.into()),
539        );
540
541        balancer
542            .state
543            .buffer_level
544            .store(5000, std::sync::atomic::Ordering::Relaxed);
545
546        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100));
547        let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
548        handle.abort();
549
550        assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
551    }
552}