1use std::time::Duration;
2
3use futures::{StreamExt, future::AbortRegistration, pin_mut};
4use hopr_async_runtime::AbortHandle;
5use tracing::{Instrument, debug, error, instrument, trace};
6
7use super::{
8 BalancerControllerBounds, MIN_BALANCER_SAMPLING_INTERVAL, SimpleSurbFlowEstimator, SurbBalancerController,
9 SurbFlowController, SurbFlowEstimator,
10};
11use crate::SessionId;
12
13#[cfg(all(feature = "prometheus", not(test)))]
14lazy_static::lazy_static! {
15 static ref METRIC_TARGET_ERROR_ESTIMATE: hopr_metrics::MultiGauge =
16 hopr_metrics::MultiGauge::new(
17 "hopr_surb_balancer_target_error_estimate",
18 "Target error estimation by the SURB balancer",
19 &["session_id"]
20 ).unwrap();
21 static ref METRIC_CONTROL_OUTPUT: hopr_metrics::MultiGauge =
22 hopr_metrics::MultiGauge::new(
23 "hopr_surb_balancer_control_output",
24 "Control output of the SURB balancer",
25 &["session_id"]
26 ).unwrap();
27 static ref METRIC_CURRENT_BUFFER: hopr_metrics::MultiGauge =
28 hopr_metrics::MultiGauge::new(
29 "hopr_surb_balancer_current_buffer_estimate",
30 "Estimated number of SURBs in the buffer",
31 &["session_id"]
32 ).unwrap();
33 static ref METRIC_CURRENT_TARGET: hopr_metrics::MultiGauge =
34 hopr_metrics::MultiGauge::new(
35 "hopr_surb_balancer_current_buffer_target",
36 "Current target (setpoint) number of SURBs in the buffer",
37 &["session_id"]
38 ).unwrap();
39 static ref METRIC_SURB_RATE: hopr_metrics::MultiGauge =
40 hopr_metrics::MultiGauge::new(
41 "hopr_surb_balancer_surbs_rate",
42 "Estimation of SURB rate per second (positive is buffer surplus, negative is buffer loss)",
43 &["session_id"]
44 ).unwrap();
45}
46
47#[cfg_attr(test, mockall::automock)]
49#[async_trait::async_trait]
50pub trait BalancerConfigFeedback {
51 async fn get_config(&self, id: &SessionId) -> crate::errors::Result<SurbBalancerConfig>;
53 async fn on_config_update(&self, id: &SessionId, cfg: SurbBalancerConfig) -> crate::errors::Result<()>;
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
59pub struct SurbBalancerConfig {
60 #[default(7_000)]
74 pub target_surb_buffer_size: u64,
75 #[default(5_000)]
83 pub max_surbs_per_sec: u64,
84
85 #[default(_code = "Some((Duration::from_secs(60), 0.05))")]
92 pub surb_decay: Option<(Duration, f64)>,
93}
94
95pub struct SurbBalancer<C, E, F> {
114 session_id: SessionId,
115 controller: C,
116 surb_estimator: E,
117 flow_control: F,
118 current_buffer: u64,
119 last_estimator_state: SimpleSurbFlowEstimator,
120 last_update: std::time::Instant,
121 last_decay: std::time::Instant,
122 was_below_target: bool,
123}
124
125impl<C, E, F> SurbBalancer<C, E, F>
126where
127 C: SurbBalancerController + Send + Sync + 'static,
128 E: SurbFlowEstimator + Send + Sync + 'static,
129 F: SurbFlowController + Send + Sync + 'static,
130{
131 pub fn new(
132 session_id: SessionId,
133 mut controller: C,
134 surb_estimator: E,
135 flow_control: F,
136 cfg: SurbBalancerConfig,
137 ) -> Self {
138 #[cfg(all(feature = "prometheus", not(test)))]
139 {
140 let sid = session_id.to_string();
141 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], 0.0);
142 METRIC_CONTROL_OUTPUT.set(&[&sid], 0.0);
143 }
144
145 controller.set_target_and_limit(BalancerControllerBounds::new(
146 cfg.target_surb_buffer_size,
147 cfg.max_surbs_per_sec,
148 ));
149
150 Self {
151 surb_estimator,
152 flow_control,
153 controller,
154 session_id,
155 current_buffer: 0,
156 last_estimator_state: Default::default(),
157 last_update: std::time::Instant::now(),
158 last_decay: std::time::Instant::now(),
159 was_below_target: true,
160 }
161 }
162
163 #[tracing::instrument(level = "trace", skip_all)]
165 fn update(&mut self, surb_decay: Option<&(Duration, f64)>) -> u64 {
166 let dt = self.last_update.elapsed();
167 if dt < Duration::from_millis(10) {
168 debug!("time elapsed since last update is too short, skipping update");
169 return self.current_buffer;
170 }
171
172 self.last_update = std::time::Instant::now();
173
174 let snapshot = SimpleSurbFlowEstimator::from(&self.surb_estimator);
176 let Some(target_buffer_change) = snapshot.estimated_surb_buffer_change(&self.last_estimator_state) else {
177 error!("non-monotonic change in SURB estimators");
178 return self.current_buffer;
179 };
180
181 self.last_estimator_state = snapshot;
182 self.current_buffer = self.current_buffer.saturating_add_signed(target_buffer_change);
183
184 if let Some(num_decayed_surbs) = surb_decay
187 .as_ref()
188 .filter(|(decay_window, _)| &self.last_decay.elapsed() >= decay_window)
189 .map(|(_, decay_coeff)| (self.controller.bounds().target() as f64 * *decay_coeff).round() as u64)
190 {
191 self.current_buffer = self.current_buffer.saturating_sub(num_decayed_surbs);
192 self.last_decay = std::time::Instant::now();
193 trace!(num_decayed_surbs, "SURBs were discarded due to automatic decay");
194 }
195
196 let error = self.current_buffer as i64 - self.controller.bounds().target() as i64;
198
199 if self.was_below_target && error >= 0 {
200 trace!(current = self.current_buffer, "reached target SURB buffer size");
201 self.was_below_target = false;
202 } else if !self.was_below_target && error < 0 {
203 trace!(current = self.current_buffer, "SURB buffer size is below target");
204 self.was_below_target = true;
205 }
206
207 trace!(
208 ?dt,
209 delta = target_buffer_change,
210 rate = target_buffer_change as f64 / dt.as_secs_f64(),
211 current = self.current_buffer,
212 error,
213 "estimated SURB buffer change"
214 );
215
216 let output = self.controller.next_control_output(self.current_buffer);
217 trace!(output, "next balancer control output for session");
218
219 self.flow_control.adjust_surb_flow(output as usize);
220
221 #[cfg(all(feature = "prometheus", not(test)))]
222 {
223 let sid = self.session_id.to_string();
224 METRIC_CURRENT_BUFFER.set(&[&sid], self.current_buffer as f64);
225 METRIC_CURRENT_TARGET.set(&[&sid], self.controller.bounds().target() as f64);
226 METRIC_TARGET_ERROR_ESTIMATE.set(&[&sid], error as f64);
227 METRIC_CONTROL_OUTPUT.set(&[&sid], output as f64);
228 METRIC_SURB_RATE.set(&[&sid], target_buffer_change as f64 / dt.as_secs_f64());
229 }
230
231 self.current_buffer
232 }
233
234 #[instrument(level = "debug", skip_all, fields(session_id = %self.session_id))]
243 pub fn start_control_loop<B>(
244 mut self,
245 sampling_interval: Duration,
246 cfg_feedback: B,
247 abort_reg: Option<AbortRegistration>,
248 ) -> (impl futures::Stream<Item = u64>, AbortHandle)
249 where
250 B: BalancerConfigFeedback + Send + Sync + 'static,
251 {
252 let (abort_handle, abort_reg) = abort_reg
254 .map(|reg| (reg.handle(), reg))
255 .unwrap_or_else(AbortHandle::new_pair);
256
257 let sampling_stream = futures::stream::Abortable::new(
259 futures_time::stream::interval(sampling_interval.max(MIN_BALANCER_SAMPLING_INTERVAL).into()),
260 abort_reg,
261 );
262
263 let balancer_level_capacity = std::env::var("HOPR_INTERNAL_SESSION_BALANCER_LEVEL_CAPACITY")
264 .ok()
265 .and_then(|s| s.trim().parse::<usize>().ok())
266 .filter(|&c| c > 0)
267 .unwrap_or(32_768);
268
269 tracing::debug!(
270 capacity = balancer_level_capacity,
271 "Creating session balancer level channel"
272 );
273 let (mut level_tx, level_rx) = futures::channel::mpsc::channel(balancer_level_capacity);
274 hopr_async_runtime::prelude::spawn(
275 async move {
276 pin_mut!(sampling_stream);
277 while sampling_stream.next().await.is_some() {
278 let Ok(mut current_cfg) = cfg_feedback.get_config(&self.session_id).await else {
281 error!("cannot get config for session");
282 break;
283 };
284
285 let current_bounds = BalancerControllerBounds::new(
286 current_cfg.target_surb_buffer_size,
287 current_cfg.max_surbs_per_sec,
288 );
289
290 if current_bounds != self.controller.bounds() {
292 self.controller.set_target_and_limit(current_bounds);
293 debug!(?current_cfg, "surb balancer has been reconfigured");
294 }
295
296 let bounds_before_update = self.controller.bounds();
297
298 let level = self.update(current_cfg.surb_decay.as_ref());
302 if !level_tx.is_closed() {
303 if let Err(error) = level_tx.try_send(level) {
304 tracing::error!(%error, "cannot send balancer level update");
305 }
306 }
307
308 let bounds_after_update = self.controller.bounds();
312 if bounds_before_update != bounds_after_update {
313 current_cfg.target_surb_buffer_size = bounds_after_update.target();
314 current_cfg.max_surbs_per_sec = bounds_after_update.output_limit();
315 match cfg_feedback.on_config_update(&self.session_id, current_cfg).await {
316 Ok(_) => debug!(
317 ?bounds_before_update,
318 ?bounds_after_update,
319 "controller bounds has changed after update"
320 ),
321 Err(error) => error!(%error, "failed to update controller bounds after it changed"),
322 }
323 }
324 }
325
326 debug!("balancer done");
327 }
328 .in_current_span(),
329 );
330
331 (level_rx, abort_handle)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use std::sync::{Arc, atomic::AtomicU64};
338
339 use hopr_crypto_random::Randomizable;
340 use hopr_internal_types::prelude::HoprPseudonym;
341
342 use super::*;
343 use crate::balancer::{
344 AtomicSurbFlowEstimator, MockSurbFlowController, pid::PidBalancerController, simple::SimpleBalancerController,
345 };
346
347 #[test_log::test]
348 fn surb_balancer_should_start_increase_level_when_below_target() {
349 let production_rate = Arc::new(AtomicU64::new(0));
350 let consumption_rate = 100;
351 let steps = 3;
352 let step_duration = std::time::Duration::from_millis(1000);
353
354 let mut controller = MockSurbFlowController::new();
355 let production_rate_clone = production_rate.clone();
356 controller
357 .expect_adjust_surb_flow()
358 .times(steps)
359 .with(mockall::predicate::ge(100))
360 .returning(move |r| {
361 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
362 });
363
364 let surb_estimator = AtomicSurbFlowEstimator::default();
365 let mut balancer = SurbBalancer::new(
366 SessionId::new(1234_u64, HoprPseudonym::random()),
367 PidBalancerController::default(),
368 surb_estimator.clone(),
369 controller,
370 SurbBalancerConfig {
371 target_surb_buffer_size: 5_000,
372 max_surbs_per_sec: 2500,
373 ..Default::default()
374 },
375 );
376
377 let mut last_update = 0;
378 for i in 0..steps {
379 std::thread::sleep(step_duration);
380 surb_estimator.produced.fetch_add(
381 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
382 std::sync::atomic::Ordering::Relaxed,
383 );
384 surb_estimator.consumed.fetch_add(
385 consumption_rate * step_duration.as_secs(),
386 std::sync::atomic::Ordering::Relaxed,
387 );
388
389 let next_update = balancer.update(None);
390 assert!(
391 i == 0 || next_update > last_update,
392 "{next_update} should be greater than {last_update}"
393 );
394 last_update = next_update;
395 }
396 }
397
398 #[test_log::test]
399 fn surb_balancer_should_start_decrease_level_when_above_target() {
400 let production_rate = Arc::new(AtomicU64::new(11_000));
401 let consumption_rate = 100;
402 let steps = 3;
403 let step_duration = std::time::Duration::from_millis(1000);
404
405 let mut controller = MockSurbFlowController::new();
406 let production_rate_clone = production_rate.clone();
407 controller
408 .expect_adjust_surb_flow()
409 .times(steps)
410 .with(mockall::predicate::ge(0))
411 .returning(move |r| {
412 production_rate_clone.store(r as u64, std::sync::atomic::Ordering::Relaxed);
413 });
414
415 let surb_estimator = AtomicSurbFlowEstimator::default();
416 let mut balancer = SurbBalancer::new(
417 SessionId::new(1234_u64, HoprPseudonym::random()),
418 PidBalancerController::default(),
419 surb_estimator.clone(),
420 controller,
421 SurbBalancerConfig::default(),
422 );
423
424 let mut last_update = 0;
425 for i in 0..steps {
426 std::thread::sleep(step_duration);
427 surb_estimator.produced.fetch_add(
428 production_rate.load(std::sync::atomic::Ordering::Relaxed) * step_duration.as_secs(),
429 std::sync::atomic::Ordering::Relaxed,
430 );
431 surb_estimator.consumed.fetch_add(
432 consumption_rate * step_duration.as_secs(),
433 std::sync::atomic::Ordering::Relaxed,
434 );
435
436 let next_update = balancer.update(None);
437 assert!(
438 i == 0 || next_update < last_update,
439 "{next_update} should be greater than {last_update}"
440 );
441 last_update = next_update;
442 }
443 }
444
445 #[test_log::test(tokio::test)]
446 async fn surb_balancer_should_start_decrease_level_when_above_target_and_decay_enabled() {
447 const NUM_STEPS: usize = 5;
448 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
449 let cfg = SurbBalancerConfig {
450 target_surb_buffer_size: 5_000,
451 max_surbs_per_sec: 2500,
452 surb_decay: Some((Duration::from_millis(200), 0.05)),
453 ..Default::default()
454 };
455
456 let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
457 mock_balancer_feedback
458 .expect_get_config()
459 .with(mockall::predicate::eq(session_id))
460 .times(NUM_STEPS)
461 .returning(move |_| Ok(cfg));
462
463 mock_balancer_feedback.expect_on_config_update().never();
464
465 let mut mock_flow_ctl = MockSurbFlowController::new();
466 mock_flow_ctl
467 .expect_adjust_surb_flow()
468 .times(NUM_STEPS)
469 .returning(|_| ());
470
471 let mut balancer = SurbBalancer::new(
472 session_id,
473 PidBalancerController::default(),
474 SimpleSurbFlowEstimator::default(),
475 mock_flow_ctl,
476 cfg,
477 );
478
479 balancer.current_buffer = 5000;
480
481 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
482 let levels = stream.take(NUM_STEPS).collect::<Vec<_>>().await;
483 handle.abort();
484
485 assert_eq!(levels, vec![5000, 4750, 4750, 4500, 4500]);
486 }
487
488 struct IterSurbFlowEstimator<P, C>(std::sync::Mutex<P>, std::sync::Mutex<C>);
489
490 impl<P, C> IterSurbFlowEstimator<P, C> {
491 fn new<I1, I2>(production: I1, consumption: I2) -> Self
492 where
493 I1: IntoIterator<Item = u64, IntoIter = P>,
494 I2: IntoIterator<Item = u64, IntoIter = C>,
495 {
496 Self(
497 std::sync::Mutex::new(production.into_iter()),
498 std::sync::Mutex::new(consumption.into_iter()),
499 )
500 }
501 }
502
503 impl<P, C> SurbFlowEstimator for IterSurbFlowEstimator<P, C>
504 where
505 P: Iterator<Item = u64>,
506 C: Iterator<Item = u64>,
507 {
508 fn estimate_surbs_consumed(&self) -> u64 {
509 self.1.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
510 }
511
512 fn estimate_surbs_produced(&self) -> u64 {
513 self.0.lock().ok().and_then(|mut it| it.next()).unwrap_or(0)
514 }
515 }
516
517 #[test_log::test(tokio::test)]
518 async fn surb_balancer_should_increase_target_when_using_simple_controller() {
519 let session_id = SessionId::new(1234_u64, HoprPseudonym::random());
520 let cfg_1 = SurbBalancerConfig {
521 target_surb_buffer_size: 4500,
522 max_surbs_per_sec: 2500,
523 surb_decay: None,
524 ..Default::default()
525 };
526
527 let cfg_2 = SurbBalancerConfig {
528 target_surb_buffer_size: 5500,
529 max_surbs_per_sec: 3055,
530 surb_decay: None,
531 ..Default::default()
532 };
533
534 let mut seq = mockall::Sequence::new();
535
536 let mut mock_balancer_feedback = MockBalancerConfigFeedback::new();
537 mock_balancer_feedback
538 .expect_get_config()
539 .times(3)
540 .in_sequence(&mut seq)
541 .with(mockall::predicate::eq(session_id))
542 .returning(move |_| {
543 tracing::trace!("get config 1");
544 Ok(cfg_1)
545 });
546
547 mock_balancer_feedback
548 .expect_on_config_update()
549 .once()
550 .in_sequence(&mut seq)
551 .with(mockall::predicate::eq(session_id), mockall::predicate::eq(cfg_2))
552 .returning(|_, _| {
553 tracing::trace!("on config update");
554 Ok(())
555 });
556
557 mock_balancer_feedback
558 .expect_get_config()
559 .times(2)
560 .in_sequence(&mut seq)
561 .with(mockall::predicate::eq(session_id))
562 .returning(move |_| {
563 tracing::trace!("get config 2");
564 Ok(cfg_2)
565 });
566
567 let mut mock_flow_ctl = MockSurbFlowController::new();
568 mock_flow_ctl.expect_adjust_surb_flow().times(5).returning(|_| ());
569
570 let mut balancer = SurbBalancer::new(
571 session_id,
572 SimpleBalancerController::with_increasing_setpoint(0.2, 5),
573 IterSurbFlowEstimator::new([1000, 2000, 3000, 3000, 3000], vec![500, 1000, 1500, 1500, 1500]),
574 mock_flow_ctl,
575 cfg_1,
576 );
577
578 balancer.current_buffer = 4500;
579
580 let (stream, handle) = balancer.start_control_loop(Duration::from_millis(100), mock_balancer_feedback, None);
581
582 let levels = stream.take(5).collect::<Vec<_>>().await;
583 handle.abort();
584
585 assert_eq!(levels, vec![5000, 5500, 6000, 6000, 6000]);
586 }
587}