1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicU8, AtomicU64},
5 },
6 time::Duration,
7};
8
9use futures::{StreamExt, pin_mut};
10use hopr_async_runtime::AbortHandle;
11use tracing::{Instrument, instrument};
12
13use super::{
14 BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
15 SurbFlowController, SurbFlowEstimator,
16};
17use crate::SessionId;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21 static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::MultiGauge =
22 hopr_metrics::MultiGauge::new(
23 "hopr_surb_balancer_target_error_estimate",
24 "Target error estimation by the SURB balancer",
25 &["session_id"]
26 ).unwrap();
27 static ref METRIC_CONTROL_OUTPUT: hopr_metrics::MultiGauge =
28 hopr_metrics::MultiGauge::new(
29 "hopr_surb_balancer_control_output",
30 "Control output of the SURB balancer",
31 &["session_id"]
32 ).unwrap();
33 static ref METRIC_CURRENT_BUFFER: hopr_metrics::MultiGauge =
34 hopr_metrics::MultiGauge::new(
35 "hopr_surb_balancer_current_buffer_estimate",
36 "Estimated number of SURBs in the buffer",
37 &["session_id"]
38 ).unwrap();
39 static ref METRIC_CURRENT_TARGET: hopr_metrics::MultiGauge =
40 hopr_metrics::MultiGauge::new(
41 "hopr_surb_balancer_current_buffer_target",
42 "Current target (setpoint) number of SURBs in the buffer",
43 &["session_id"]
44 ).unwrap();
45 static ref METRIC_SURB_RATE: hopr_metrics::MultiGauge =
46 hopr_metrics::MultiGauge::new(
47 "hopr_surb_balancer_surbs_rate",
48 "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
49 &["session_id"]
50 ).unwrap();
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
55pub struct SurbBalancerConfig {
56 #[default(7_000)]
70 pub target_surb_buffer_size: u64,
71 #[default(5_000)]
80 pub max_surbs_per_sec: u64,
81
82 #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
89 pub surb_decay: Option<(Duration, f64)>,
90}
91
92impl SurbBalancerConfig {
93 #[inline]
95 pub fn as_controller_bounds(&self) -> BalancerControllerBounds {
96 BalancerControllerBounds::new(self.target_surb_buffer_size, self.max_surbs_per_sec)
97 }
98}
99
100#[derive(Debug, Default)]
102pub struct BalancerStateValues {
103 pub target_surb_buffer_size: AtomicU64,
104 pub max_surbs_per_sec: AtomicU64,
105 pub decay_duration_msec: AtomicU64,
106 pub decay_volume_pct: AtomicU8,
107 pub buffer_level: AtomicU64,
108}
109
110impl BalancerStateValues {
111 pub fn new(cfg: SurbBalancerConfig) -> Self {
113 let state = Self::default();
114 state.update(&cfg);
115 state
116 }
117
118 pub fn update(&self, cfg: &SurbBalancerConfig) {
121 self.target_surb_buffer_size
122 .store(cfg.target_surb_buffer_size, std::sync::atomic::Ordering::Relaxed);
123 self.max_surbs_per_sec
124 .store(cfg.max_surbs_per_sec, std::sync::atomic::Ordering::Relaxed);
125 self.decay_duration_msec.store(
126 cfg.surb_decay
127 .map(|(d, _)| d.as_millis().min(u64::MAX as u128) as u64)
128 .unwrap_or_default(),
129 std::sync::atomic::Ordering::Relaxed,
130 );
131 self.decay_volume_pct.store(
132 cfg.surb_decay
133 .map(|(_, p)| (p.clamp(0.0, 1.0) * 100.0).round() as u8)
134 .unwrap_or_default(),
135 std::sync::atomic::Ordering::Relaxed,
136 );
137 }
138
139 pub fn as_config(&self) -> SurbBalancerConfig {
141 SurbBalancerConfig {
142 target_surb_buffer_size: self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
143 max_surbs_per_sec: self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
144 surb_decay: self.surb_decay(),
145 }
146 }
147
148 pub fn is_disabled(&self) -> bool {
150 self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed) == 0
151 }
152
153 pub fn surb_decay(&self) -> Option<(Duration, f64)> {
155 Some((
156 self.decay_duration_msec.load(std::sync::atomic::Ordering::Relaxed),
157 self.decay_volume_pct.load(std::sync::atomic::Ordering::Relaxed),
158 ))
159 .filter(|&(d, p)| d > 0 && p > 0)
160 .map(|(d, p)| (Duration::from_millis(d), p as f64 / 100.0))
161 }
162
163 #[inline]
165 pub fn buffer_level(&self) -> u64 {
166 self.buffer_level.load(std::sync::atomic::Ordering::Relaxed)
167 }
168
169 #[inline]
171 pub fn controller_bounds(&self) -> BalancerControllerBounds {
172 BalancerControllerBounds::new(
173 self.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
174 self.max_surbs_per_sec.load(std::sync::atomic::Ordering::Relaxed),
175 )
176 }
177}
178
179impl From<SurbBalancerConfig> for BalancerStateValues {
180 fn from(cfg: SurbBalancerConfig) -> Self {
181 Self::new(cfg)
182 }
183}
184
185pub struct SurbBalancer<C, E, F> {
204 session_id: SessionId,
205 controller: C,
206 surb_estimator: E,
207 flow_control: F,
208 state: Arc<BalancerStateValues>,
209 last_estimator_state: SimpleSurbFlowEstimator,
210 last_update: std::time::Instant,
211 last_decay: std::time::Instant,
212 was_below_target: bool,
213}
214
215impl<C, E, F> SurbBalancer<C, E, F>
216where
217 C: SurbBalancerController + Send + Sync + 'static,
218 E: SurbFlowEstimator + Send + Sync + 'static,
219 F: SurbFlowController + Send + Sync + 'static,
220{
221 pub fn new(
222 session_id: SessionId,
223 mut controller: C,
224 surb_estimator: E,
225 flow_control: F,
226 state: Arc<BalancerStateValues>,
227 ) -> Self {
228 #[cfg(all(feature = "prometheus", not(test)))]
229 {
230 let sid = session_id.to_string();
231 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
232 METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
233 }
234
235 controller.set_target_and_limit(state.controller_bounds());
236
237 Self {
238 surb_estimator,
239 flow_control,
240 controller,
241 session_id,
242 state,
243 last_estimator_state: Default::default(),
244 last_update: std::time::Instant::now(),
245 last_decay: std::time::Instant::now(),
246 was_below_target: true,
247 }
248 }
249
250 #[tracing::instrument(level = "trace", skip_all)]
252 fn update(&mut self) -> u64 {
253 let dt = self.last_update.elapsed();
254
255 let mut current = self.state.buffer_level.load(std::sync::atomic::Ordering::Acquire);
257
258 if dt < Duration::from_millis(10) {
259 tracing::debug!("time elapsed since last update is too short, skipping update");
260 return current;
261 }
262
263 self.last_update = std::time::Instant::now();
264
265 let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
267 let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
268 tracing::error!("non-monotonic change in SURB estimators");
269 return current;
270 };
271
272 self.last_estimator_state = snapshot;
273 current = current.saturating_add_signed(target_buffer_change);
274
275 if let Some(num_decayed_surbs) = self
278 .state
279 .surb_decay()
280 .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
281 .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * decay_coeff).round() as u64)
282 {
283 current = current.saturating_sub(num_decayed_surbs);
284 self.last_decay = std::time::Instant::now();
285 tracing::trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
286 }
287
288 self.state
289 .buffer_level
290 .store(current, std::sync::atomic::Ordering::Release);
291
292 let error = current as i64 - self.controller.bounds().target() as i64;
294
295 if self.was_below_target && error >= 0 {
296 tracing::trace!(current, "reached target SURB buffer size");
297 self.was_below_target = false;
298 } else if !self.was_below_target && error < 0 {
299 tracing::trace!(current, "SURB buffer size is below target");
300 self.was_below_target = true;
301 }
302
303 tracing::trace!(
304 ?dt,
305 delta = target_buffer_change,
306 rate = target_buffer_change as f64 / dt.as_secs_f64(),
307 current,
308 error,
309 "estimated SURB buffer change"
310 );
311
312 let output = self.controller.next_control_output(current);
313 tracing::trace!(output, "next balancer control output for session");
314
315 self.flow_control.adjust_surb_flow(output as usize);
316
317 #[cfg(all(feature = "prometheus", not(test)))]
318 {
319 let sid = self.session_id.to_string();
320 METRIC_CURRENT_BUFFER.set(&[&sid], current as f64);
321 METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
322 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
323 METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
324 METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
325 }
326
327 current
328 }
329
330 #[instrument(level = "debug", skip(self), fields(session_id = %self.session_id))]
339 pub fn start_control_loop(
340 mut self,
341 sampling_interval: Duration,
342 ) -> (impl futures::Stream<Item = u64>, AbortHandle) {
343 let (abort_handle, abort_reg) = AbortHandle::new_pair();
344
345 let sampling_stream = futures::stream::Abortable::new(
347 futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
348 abort_reg,
349 );
350
351 let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
352 .ok()
353 .and_then(|s| s.trim().parse::<usize>().ok())
354 .filter(|&c| c > 0)
355 .unwrap_or(32_768);
356
357 tracing::debug!(
358 capacity = balancer_level_capacity,
359 "Creating session balancer level channel"
360 );
361 let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
362 hopr_async_runtime::prelude::spawn(
363 async move {
364 pin_mut!(sampling_stream);
365 while sampling_stream.next().await.is_some() {
366 let current_bounds = self.state.controller_bounds();
368 if current_bounds != self.controller.bounds() {
369 self.controller.set_target_and_limit(current_bounds);
370 tracing::debug!(new_cfg = ?self.state.as_config(), "surb balancer has been reconfigured");
371 }
372
373 let level = self.update();
377 if !level_tx.is_closed()
378 && let Err(error) = level_tx.try_send(level)
379 {
380 tracing::error!(%error, "cannot send balancer level update");
381 }
382 }
383
384 tracing::debug!("balancer done");
385 }
386 .in_current_span(),
387 );
388
389 (level_rx, abort_handle)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use std::sync::{Arc, atomic::AtomicU64};
396
397 use hopr_crypto_random::Randomizable;
398 use hopr_internal_types::prelude::HoprPseudonym;
399
400 use super::*;
401 use crate::balancer::{AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController};
402
403 #[test]
404 fn surb_balancer_config_should_be_convertible_to_atomics() {
405 let cfg = SurbBalancerConfig::default();
406 let state_data = BalancerStateValues::new(cfg);
407 assert_eq!(cfg, state_data.as_config());
408 }
409
410 #[test_log::test]
411 fn surb_balancer_should_start_increase_level_when_below_target() {
412 let production_rate = Arc::new(AtomicU64::new(0));
413 let consumption_rate = 100;
414 let steps = 3;
415 let step_duration = std::time::Duration::from_millis(1000);
416
417 let mut controller = MockSurbFlowController::new();
418 let production_rate_clone = production_rate.clone();
419 controller
420 .expect_adjust_surb_flow()
421 .times(steps)
422 .with(mockall::predicate::ge(100))
423 .returning(move |r| {
424 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
425 });
426
427 let surb_estimator = AtomicSurbFlowEstimator::default();
428 let mut balancer = SurbBalancer::new(
429 SessionId::new(1234_u64, HoprPseudonym::random()),
430 PidBalancerController::default(),
431 surb_estimator.clone(),
432 controller,
433 Arc::new(
434 SurbBalancerConfig {
435 target_surb_buffer_size: 5_000,
436 max_surbs_per_sec: 2500,
437 surb_decay: None,
438 }
439 .into(),
440 ),
441 );
442
443 let mut last_update = 0;
444 for i in 0..steps {
445 std::thread::sleep(step_duration);
446 surb_estimator.produced.fetch_add(
447 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
448 std::sync::atomic::Ordering::Relaxed,
449 );
450 surb_estimator.consumed.fetch_add(
451 consumption_rate * step_duration.as_secs(),
452 std::sync::atomic::Ordering::Relaxed,
453 );
454
455 let next_update = balancer.update();
456 assert!(
457 i == 0 || next_update > last_update,
458 "{next_update} should be greater than {last_update}"
459 );
460 last_update = next_update;
461 }
462 }
463
464 #[test_log::test]
465 fn surb_balancer_should_start_decrease_level_when_above_target() {
466 let production_rate = Arc::new(AtomicU64::new(11_000));
467 let consumption_rate = 100;
468 let steps = 3;
469 let step_duration = std::time::Duration::from_millis(1000);
470
471 let mut controller = MockSurbFlowController::new();
472 let production_rate_clone = production_rate.clone();
473 controller
474 .expect_adjust_surb_flow()
475 .times(steps)
476 .with(mockall::predicate::ge(0))
477 .returning(move |r| {
478 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
479 });
480
481 let surb_estimator = AtomicSurbFlowEstimator::default();
482 let mut balancer = SurbBalancer::new(
483 SessionId::new(1234_u64, HoprPseudonym::random()),
484 PidBalancerController::default(),
485 surb_estimator.clone(),
486 controller,
487 Arc::new(
488 SurbBalancerConfig {
489 surb_decay: None,
490 ..Default::default()
491 }
492 .into(),
493 ),
494 );
495
496 let mut last_update = 0;
497 for i in 0..steps {
498 std::thread::sleep(step_duration);
499 surb_estimator.produced.fetch_add(
500 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
501 std::sync::atomic::Ordering::Relaxed,
502 );
503 surb_estimator.consumed.fetch_add(
504 consumption_rate * step_duration.as_secs(),
505 std::sync::atomic::Ordering::Relaxed,
506 );
507
508 let next_update = balancer.update();
509 assert!(
510 i == 0 || next_update < last_update,
511 "{next_update} should be greater than {last_update}"
512 );
513 last_update = next_update;
514 }
515 }
516
517 #[test_log::test(tokio::test)]
518 async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
519 const NUM_STEPS: usize = 5;
520 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
521 let cfg = SurbBalancerConfig {
522 target_surb_buffer_size: 5_000,
523 max_surbs_per_sec: 2500,
524 surb_decay: Some((Duration::from_millis(200), 0.05)),
525 };
526
527 let mut mock_flow_ctl = MockSurbFlowController::new();
528 mock_flow_ctl
529 .expect_adjust_surb_flow()
530 .times(NUM_STEPS)
531 .returning(|_| ());
532
533 let balancer = SurbBalancer::new(
534 session_id,
535 PidBalancerController::default(),
536 SimpleSurbFlowEstimator::default(),
537 mock_flow_ctl,
538 Arc::new(cfg.into()),
539 );
540
541 balancer
542 .state
543 .buffer_level
544 .store(5000, std::sync::atomic::Ordering::Relaxed);
545
546 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100));
547 let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
548 handle.abort();
549
550 assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
551 }
552}