hopr_transport_session/balancer/
controller.rs

1use std::time::Duration;
2
3use futures::{StreamExt, future::AbortRegistration, pin_mut};
4use hopr_async_runtime::AbortHandle;
5use tracing::{Instrument, debug, error, instrument, trace};
6
7use super::{
8    BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
9    SurbFlowController, SurbFlowEstimator,
10};
11use crate::SessionId;
12
13#[cfg(all(feature = "prometheus", not(test)))]
14lazy_static::lazy_static! {
15    static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::MultiGauge =
16        hopr_metrics::MultiGauge::new(
17            "hopr_surb_balancer_target_error_estimate",
18            "Target error estimation by the SURB balancer",
19            &["session_id"]
20    ).unwrap();
21    static ref METRIC_CONTROL_OUTPUT: hopr_metrics::MultiGauge =
22        hopr_metrics::MultiGauge::new(
23            "hopr_surb_balancer_control_output",
24            "Control output of the SURB balancer",
25            &["session_id"]
26    ).unwrap();
27    static ref METRIC_CURRENT_BUFFER: hopr_metrics::MultiGauge =
28        hopr_metrics::MultiGauge::new(
29            "hopr_surb_balancer_current_buffer_estimate",
30            "Estimated number of SURBs in the buffer",
31            &["session_id"]
32    ).unwrap();
33    static ref METRIC_CURRENT_TARGET: hopr_metrics::MultiGauge =
34        hopr_metrics::MultiGauge::new(
35            "hopr_surb_balancer_current_buffer_target",
36            "Current target (setpoint) number of SURBs in the buffer",
37            &["session_id"]
38    ).unwrap();
39    static ref METRIC_SURB_RATE: hopr_metrics::MultiGauge =
40        hopr_metrics::MultiGauge::new(
41            "hopr_surb_balancer_surbs_rate",
42            "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
43            &["session_id"]
44    ).unwrap();
45}
46
47/// Allows updating [`SurbBalancerConfig`] from the [control loop](spawn_surb_balancer_control_loop).
48#[cfg_attr(test, mockall::automock)]
49#[async_trait::async_trait]
50pub trait BalancerConfigFeedback {
51    /// Gets current [`SurbBalancerConfig`] for a Session with `id`.
52    async fn get_config(&self, id: &SessionId) -> crate::errors::Result<SurbBalancerConfig>;
53    /// Updates current [`SurbBalancerConfig`] for a Session with `id`.
54    async fn on_config_update(&self, id: &SessionId, cfg: SurbBalancerConfig) -> crate::errors::Result<()>;
55}
56
57/// Configuration for the `SurbBalancer`.
58#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
59pub struct SurbBalancerConfig {
60    /// The desired number of SURBs to be always kept as a buffer locally or at the Session counterparty.
61    ///
62    /// The `SurbBalancer` will try to maintain approximately this number of SURBs
63    /// locally or remotely (at the counterparty) at all times.
64    ///
65    /// The local buffer is maintained by [regulating](SurbFlowController) the egress from the Session.
66    /// The remote buffer (at session counterparty) is maintained by regulating the flow of non-organic SURBs via
67    /// [keep-alive messages](crate::initiation::StartProtocol::KeepAlive).
68    ///
69    /// It does not make sense to set this value higher than the [`max_surb_buffer_size`](crate::SessionManagerConfig)
70    /// configuration at the counterparty.
71    ///
72    /// Default is 7000 SURBs.
73    #[default(7_000)]
74    pub target_surb_buffer_size: u64,
75    /// Maximum outflow of SURBs.
76    ///
77    /// - In the context of the local SURB buffer, this is the maximum egress Session traffic (= SURB consumption).
78    /// - In the context of the remote SURB buffer, this is the maximum egress of keep-alive messages to the
79    ///   counterparty (= artificial SURB production).
80    ///
81    /// The default is 5000 (which is 2500 packets/second currently)
82    #[default(5_000)]
83    pub max_surbs_per_sec: u64,
84
85    /// Sets what percentage of the target buffer size should be discarded at each window.
86    ///
87    /// The [`SurbBalancer`] will discard the given percentage of `target_surb_buffer_size` at each
88    /// window with the given `Duration`.
89    ///
90    /// The default is `(60, 0.05)` (5% of the target buffer size is discarded every 60 seconds).
91    #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
92    pub surb_decay: Option<(Duration, f64)>,
93}
94
95/// Runs a continuous process that attempts to [evaluate](SurbFlowEstimator) and
96/// [regulate](SurbFlowController) the flow of SURBs to the Session counterparty,
97/// to keep the number of SURBs locally or at the counterparty at a certain level.
98///
99/// Internally, the Balancer uses an implementation of [`SurbBalancerController`] to
100/// control the rate of SURBs consumed or sent to the counterparty
101/// each time the [`update`](SurbBalancer::update) method is called:
102///
103/// 1. The size of the SURB buffer at locally or at the counterparty is estimated using [`SurbFlowEstimator`].
104/// 2. Error against a set-point given in [`SurbBalancerConfig`] is evaluated in the `SurbBalancerController`.
105/// 3. The `SurbBalancerController` applies a new SURB flow rate value using the [`SurbFlowController`].
106///
107/// In the local context, the `SurbFlowController` might simply regulate the egress traffic from the
108/// Session, slowing it down to avoid fast SURB drainage.
109///
110/// In the remote context, the `SurbFlowController` might regulate the flow of non-organic SURBs via
111/// Start protocol's [`KeepAlive`](crate::initiation::StartProtocol::KeepAlive) messages to deliver additional
112/// SURBs to the counterparty.
113pub struct SurbBalancer<C, E, F> {
114    session_id: SessionId,
115    controller: C,
116    surb_estimator: E,
117    flow_control: F,
118    current_buffer: u64,
119    last_estimator_state: SimpleSurbFlowEstimator,
120    last_update: std::time::Instant,
121    last_decay: std::time::Instant,
122    was_below_target: bool,
123}
124
125impl<C, E, F> SurbBalancer<C, E, F>
126where
127    C: SurbBalancerController + Send + Sync + 'static,
128    E: SurbFlowEstimator + Send + Sync + 'static,
129    F: SurbFlowController + Send + Sync + 'static,
130{
131    pub fn new(
132        session_id: SessionId,
133        mut controller: C,
134        surb_estimator: E,
135        flow_control: F,
136        cfg: SurbBalancerConfig,
137    ) -> Self {
138        #[cfg(all(feature = "prometheus", not(test)))]
139        {
140            let sid = session_id.to_string();
141            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
142            METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
143        }
144
145        controller.set_target_and_limit(BalancerControllerBounds::new(
146            cfg.target_surb_buffer_size,
147            cfg.max_surbs_per_sec,
148        ));
149
150        Self {
151            surb_estimator,
152            flow_control,
153            controller,
154            session_id,
155            current_buffer: 0,
156            last_estimator_state: Default::default(),
157            last_update: std::time::Instant::now(),
158            last_decay: std::time::Instant::now(),
159            was_below_target: true,
160        }
161    }
162
163    /// Computes the next control update and adjusts the [`SurbFlowController`] rate accordingly.
164    #[tracing::instrument(level = "trace", skip_all)]
165    fn update(&mut self, surb_decay: Option<&(Duration, f64)>) -> u64 {
166        let dt = self.last_update.elapsed();
167        if dt < Duration::from_millis(10) {
168            debug!("time elapsed since last update is too short, skipping update");
169            return self.current_buffer;
170        }
171
172        self.last_update = std::time::Instant::now();
173
174        // Take a snapshot of the active SURB estimator and calculate the balance change
175        let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
176        let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
177            error!("non-monotonic change in SURB estimators");
178            return self.current_buffer;
179        };
180
181        self.last_estimator_state = snapshot;
182        self.current_buffer = self.current_buffer.saturating_add_signed(target_buffer_change);
183
184        // If SURB decaying is enabled, check if the decay window has elapsed
185        // and calculate the number of SURBs that will be discarded
186        if let Some(num_decayed_surbs) = surb_decay
187            .as_ref()
188            .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
189            .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * *decay_coeff).round() as u64)
190        {
191            self.current_buffer = self.current_buffer.saturating_sub(num_decayed_surbs);
192            self.last_decay = std::time::Instant::now();
193            trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
194        }
195
196        // Error from the desired target SURB buffer size at counterparty
197        let error = self.current_buffer as i64 - self.controller.bounds().target() as i64;
198
199        if self.was_below_target && error >= 0 {
200            trace!(current = self.current_buffer, "reached target SURB buffer size");
201            self.was_below_target = false;
202        } else if !self.was_below_target && error < 0 {
203            trace!(current = self.current_buffer, "SURB buffer size is below target");
204            self.was_below_target = true;
205        }
206
207        trace!(
208            ?dt,
209            delta = target_buffer_change,
210            rate = target_buffer_change as f64 / dt.as_secs_f64(),
211            current = self.current_buffer,
212            error,
213            "estimated SURB buffer change"
214        );
215
216        let output = self.controller.next_control_output(self.current_buffer);
217        trace!(output, "next balancer control output for session");
218
219        self.flow_control.adjust_surb_flow(output as usize);
220
221        #[cfg(all(feature = "prometheus", not(test)))]
222        {
223            let sid = self.session_id.to_string();
224            METRIC_CURRENT_BUFFER.set(&[&sid], self.current_buffer as f64);
225            METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
226            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
227            METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
228            METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
229        }
230
231        self.current_buffer
232    }
233
234    /// Spawns a new task that performs updates of the given [`SurbBalancer`] at the given `sampling_interval`.
235    ///
236    /// If `cfg_feedback` is given, [`SurbBalancerConfig`] can be queried for updates and also updated
237    /// if the underlying [`SurbBalancerController`] also does target updates.
238    ///
239    /// Returns a stream of current estimated buffer levels, and also an `AbortHandle`
240    /// to terminate the loop. If `abort_reg` was given, the returned `AbortHandle` corresponds
241    /// to it.
242    #[instrument(level = "debug", skip_all, fields(session_id = %self.session_id))]
243    pub fn start_control_loop<B>(
244        mut self,
245        sampling_interval: Duration,
246        cfg_feedback: B,
247        abort_reg: Option<AbortRegistration>,
248    ) -> (impl futures::Stream<Item = u64>, AbortHandle)
249    where
250        B: BalancerConfigFeedback + Send + Sync + 'static,
251    {
252        // Get abort handle and registration (or create new ones)
253        let (abort_handle, abort_reg) = abort_reg
254            .map(|reg| (reg.handle(), reg))
255            .unwrap_or_else(AbortHandle::new_pair);
256
257        // Start an interval stream at which the balancer will sample and perform updates
258        let sampling_stream = futures::stream::Abortable::new(
259            futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
260            abort_reg,
261        );
262
263        let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
264            .ok()
265            .and_then(|s| s.trim().parse::<usize>().ok())
266            .filter(|&c| c > 0)
267            .unwrap_or(32_768);
268
269        tracing::debug!(
270            capacity = balancer_level_capacity,
271            "Creating session balancer level channel"
272        );
273        let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
274        hopr_async_runtime::prelude::spawn(
275            async move {
276                pin_mut!(sampling_stream);
277                while sampling_stream.next().await.is_some() {
278                    // This call should not update the popularity estimator of the cache,
279                    // so that it is still allowed to expire.
280                    let Ok(mut current_cfg) = cfg_feedback.get_config(&self.session_id).await else {
281                        error!("cannot get config for session");
282                        break;
283                    };
284
285                    let current_bounds = BalancerControllerBounds::new(
286                        current_cfg.target_surb_buffer_size,
287                        current_cfg.max_surbs_per_sec,
288                    );
289
290                    // Check if the balancer controller needs to be reconfigured
291                    if current_bounds != self.controller.bounds() {
292                        self.controller.set_target_and_limit(current_bounds);
293                        debug!(?current_cfg, "surb balancer has been reconfigured");
294                    }
295
296                    let bounds_before_update = self.controller.bounds();
297
298                    // Perform controller update (this internally samples the SurbFlowEstimator)
299                    // and send an update about the current level to the outgoing stream.
300                    // If the other party has closed the stream, we don't care about the update.
301                    let level = self.update(current_cfg.surb_decay.as_ref());
302                    if !level_tx.is_closed() {
303                        if let Err(error) = level_tx.try_send(level) {
304                            tracing::error!(%error, "cannot send balancer level update");
305                        }
306                    }
307
308                    // See if the setpoint has been updated at the controller as a result
309                    // of the update step, because some controllers (such as the SimpleBalancerController)
310                    // permit that.
311                    let bounds_after_update = self.controller.bounds();
312                    if bounds_before_update != bounds_after_update {
313                        current_cfg.target_surb_buffer_size = bounds_after_update.target();
314                        current_cfg.max_surbs_per_sec = bounds_after_update.output_limit();
315                        match cfg_feedback.on_config_update(&self.session_id, current_cfg).await {
316                            Ok(_) => debug!(
317                                ?bounds_before_update,
318                                ?bounds_after_update,
319                                "controller bounds has changed after update"
320                            ),
321                            Err(error) => error!(%error, "failed to update controller bounds after it changed"),
322                        }
323                    }
324                }
325
326                debug!("balancer done");
327            }
328            .in_current_span(),
329        );
330
331        (level_rx, abort_handle)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use std::sync::{Arc, atomic::AtomicU64};
338
339    use hopr_crypto_random::Randomizable;
340    use hopr_internal_types::prelude::HoprPseudonym;
341
342    use super::*;
343    use crate::balancer::{
344        AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController, simple::SimpleBalancerController,
345    };
346
347    #[test_log::test]
348    fn surb_balancer_should_start_increase_level_when_below_target() {
349        let production_rate = Arc::new(AtomicU64::new(0));
350        let consumption_rate = 100;
351        let steps = 3;
352        let step_duration = std::time::Duration::from_millis(1000);
353
354        let mut controller = MockSurbFlowController::new();
355        let production_rate_clone = production_rate.clone();
356        controller
357            .expect_adjust_surb_flow()
358            .times(steps)
359            .with(mockall::predicate::ge(100))
360            .returning(move |r| {
361                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
362            });
363
364        let surb_estimator = AtomicSurbFlowEstimator::default();
365        let mut balancer = SurbBalancer::new(
366            SessionId::new(1234_u64, HoprPseudonym::random()),
367            PidBalancerController::default(),
368            surb_estimator.clone(),
369            controller,
370            SurbBalancerConfig {
371                target_surb_buffer_size: 5_000,
372                max_surbs_per_sec: 2500,
373                ..Default::default()
374            },
375        );
376
377        let mut last_update = 0;
378        for i in 0..steps {
379            std::thread::sleep(step_duration);
380            surb_estimator.produced.fetch_add(
381                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
382                std::sync::atomic::Ordering::Relaxed,
383            );
384            surb_estimator.consumed.fetch_add(
385                consumption_rate * step_duration.as_secs(),
386                std::sync::atomic::Ordering::Relaxed,
387            );
388
389            let next_update = balancer.update(None);
390            assert!(
391                i == 0 || next_update > last_update,
392                "{next_update} should be greater than {last_update}"
393            );
394            last_update = next_update;
395        }
396    }
397
398    #[test_log::test]
399    fn surb_balancer_should_start_decrease_level_when_above_target() {
400        let production_rate = Arc::new(AtomicU64::new(11_000));
401        let consumption_rate = 100;
402        let steps = 3;
403        let step_duration = std::time::Duration::from_millis(1000);
404
405        let mut controller = MockSurbFlowController::new();
406        let production_rate_clone = production_rate.clone();
407        controller
408            .expect_adjust_surb_flow()
409            .times(steps)
410            .with(mockall::predicate::ge(0))
411            .returning(move |r| {
412                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
413            });
414
415        let surb_estimator = AtomicSurbFlowEstimator::default();
416        let mut balancer = SurbBalancer::new(
417            SessionId::new(1234_u64, HoprPseudonym::random()),
418            PidBalancerController::default(),
419            surb_estimator.clone(),
420            controller,
421            SurbBalancerConfig::default(),
422        );
423
424        let mut last_update = 0;
425        for i in 0..steps {
426            std::thread::sleep(step_duration);
427            surb_estimator.produced.fetch_add(
428                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
429                std::sync::atomic::Ordering::Relaxed,
430            );
431            surb_estimator.consumed.fetch_add(
432                consumption_rate * step_duration.as_secs(),
433                std::sync::atomic::Ordering::Relaxed,
434            );
435
436            let next_update = balancer.update(None);
437            assert!(
438                i == 0 || next_update < last_update,
439                "{next_update} should be greater than {last_update}"
440            );
441            last_update = next_update;
442        }
443    }
444
445    #[test_log::test(tokio::test)]
446    async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
447        const NUM_STEPS: usize = 5;
448        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
449        let cfg = SurbBalancerConfig {
450            target_surb_buffer_size: 5_000,
451            max_surbs_per_sec: 2500,
452            surb_decay: Some((Duration::from_millis(200), 0.05)),
453            ..Default::default()
454        };
455
456        let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
457        mock_balancer_feedback
458            .expect_get_config()
459            .with(mockall::predicate::eq(session_id))
460            .times(NUM_STEPS)
461            .returning(move |_| Ok(cfg));
462
463        mock_balancer_feedback.expect_on_config_update().never();
464
465        let mut mock_flow_ctl = MockSurbFlowController::new();
466        mock_flow_ctl
467            .expect_adjust_surb_flow()
468            .times(NUM_STEPS)
469            .returning(|_| ());
470
471        let mut balancer = SurbBalancer::new(
472            session_id,
473            PidBalancerController::default(),
474            SimpleSurbFlowEstimator::default(),
475            mock_flow_ctl,
476            cfg,
477        );
478
479        balancer.current_buffer = 5000;
480
481        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
482        let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
483        handle.abort();
484
485        assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
486    }
487
488    struct IterSurbFlowEstimator<P, C>(std::sync::Mutex<P>, std::sync::Mutex<C>);
489
490    impl<P, C> IterSurbFlowEstimator<P, C> {
491        fn new<I1, I2>(production: I1, consumption: I2) -> Self
492        where
493            I1: IntoIterator<Item = u64, IntoIter = P>,
494            I2: IntoIterator<Item = u64, IntoIter = C>,
495        {
496            Self(
497                std::sync::Mutex::new(production.into_iter()),
498                std::sync::Mutex::new(consumption.into_iter()),
499            )
500        }
501    }
502
503    impl<P, C> SurbFlowEstimator for IterSurbFlowEstimator<P, C>
504    where
505        P: Iterator<Item = u64>,
506        C: Iterator<Item = u64>,
507    {
508        fn estimate_surbs_consumed(&self) -> u64 {
509            self.1.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
510        }
511
512        fn estimate_surbs_produced(&self) -> u64 {
513            self.0.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
514        }
515    }
516
517    #[test_log::test(tokio::test)]
518    async fn surb_balancer_should_increase_target_when_using_simple_controller() {
519        let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
520        let cfg_1 = SurbBalancerConfig {
521            target_surb_buffer_size: 4500,
522            max_surbs_per_sec: 2500,
523            surb_decay: None,
524            ..Default::default()
525        };
526
527        let cfg_2 = SurbBalancerConfig {
528            target_surb_buffer_size: 5500,
529            max_surbs_per_sec: 3055,
530            surb_decay: None,
531            ..Default::default()
532        };
533
534        let mut seq = mockall::Sequence::new();
535
536        let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
537        mock_balancer_feedback
538            .expect_get_config()
539            .times(3)
540            .in_sequence(&mut seq)
541            .with(mockall::predicate::eq(session_id))
542            .returning(move |_| {
543                tracing::trace!("get config 1");
544                Ok(cfg_1)
545            });
546
547        mock_balancer_feedback
548            .expect_on_config_update()
549            .once()
550            .in_sequence(&mut seq)
551            .with(mockall::predicate::eq(session_id), mockall::predicate::eq(cfg_2))
552            .returning(|_, _| {
553                tracing::trace!("on config update");
554                Ok(())
555            });
556
557        mock_balancer_feedback
558            .expect_get_config()
559            .times(2)
560            .in_sequence(&mut seq)
561            .with(mockall::predicate::eq(session_id))
562            .returning(move |_| {
563                tracing::trace!("get config 2");
564                Ok(cfg_2)
565            });
566
567        let mut mock_flow_ctl = MockSurbFlowController::new();
568        mock_flow_ctl.expect_adjust_surb_flow().times(5).returning(|_| ());
569
570        let mut balancer = SurbBalancer::new(
571            session_id,
572            SimpleBalancerController::with_increasing_setpoint(0.2, 5),
573            IterSurbFlowEstimator::new([1000, 2000, 3000, 3000, 3000], vec![500, 1000, 1500, 1500, 1500]),
574            mock_flow_ctl,
575            cfg_1,
576        );
577
578        balancer.current_buffer = 4500;
579
580        let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
581
582        let levels = stream.take(5).collect::<Vec<_>>().await;
583        handle.abort();
584
585        assert_eq!(levels, vec![5000, 5500, 6000, 6000, 6000]);
586    }
587}