hopr_transport_session/balancer/
controller.rs

1use std::time::Duration;
2
3use futures::{StreamExt, future::AbortRegistration, pin_mut};
4use hopr_async_runtime::AbortHandle;
5use tracing::{Instrument, debug, error, instrument, trace};
6
7use super::{
8    BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
9    SurbFlowController, SurbFlowEstimator,
10};
11use crate::SessionId;
12
13#[cfg(all(feature = "prometheus", not(test)))]
14lazy_static::lazy_static! {
15    static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::metrics::MultiGauge =
16        hopr_metrics::metrics::MultiGauge::new(
17            "hopr_surb_balancer_target_error_estimate",
18            "Target error estimation by the SURB balancer",
19            &["session_id"]
20    ).unwrap();
21    static ref METRIC_CONTROL_OUTPUT: hopr_metrics::metrics::MultiGauge =
22        hopr_metrics::metrics::MultiGauge::new(
23            "hopr_surb_balancer_control_output",
24            "Control output of the SURB balancer",
25            &["session_id"]
26    ).unwrap();
27    static ref METRIC_CURRENT_BUFFER: hopr_metrics::metrics::MultiGauge =
28        hopr_metrics::metrics::MultiGauge::new(
29            "hopr_surb_balancer_current_buffer_estimate",
30            "Estimated number of SURBs in the buffer",
31            &["session_id"]
32    ).unwrap();
33    static ref METRIC_CURRENT_TARGET: hopr_metrics::metrics::MultiGauge =
34        hopr_metrics::metrics::MultiGauge::new(
35            "hopr_surb_balancer_current_buffer_target",
36            "Current target (setpoint) number of SURBs in the buffer",
37            &["session_id"]
38    ).unwrap();
39    static ref METRIC_SURB_RATE: hopr_metrics::metrics::MultiGauge =
40        hopr_metrics::metrics::MultiGauge::new(
41            "hopr_surb_balancer_surbs_rate",
42            "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
43            &["session_id"]
44    ).unwrap();
45}
46
47/// Allows updating [`SurbBalancerConfig`] from the [control loop](spawn_surb_balancer_control_loop).
48#[cfg_attr(test, mockall::automock)]
49#[async_trait::async_trait]
50pub trait BalancerConfigFeedback {
51    /// Gets current [`SurbBalancerConfig`] for a Session with `id`.
52    async fn get_config(&self, id: &SessionId) -> crate::errors::Result<SurbBalancerConfig>;
53    /// Updates current [`SurbBalancerConfig`] for a Session with `id`.
54    async fn on_config_update(&self, id: &SessionId, cfg: SurbBalancerConfig) -> crate::errors::Result<()>;
55}
56
57/// Configuration for the `SurbBalancer`.
58#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
59pub struct SurbBalancerConfig {
60    /// The desired number of SURBs to be always kept as a buffer locally or at the Session counterparty.
61    ///
62    /// The `SurbBalancer` will try to maintain approximately this number of SURBs
63    /// locally or remotely (at the counterparty) at all times.
64    ///
65    /// The local buffer is maintained by [regulating](SurbFlowController) the egress from the Session.
66    /// The remote buffer (at session counterparty) is maintained by regulating the flow of non-organic SURBs via
67    /// [keep-alive messages](crate::initiation::StartProtocol::KeepAlive).
68    ///
69    /// It does not make sense to set this value higher than the [`max_surb_buffer_size`](crate::SessionManagerConfig)
70    /// configuration at the counterparty.
71    ///
72    /// Default is 7000 SURBs.
73    #[default(7_000)]
74    pub target_surb_buffer_size: u64,
75    /// Maximum outflow of SURBs.
76    ///
77    /// - In the context of the local SURB buffer, this is the maximum egress Session traffic.
78    /// - In the context of the remote SURB buffer, this is the maximum egress of keep-alive messages to the
79    ///   counterparty.
80    ///
81    /// The default is 5000 (which is 2500 packets/second currently)
82    #[default(5_000)]
83    pub max_surbs_per_sec: u64,
84
85    /// Sets what percentage of the target buffer size should be discarded at each window.
86    ///
87    /// The [`SurbBalancer`] will discard the given percentage of `target_surb_buffer_size` at each
88    /// window with the given `Duration`.
89    ///
90    /// The default is `(60, 0.05)` (5% of the target buffer size is discarded every 60 seconds).
91    #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
92    pub surb_decay: Option<(Duration, f64)>,
93}
94
95/// Runs a continuous process that attempts to [evaluate](SurbFlowEstimator) and
96/// [regulate](SurbFlowController) the flow of SURBs to the Session counterparty,
97/// to keep the number of SURBs locally or at the counterparty at a certain level.
98///
99/// Internally, the Balancer uses an implementation of [`SurbBalancerController`] to
100/// control the rate of SURBs consumed or sent to the counterparty
101/// each time the [`update`](SurbBalancer::update) method is called:
102///
103/// 1. The size of the SURB buffer at locally or at the counterparty is estimated using [`SurbFlowEstimator`].
104/// 2. Error against a set-point given in [`SurbBalancerConfig`] is evaluated in the `SurbBalancerController`.
105/// 3. The `SurbBalancerController` applies a new SURB flow rate value using the [`SurbFlowController`].
106///
107/// In the local context, the `SurbFlowController` might simply regulate the egress traffic from the
108/// Session, slowing it down to avoid fast SURB drainage.
109///
110/// In the remote context, the `SurbFlowController` might regulate the flow of non-organic SURBs via
111/// Start protocol's [`KeepAlive`](crate::initiation::StartProtocol::KeepAlive) messages to deliver additional
112/// SURBs to the counterparty.
113pub struct SurbBalancer<C, E, F> {
114    session_id: SessionId,
115    controller: C,
116    surb_estimator: E,
117    flow_control: F,
118    current_buffer: u64,
119    last_estimator_state: SimpleSurbFlowEstimator,
120    last_update: std::time::Instant,
121    last_decay: std::time::Instant,
122    was_below_target: bool,
123}
124
125impl<C, E, F> SurbBalancer<C, E, F>
126where
127    C: SurbBalancerController + Send + Sync + 'static,
128    E: SurbFlowEstimator + Send + Sync + 'static,
129    F: SurbFlowController + Send + Sync + 'static,
130{
131    pub fn new(
132        session_id: SessionId,
133        mut controller: C,
134        surb_estimator: E,
135        flow_control: F,
136        cfg: SurbBalancerConfig,
137    ) -> Self {
138        #[cfg(all(feature = "prometheus", not(test)))]
139        {
140            let sid = session_id.to_string();
141            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
142            METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
143        }
144
145        controller.set_target_and_limit(BalancerControllerBounds::new(
146            cfg.target_surb_buffer_size,
147            cfg.max_surbs_per_sec,
148        ));
149
150        Self {
151            surb_estimator,
152            flow_control,
153            controller,
154            session_id,
155            current_buffer: 0,
156            last_estimator_state: Default::default(),
157            last_update: std::time::Instant::now(),
158            last_decay: std::time::Instant::now(),
159            was_below_target: true,
160        }
161    }
162
163    /// Computes the next control update and adjusts the [`SurbFlowController`] rate accordingly.
164    #[tracing::instrument(level = "trace", skip_all)]
165    fn update(&mut self, surb_decay: Option<&(Duration, f64)>) -> u64 {
166        let dt = self.last_update.elapsed();
167        if dt < Duration::from_millis(10) {
168            debug!("time elapsed since last update is too short, skipping update");
169            return self.current_buffer;
170        }
171
172        self.last_update = std::time::Instant::now();
173
174        // Take a snapshot of the active SURB estimator and calculate the balance change
175        let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
176        let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
177            error!("non-monotonic change in SURB estimators");
178            return self.current_buffer;
179        };
180
181        self.last_estimator_state = snapshot;
182        self.current_buffer = self.current_buffer.saturating_add_signed(target_buffer_change);
183
184        // If SURB decaying is enabled, check if the decay window has elapsed
185        // and calculate the number of SURBs that will be discarded
186        if let Some(num_decayed_surbs) = surb_decay
187            .as_ref()
188            .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
189            .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * *decay_coeff).round() as u64)
190        {
191            self.current_buffer = self.current_buffer.saturating_sub(num_decayed_surbs);
192            self.last_decay = std::time::Instant::now();
193            trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
194        }
195
196        // Error from the desired target SURB buffer size at counterparty
197        let error = self.current_buffer as i64 - self.controller.bounds().target() as i64;
198
199        if self.was_below_target && error >= 0 {
200            trace!(current = self.current_buffer, "reached target SURB buffer size");
201            self.was_below_target = false;
202        } else if !self.was_below_target && error < 0 {
203            trace!(current = self.current_buffer, "SURB buffer size is below target");
204            self.was_below_target = true;
205        }
206
207        trace!(
208            ?dt,
209            delta = target_buffer_change,
210            rate = target_buffer_change as f64 / dt.as_secs_f64(),
211            current = self.current_buffer,
212            error,
213            "estimated SURB buffer change"
214        );
215
216        let output = self.controller.next_control_output(self.current_buffer);
217        trace!(output, "next balancer control output for session");
218
219        self.flow_control.adjust_surb_flow(output as usize);
220
221        #[cfg(all(feature = "prometheus", not(test)))]
222        {
223            let sid = self.session_id.to_string();
224            METRIC_CURRENT_BUFFER.set(&[&sid], self.current_buffer as f64);
225            METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
226            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
227            METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
228            METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
229        }
230
231        self.current_buffer
232    }
233
234    /// Spawns a new task that performs updates of the given [`SurbBalancer`] at the given `sampling_interval`.
235    ///
236    /// If `surb_decay` is given, SURBs are removed at each window as the given percentage of the target buffer.
237    /// If `cfg_feedback` is given, [`SurbBalancerConfig`] can be queried for updates and also updated
238    /// if the underlying [`SurbBalancerController`] also does target updates.
239    ///
240    /// Returns a stream of current estimated buffer levels, and also an `AbortHandle`
241    /// to terminate the loop. If `abort_reg` was given, the returned `AbortHandle` corresponds
242    /// to it.
243    #[instrument(level = "debug", skip_all, fields(session_id = %self.session_id))]
244    pub fn start_control_loop<B>(
245        mut self,
246        sampling_interval: Duration,
247        cfg_feedback: B,
248        abort_reg: Option<AbortRegistration>,
249    ) -> (impl futures::Stream<Item = u64>, AbortHandle)
250    where
251        B: BalancerConfigFeedback + Send + Sync + 'static,
252    {
253        // Get abort handle and registration (or create new ones)
254        let (abort_handle, abort_reg) = abort_reg
255            .map(|reg| (reg.handle(), reg))
256            .unwrap_or_else(AbortHandle::new_pair);
257
258        // Start an interval stream at which the balancer will sample and perform updates
259        let sampling_stream = futures::stream::Abortable::new(
260            futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
261            abort_reg,
262        );
263
264        let (level_tx, level_rx) = futures::channel::mpsc::unbounded();
265        hopr_async_runtime::prelude::spawn(
266            async move {
267                pin_mut!(sampling_stream);
268                while sampling_stream.next().await.is_some() {
269                    let Ok(mut current_cfg) = cfg_feedback.get_config(&self.session_id).await else {
270                        error!("cannot get config for session");
271                        break;
272                    };
273
274                    let current_bounds = BalancerControllerBounds::new(
275                        current_cfg.target_surb_buffer_size,
276                        current_cfg.max_surbs_per_sec,
277                    );
278
279                    // Check if the balancer controller needs to be reconfigured
280                    if current_bounds != self.controller.bounds() {
281                        self.controller.set_target_and_limit(current_bounds);
282                        debug!(?current_cfg, "surb balancer has been reconfigured");
283                    }
284
285                    let bounds_before_update = self.controller.bounds();
286
287                    // Perform controller update (this internally samples the SurbFlowEstimator)
288                    // and send an update about the current level to the outgoing stream
289                    let level = self.update(current_cfg.surb_decay.as_ref());
290                    let _ = level_tx.unbounded_send(level);
291
292                    // See if the setpoint has been updated at the controller as a result
293                    // of the update step, because some controllers (such as the SimpleBalancerController)
294                    // permit that.
295                    let bounds_after_update = self.controller.bounds();
296                    if bounds_before_update != bounds_after_update {
297                        current_cfg.target_surb_buffer_size = bounds_after_update.target();
298                        current_cfg.max_surbs_per_sec = bounds_after_update.output_limit();
299                        match cfg_feedback.on_config_update(&self.session_id, current_cfg).await {
300                            Ok(_) => debug!(
301                                ?bounds_before_update,
302                                ?bounds_after_update,
303                                "controller bounds has changed after update"
304                            ),
305                            Err(error) => error!(%error, "failed to update controller bounds after it changed"),
306                        }
307                    }
308                }
309
310                debug!("balancer done");
311            }
312            .in_current_span(),
313        );
314
315        (level_rx, abort_handle)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use std::sync::{Arc, atomic::AtomicU64};
322
323    use hopr_crypto_random::Randomizable;
324    use hopr_internal_types::prelude::HoprPseudonym;
325
326    use super::*;
327    use crate::balancer::{
328        AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController, simple::SimpleBalancerController,
329    };
330
331    #[test_log::test]
332    fn surb_balancer_should_start_increase_level_when_below_target() {
333        let production_rate = Arc::new(AtomicU64::new(0));
334        let consumption_rate = 100;
335        let steps = 3;
336        let step_duration = std::time::Duration::from_millis(1000);
337
338        let mut controller = MockSurbFlowController::new();
339        let production_rate_clone = production_rate.clone();
340        controller
341            .expect_adjust_surb_flow()
342            .times(steps)
343            .with(mockall::predicate::ge(100))
344            .returning(move |r| {
345                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
346            });
347
348        let surb_estimator = AtomicSurbFlowEstimator::default();
349        let mut balancer = SurbBalancer::new(
350            SessionId::new(1234_u64, HoprPseudonym::random()),
351            PidBalancerController::default(),
352            surb_estimator.clone(),
353            controller,
354            SurbBalancerConfig {
355                target_surb_buffer_size: 5_000,
356                max_surbs_per_sec: 2500,
357                ..Default::default()
358            },
359        );
360
361        let mut last_update = 0;
362        for i in 0..steps {
363            std::thread::sleep(step_duration);
364            surb_estimator.produced.fetch_add(
365                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
366                std::sync::atomic::Ordering::Relaxed,
367            );
368            surb_estimator.consumed.fetch_add(
369                consumption_rate * step_duration.as_secs(),
370                std::sync::atomic::Ordering::Relaxed,
371            );
372
373            let next_update = balancer.update(None);
374            assert!(
375                i == 0 || next_update > last_update,
376                "{next_update} should be greater than {last_update}"
377            );
378            last_update = next_update;
379        }
380    }
381
382    #[test_log::test]
383    fn surb_balancer_should_start_decrease_level_when_above_target() {
384        let production_rate = Arc::new(AtomicU64::new(11_000));
385        let consumption_rate = 100;
386        let steps = 3;
387        let step_duration = std::time::Duration::from_millis(1000);
388
389        let mut controller = MockSurbFlowController::new();
390        let production_rate_clone = production_rate.clone();
391        controller
392            .expect_adjust_surb_flow()
393            .times(steps)
394            .with(mockall::predicate::ge(0))
395            .returning(move |r| {
396                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
397            });
398
399        let surb_estimator = AtomicSurbFlowEstimator::default();
400        let mut balancer = SurbBalancer::new(
401            SessionId::new(1234_u64, HoprPseudonym::random()),
402            PidBalancerController::default(),
403            surb_estimator.clone(),
404            controller,
405            SurbBalancerConfig::default(),
406        );
407
408        let mut last_update = 0;
409        for i in 0..steps {
410            std::thread::sleep(step_duration);
411            surb_estimator.produced.fetch_add(
412                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
413                std::sync::atomic::Ordering::Relaxed,
414            );
415            surb_estimator.consumed.fetch_add(
416                consumption_rate * step_duration.as_secs(),
417                std::sync::atomic::Ordering::Relaxed,
418            );
419
420            let next_update = balancer.update(None);
421            assert!(
422                i == 0 || next_update < last_update,
423                "{next_update} should be greater than {last_update}"
424            );
425            last_update = next_update;
426        }
427    }
428
429    #[test_log::test(tokio::test)]
430    async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
431        const NUM_STEPS: usize = 5;
432        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
433        let cfg = SurbBalancerConfig {
434            target_surb_buffer_size: 5_000,
435            max_surbs_per_sec: 2500,
436            surb_decay: Some((Duration::from_millis(200), 0.05)),
437            ..Default::default()
438        };
439
440        let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
441        mock_balancer_feedback
442            .expect_get_config()
443            .with(mockall::predicate::eq(session_id))
444            .times(NUM_STEPS)
445            .returning(move |_| Ok(cfg));
446
447        mock_balancer_feedback.expect_on_config_update().never();
448
449        let mut mock_flow_ctl = MockSurbFlowController::new();
450        mock_flow_ctl
451            .expect_adjust_surb_flow()
452            .times(NUM_STEPS)
453            .returning(|_| ());
454
455        let mut balancer = SurbBalancer::new(
456            session_id,
457            PidBalancerController::default(),
458            SimpleSurbFlowEstimator::default(),
459            mock_flow_ctl,
460            cfg,
461        );
462
463        balancer.current_buffer = 5000;
464
465        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
466        let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
467        handle.abort();
468
469        assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
470    }
471
472    struct IterSurbFlowEstimator<P, C>(std::sync::Mutex<P>, std::sync::Mutex<C>);
473
474    impl<P, C> IterSurbFlowEstimator<P, C> {
475        fn new<I1, I2>(production: I1, consumption: I2) -> Self
476        where
477            I1: IntoIterator<Item = u64, IntoIter = P>,
478            I2: IntoIterator<Item = u64, IntoIter = C>,
479        {
480            Self(
481                std::sync::Mutex::new(production.into_iter()),
482                std::sync::Mutex::new(consumption.into_iter()),
483            )
484        }
485    }
486
487    impl<P, C> SurbFlowEstimator for IterSurbFlowEstimator<P, C>
488    where
489        P: Iterator<Item = u64>,
490        C: Iterator<Item = u64>,
491    {
492        fn estimate_surbs_consumed(&self) -> u64 {
493            self.1.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
494        }
495
496        fn estimate_surbs_produced(&self) -> u64 {
497            self.0.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
498        }
499    }
500
501    #[test_log::test(tokio::test)]
502    async fn surb_balancer_should_increase_target_when_using_simple_controller() {
503        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
504        let cfg_1 = SurbBalancerConfig {
505            target_surb_buffer_size: 4500,
506            max_surbs_per_sec: 2500,
507            surb_decay: None,
508            ..Default::default()
509        };
510
511        let cfg_2 = SurbBalancerConfig {
512            target_surb_buffer_size: 5500,
513            max_surbs_per_sec: 3055,
514            surb_decay: None,
515            ..Default::default()
516        };
517
518        let mut seq = mockall::Sequence::new();
519
520        let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
521        mock_balancer_feedback
522            .expect_get_config()
523            .times(3)
524            .in_sequence(&mut seq)
525            .with(mockall::predicate::eq(session_id))
526            .returning(move |_| {
527                tracing::trace!("get config 1");
528                Ok(cfg_1)
529            });
530
531        mock_balancer_feedback
532            .expect_on_config_update()
533            .once()
534            .in_sequence(&mut seq)
535            .with(mockall::predicate::eq(session_id), mockall::predicate::eq(cfg_2))
536            .returning(|_, _| {
537                tracing::trace!("on config update");
538                Ok(())
539            });
540
541        mock_balancer_feedback
542            .expect_get_config()
543            .times(2)
544            .in_sequence(&mut seq)
545            .with(mockall::predicate::eq(session_id))
546            .returning(move |_| {
547                tracing::trace!("get config 2");
548                Ok(cfg_2)
549            });
550
551        let mut mock_flow_ctl = MockSurbFlowController::new();
552        mock_flow_ctl.expect_adjust_surb_flow().times(5).returning(|_| ());
553
554        let mut balancer = SurbBalancer::new(
555            session_id,
556            SimpleBalancerController::with_increasing_setpoint(0.2, 5),
557            IterSurbFlowEstimator::new([1000, 2000, 3000, 3000, 3000], vec![500, 1000, 1500, 1500, 1500]),
558            mock_flow_ctl,
559            cfg_1,
560        );
561
562        balancer.current_buffer = 4500;
563
564        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
565
566        let levels = stream.take(5).collect::<Vec<_>>().await;
567        handle.abort();
568
569        assert_eq!(levels, vec![5000, 5500, 6000, 6000, 6000]);
570    }
571}