hopr_transport_session/balancer/
controller.rs

1use std::fmt::Display;
2
3use pid::Pid;
4
5use crate::balancer::{SurbFlowController, SurbFlowEstimator};
6
7#[cfg(all(feature = "prometheus", not(test)))]
8lazy_static::lazy_static! {
9    static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::metrics::MultiGauge =
10        hopr_metrics::metrics::MultiGauge::new(
11            "hopr_surb_balancer_target_error_estimate",
12            "Target error estimation by the SURB balancer",
13            &["session_id"]
14    ).unwrap();
15    static ref METRIC_CONTROL_OUTPUT: hopr_metrics::metrics::MultiGauge =
16        hopr_metrics::metrics::MultiGauge::new(
17            "hopr_surb_balancer_control_output",
18            "hopr_surb_balancer_control_output",
19            &["session_id"]
20    ).unwrap();
21    static ref METRIC_SURBS_CONSUMED: hopr_metrics::metrics::MultiCounter =
22        hopr_metrics::metrics::MultiCounter::new(
23            "hopr_surb_balancer_surbs_consumed",
24            "Estimations of the number of SURBs consumed by the counterparty",
25            &["session_id"]
26    ).unwrap();
27    static ref METRIC_SURBS_PRODUCED: hopr_metrics::metrics::MultiCounter =
28        hopr_metrics::metrics::MultiCounter::new(
29            "hopr_surb_balancer_surbs_produced",
30            "Estimations of the number of SURBs produced for the counterparty",
31            &["session_id"]
32    ).unwrap();
33}
34
35/// Configuration for the [`SurbBalancer`].
36#[derive(Clone, Copy, Debug, PartialEq, Eq, smart_default::SmartDefault)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct SurbBalancerConfig {
39    /// The desired number of SURBs to be always kept as a buffer at the Session counterparty.
40    ///
41    /// The [`SurbBalancer`] will try to maintain approximately this number of SURBs
42    /// at the counterparty at all times, by regulating the [flow of non-organic SURBs](SurbFlowController).
43    #[default(5_000)]
44    pub target_surb_buffer_size: u64,
45    /// Maximum outflow of non-organic surbs.
46    ///
47    /// The default is 2500 (which is 1250 packets/second currently)
48    #[default(2_500)]
49    pub max_surbs_per_sec: u64,
50}
51
52/// Runs a continuous process that attempts to [evaluate](SurbFlowEstimator) and
53/// [regulate](SurbFlowController) the flow of non-organic SURBs to the Session counterparty,
54/// to keep the number of SURBs at the counterparty at a certain level.
55///
56/// Internally, the Balancer uses a PID (Proportional Integral Derivative) controller to
57/// control the rate of SURBs sent to the counterparty
58/// each time the [`update`](SurbBalancer::update) method is called:
59///
60/// 1. The size of the SURB buffer at the counterparty is estimated using [`SurbFlowEstimator`].
61/// 2. Error against a set-point given in [`SurbBalancerConfig`] is evaluated in the PID controller.
62/// 3. The PID controller applies a new SURB flow rate value using the [`SurbFlowController`].
63pub struct SurbBalancer<I, O, F, S> {
64    session_id: S,
65    pid: Pid<f64>,
66    surb_production_estimator: O,
67    surb_consumption_estimator: I,
68    controller: F,
69    cfg: SurbBalancerConfig,
70    current_target_buffer: u64,
71    last_surbs_delivered: u64,
72    last_surbs_used: u64,
73    last_update: std::time::Instant,
74    was_below_target: bool,
75}
76
77// Default coefficients for the PID controller
78// This might be tweaked in the future.
79const DEFAULT_P_GAIN: f64 = 0.6;
80const DEFAULT_I_GAIN: f64 = 0.7;
81const DEFAULT_D_GAIN: f64 = 0.2;
82
83impl<I, O, F, S> SurbBalancer<I, O, F, S>
84where
85    O: SurbFlowEstimator,
86    I: SurbFlowEstimator,
87    F: SurbFlowController,
88    S: Display,
89{
90    pub fn new(
91        session_id: S,
92        surb_production_estimator: O,
93        surb_consumption_estimator: I,
94        controller: F,
95        cfg: SurbBalancerConfig,
96    ) -> Self {
97        // Initialize the PID controller with default tuned gains
98        let max_surbs_per_sec = cfg.max_surbs_per_sec as f64;
99        let mut pid = Pid::new(cfg.target_surb_buffer_size as f64, max_surbs_per_sec);
100        pid.p(
101            std::env::var("HOPR_BALANCER_PID_P_GAIN")
102                .ok()
103                .and_then(|v| v.parse().ok())
104                .unwrap_or(DEFAULT_P_GAIN),
105            max_surbs_per_sec,
106        );
107        pid.i(
108            std::env::var("HOPR_BALANCER_PID_I_GAIN")
109                .ok()
110                .and_then(|v| v.parse().ok())
111                .unwrap_or(DEFAULT_I_GAIN),
112            max_surbs_per_sec,
113        );
114        pid.d(
115            std::env::var("HOPR_BALANCER_PID_D_GAIN")
116                .ok()
117                .and_then(|v| v.parse().ok())
118                .unwrap_or(DEFAULT_D_GAIN),
119            max_surbs_per_sec,
120        );
121
122        #[cfg(all(feature = "prometheus", not(test)))]
123        {
124            let sid = session_id.to_string();
125            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
126            METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
127        }
128
129        Self {
130            surb_production_estimator,
131            surb_consumption_estimator,
132            controller,
133            pid,
134            cfg,
135            session_id,
136            current_target_buffer: 0,
137            last_surbs_delivered: 0,
138            last_surbs_used: 0,
139            last_update: std::time::Instant::now(),
140            was_below_target: true,
141        }
142    }
143
144    /// Computes the next control update and adjusts the [`SurbFlowController`] rate accordingly.
145    #[tracing::instrument(level = "trace", skip(self), fields(session_id = %self.session_id))]
146    pub fn update(&mut self) -> u64 {
147        let dt = self.last_update.elapsed();
148        if dt < std::time::Duration::from_millis(10) {
149            tracing::debug!("time elapsed since last update is too short, skipping update");
150            return self.current_target_buffer;
151        }
152
153        self.last_update = std::time::Instant::now();
154
155        // Number of SURBs sent to the counterparty since the last update
156        let current_surbs_delivered = self.surb_production_estimator.estimate_surb_turnout();
157        let surbs_delivered_delta = current_surbs_delivered - self.last_surbs_delivered;
158        self.last_surbs_delivered = current_surbs_delivered;
159
160        // Number of SURBs used by the counterparty since the last update
161        let current_surbs_used = self.surb_consumption_estimator.estimate_surb_turnout();
162        let surbs_used_delta = current_surbs_used - self.last_surbs_used;
163        self.last_surbs_used = current_surbs_used;
164
165        // Estimated change in counterparty's SURB buffer
166        let target_buffer_change = surbs_delivered_delta as i64 - surbs_used_delta as i64;
167        self.current_target_buffer = self.current_target_buffer.saturating_add_signed(target_buffer_change);
168
169        // Error from the desired target SURB buffer size at counterparty
170        let error = self.current_target_buffer as i64 - self.cfg.target_surb_buffer_size as i64;
171
172        if self.was_below_target && error >= 0 {
173            tracing::debug!(session_id = %self.session_id, current = self.current_target_buffer, "reached target SURB buffer size");
174            self.was_below_target = false;
175        } else if !self.was_below_target && error < 0 {
176            tracing::debug!(session_id = %self.session_id, current = self.current_target_buffer, "SURB buffer size is below target");
177            self.was_below_target = true;
178        }
179
180        tracing::trace!(
181            session_id = %self.session_id,
182            ?dt,
183            delta = target_buffer_change,
184            current = self.current_target_buffer,
185            error,
186            rate_up = surbs_delivered_delta as f64 / dt.as_secs_f64(),
187            rate_down = surbs_used_delta as f64 / dt.as_secs_f64(),
188            "estimated SURB buffer change"
189        );
190
191        let output = self.pid.next_control_output(self.current_target_buffer as f64);
192        let corrected_output = output.output.clamp(0.0, self.cfg.max_surbs_per_sec as f64);
193        self.controller.adjust_surb_flow(corrected_output as usize);
194
195        tracing::trace!(control_output = corrected_output, "next control output");
196
197        #[cfg(all(feature = "prometheus", not(test)))]
198        {
199            let sid = self.session_id.to_string();
200            METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
201            METRIC_CONTROL_OUTPUT.set(&[&sid], corrected_output);
202            METRIC_SURBS_CONSUMED.increment_by(&[&sid], surbs_used_delta);
203            METRIC_SURBS_PRODUCED.increment_by(&[&sid], surbs_delivered_delta);
204        }
205
206        self.current_target_buffer
207    }
208
209    /// Allows setting the target buffer size when its value is known exactly.
210    #[allow(unused)]
211    pub fn set_exact_target_buffer_size(&mut self, target_buffer_size: u64) {
212        self.current_target_buffer = target_buffer_size;
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use std::sync::{Arc, atomic::AtomicU64};
219
220    use super::*;
221    use crate::balancer::MockSurbFlowController;
222
223    #[test_log::test]
224    fn surb_balancer_should_start_increase_level_when_below_target() {
225        let production_rate = Arc::new(AtomicU64::new(0));
226        let consumption_rate = 100;
227        let steps = 3;
228        let step_duration = std::time::Duration::from_millis(1000);
229
230        let mut controller = MockSurbFlowController::new();
231        let production_rate_clone = production_rate.clone();
232        controller
233            .expect_adjust_surb_flow()
234            .times(steps)
235            .with(mockall::predicate::ge(100))
236            .returning(move |r| {
237                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
238            });
239
240        let surb_production_count = Arc::new(AtomicU64::new(0));
241        let surb_consumption_count = Arc::new(AtomicU64::new(0));
242        let mut balancer = SurbBalancer::new(
243            "test",
244            surb_production_count.clone(),
245            surb_consumption_count.clone(),
246            controller,
247            SurbBalancerConfig::default(),
248        );
249
250        let mut last_update = 0;
251        for i in 0..steps {
252            std::thread::sleep(step_duration);
253            surb_production_count.fetch_add(
254                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
255                std::sync::atomic::Ordering::Relaxed,
256            );
257            surb_consumption_count.fetch_add(
258                consumption_rate * step_duration.as_secs(),
259                std::sync::atomic::Ordering::Relaxed,
260            );
261
262            let next_update = balancer.update();
263            assert!(
264                i == 0 || next_update > last_update,
265                "{next_update} should be greater than {last_update}"
266            );
267            last_update = next_update;
268        }
269    }
270
271    #[test_log::test]
272    fn surb_balancer_should_start_decrease_level_when_above_target() {
273        let production_rate = Arc::new(AtomicU64::new(11_000));
274        let consumption_rate = 100;
275        let steps = 3;
276        let step_duration = std::time::Duration::from_millis(1000);
277
278        let mut controller = MockSurbFlowController::new();
279        let production_rate_clone = production_rate.clone();
280        controller
281            .expect_adjust_surb_flow()
282            .times(steps)
283            .with(mockall::predicate::ge(0))
284            .returning(move |r| {
285                production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
286            });
287
288        let surb_production_count = Arc::new(AtomicU64::new(0));
289        let surb_consumption_count = Arc::new(AtomicU64::new(0));
290        let mut balancer = SurbBalancer::new(
291            "test",
292            surb_production_count.clone(),
293            surb_consumption_count.clone(),
294            controller,
295            SurbBalancerConfig::default(),
296        );
297
298        let mut last_update = 0;
299        for i in 0..steps {
300            std::thread::sleep(step_duration);
301            surb_production_count.fetch_add(
302                production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
303                std::sync::atomic::Ordering::Relaxed,
304            );
305            surb_consumption_count.fetch_add(
306                consumption_rate * step_duration.as_secs(),
307                std::sync::atomic::Ordering::Relaxed,
308            );
309
310            let next_update = balancer.update();
311            assert!(
312                i == 0 || next_update < last_update,
313                "{next_update} should be greater than {last_update}"
314            );
315            last_update = next_update;
316        }
317    }
318}