hopr_transport_session/balancer/
controller.rs1use std::fmt::Display;
2
3use pid::Pid;
4
5use crate::balancer::{SurbFlowController, SurbFlowEstimator};
6
7#[cfg(all(feature = "prometheus", not(test)))]
8lazy_static::lazy_static! {
9 static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::metrics::MultiGauge =
10 hopr_metrics::metrics::MultiGauge::new(
11 "hopr_surb_balancer_target_error_estimate",
12 "Target error estimation by the SURB balancer",
13 &["session_id"]
14 ).unwrap();
15 static ref METRIC_CONTROL_OUTPUT: hopr_metrics::metrics::MultiGauge =
16 hopr_metrics::metrics::MultiGauge::new(
17 "hopr_surb_balancer_control_output",
18 "hopr_surb_balancer_control_output",
19 &["session_id"]
20 ).unwrap();
21 static ref METRIC_SURBS_CONSUMED: hopr_metrics::metrics::MultiCounter =
22 hopr_metrics::metrics::MultiCounter::new(
23 "hopr_surb_balancer_surbs_consumed",
24 "Estimations of the number of SURBs consumed by the counterparty",
25 &["session_id"]
26 ).unwrap();
27 static ref METRIC_SURBS_PRODUCED: hopr_metrics::metrics::MultiCounter =
28 hopr_metrics::metrics::MultiCounter::new(
29 "hopr_surb_balancer_surbs_produced",
30 "Estimations of the number of SURBs produced for the counterparty",
31 &["session_id"]
32 ).unwrap();
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, smart_default::SmartDefault)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct SurbBalancerConfig {
39 #[default(5_000)]
44 pub target_surb_buffer_size: u64,
45 #[default(2_500)]
49 pub max_surbs_per_sec: u64,
50}
51
52pub struct SurbBalancer<I, O, F, S> {
64 session_id: S,
65 pid: Pid<f64>,
66 surb_production_estimator: O,
67 surb_consumption_estimator: I,
68 controller: F,
69 cfg: SurbBalancerConfig,
70 current_target_buffer: u64,
71 last_surbs_delivered: u64,
72 last_surbs_used: u64,
73 last_update: std::time::Instant,
74 was_below_target: bool,
75}
76
77const DEFAULT_P_GAIN: f64 = 0.6;
80const DEFAULT_I_GAIN: f64 = 0.7;
81const DEFAULT_D_GAIN: f64 = 0.2;
82
83impl<I, O, F, S> SurbBalancer<I, O, F, S>
84where
85 O: SurbFlowEstimator,
86 I: SurbFlowEstimator,
87 F: SurbFlowController,
88 S: Display,
89{
90 pub fn new(
91 session_id: S,
92 surb_production_estimator: O,
93 surb_consumption_estimator: I,
94 controller: F,
95 cfg: SurbBalancerConfig,
96 ) -> Self {
97 let max_surbs_per_sec = cfg.max_surbs_per_sec as f64;
99 let mut pid = Pid::new(cfg.target_surb_buffer_size as f64, max_surbs_per_sec);
100 pid.p(
101 std::env::var("HOPR_BALANCER_PID_P_GAIN")
102 .ok()
103 .and_then(|v| v.parse().ok())
104 .unwrap_or(DEFAULT_P_GAIN),
105 max_surbs_per_sec,
106 );
107 pid.i(
108 std::env::var("HOPR_BALANCER_PID_I_GAIN")
109 .ok()
110 .and_then(|v| v.parse().ok())
111 .unwrap_or(DEFAULT_I_GAIN),
112 max_surbs_per_sec,
113 );
114 pid.d(
115 std::env::var("HOPR_BALANCER_PID_D_GAIN")
116 .ok()
117 .and_then(|v| v.parse().ok())
118 .unwrap_or(DEFAULT_D_GAIN),
119 max_surbs_per_sec,
120 );
121
122 #[cfg(all(feature = "prometheus", not(test)))]
123 {
124 let sid = session_id.to_string();
125 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
126 METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
127 }
128
129 Self {
130 surb_production_estimator,
131 surb_consumption_estimator,
132 controller,
133 pid,
134 cfg,
135 session_id,
136 current_target_buffer: 0,
137 last_surbs_delivered: 0,
138 last_surbs_used: 0,
139 last_update: std::time::Instant::now(),
140 was_below_target: true,
141 }
142 }
143
144 #[tracing::instrument(level = "trace", skip(self), fields(session_id = %self.session_id))]
146 pub fn update(&mut self) -> u64 {
147 let dt = self.last_update.elapsed();
148 if dt < std::time::Duration::from_millis(10) {
149 tracing::debug!("time elapsed since last update is too short, skipping update");
150 return self.current_target_buffer;
151 }
152
153 self.last_update = std::time::Instant::now();
154
155 let current_surbs_delivered = self.surb_production_estimator.estimate_surb_turnout();
157 let surbs_delivered_delta = current_surbs_delivered - self.last_surbs_delivered;
158 self.last_surbs_delivered = current_surbs_delivered;
159
160 let current_surbs_used = self.surb_consumption_estimator.estimate_surb_turnout();
162 let surbs_used_delta = current_surbs_used - self.last_surbs_used;
163 self.last_surbs_used = current_surbs_used;
164
165 let target_buffer_change = surbs_delivered_delta as i64 - surbs_used_delta as i64;
167 self.current_target_buffer = self.current_target_buffer.saturating_add_signed(target_buffer_change);
168
169 let error = self.current_target_buffer as i64 - self.cfg.target_surb_buffer_size as i64;
171
172 if self.was_below_target && error >= 0 {
173 tracing::debug!(session_id = %self.session_id, current = self.current_target_buffer, "reached target SURB buffer size");
174 self.was_below_target = false;
175 } else if !self.was_below_target && error < 0 {
176 tracing::debug!(session_id = %self.session_id, current = self.current_target_buffer, "SURB buffer size is below target");
177 self.was_below_target = true;
178 }
179
180 tracing::trace!(
181 session_id = %self.session_id,
182 ?dt,
183 delta = target_buffer_change,
184 current = self.current_target_buffer,
185 error,
186 rate_up = surbs_delivered_delta as f64 / dt.as_secs_f64(),
187 rate_down = surbs_used_delta as f64 / dt.as_secs_f64(),
188 "estimated SURB buffer change"
189 );
190
191 let output = self.pid.next_control_output(self.current_target_buffer as f64);
192 let corrected_output = output.output.clamp(0.0, self.cfg.max_surbs_per_sec as f64);
193 self.controller.adjust_surb_flow(corrected_output as usize);
194
195 tracing::trace!(control_output = corrected_output, "next control output");
196
197 #[cfg(all(feature = "prometheus", not(test)))]
198 {
199 let sid = self.session_id.to_string();
200 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
201 METRIC_CONTROL_OUTPUT.set(&[&sid], corrected_output);
202 METRIC_SURBS_CONSUMED.increment_by(&[&sid], surbs_used_delta);
203 METRIC_SURBS_PRODUCED.increment_by(&[&sid], surbs_delivered_delta);
204 }
205
206 self.current_target_buffer
207 }
208
209 #[allow(unused)]
211 pub fn set_exact_target_buffer_size(&mut self, target_buffer_size: u64) {
212 self.current_target_buffer = target_buffer_size;
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::sync::{Arc, atomic::AtomicU64};
219
220 use super::*;
221 use crate::balancer::MockSurbFlowController;
222
223 #[test_log::test]
224 fn surb_balancer_should_start_increase_level_when_below_target() {
225 let production_rate = Arc::new(AtomicU64::new(0));
226 let consumption_rate = 100;
227 let steps = 3;
228 let step_duration = std::time::Duration::from_millis(1000);
229
230 let mut controller = MockSurbFlowController::new();
231 let production_rate_clone = production_rate.clone();
232 controller
233 .expect_adjust_surb_flow()
234 .times(steps)
235 .with(mockall::predicate::ge(100))
236 .returning(move |r| {
237 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
238 });
239
240 let surb_production_count = Arc::new(AtomicU64::new(0));
241 let surb_consumption_count = Arc::new(AtomicU64::new(0));
242 let mut balancer = SurbBalancer::new(
243 "test",
244 surb_production_count.clone(),
245 surb_consumption_count.clone(),
246 controller,
247 SurbBalancerConfig::default(),
248 );
249
250 let mut last_update = 0;
251 for i in 0..steps {
252 std::thread::sleep(step_duration);
253 surb_production_count.fetch_add(
254 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
255 std::sync::atomic::Ordering::Relaxed,
256 );
257 surb_consumption_count.fetch_add(
258 consumption_rate * step_duration.as_secs(),
259 std::sync::atomic::Ordering::Relaxed,
260 );
261
262 let next_update = balancer.update();
263 assert!(
264 i == 0 || next_update > last_update,
265 "{next_update} should be greater than {last_update}"
266 );
267 last_update = next_update;
268 }
269 }
270
271 #[test_log::test]
272 fn surb_balancer_should_start_decrease_level_when_above_target() {
273 let production_rate = Arc::new(AtomicU64::new(11_000));
274 let consumption_rate = 100;
275 let steps = 3;
276 let step_duration = std::time::Duration::from_millis(1000);
277
278 let mut controller = MockSurbFlowController::new();
279 let production_rate_clone = production_rate.clone();
280 controller
281 .expect_adjust_surb_flow()
282 .times(steps)
283 .with(mockall::predicate::ge(0))
284 .returning(move |r| {
285 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
286 });
287
288 let surb_production_count = Arc::new(AtomicU64::new(0));
289 let surb_consumption_count = Arc::new(AtomicU64::new(0));
290 let mut balancer = SurbBalancer::new(
291 "test",
292 surb_production_count.clone(),
293 surb_consumption_count.clone(),
294 controller,
295 SurbBalancerConfig::default(),
296 );
297
298 let mut last_update = 0;
299 for i in 0..steps {
300 std::thread::sleep(step_duration);
301 surb_production_count.fetch_add(
302 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
303 std::sync::atomic::Ordering::Relaxed,
304 );
305 surb_consumption_count.fetch_add(
306 consumption_rate * step_duration.as_secs(),
307 std::sync::atomic::Ordering::Relaxed,
308 );
309
310 let next_update = balancer.update();
311 assert!(
312 i == 0 || next_update < last_update,
313 "{next_update} should be greater than {last_update}"
314 );
315 last_update = next_update;
316 }
317 }
318}