1use std::{
2 cmp::Reverse,
3 collections::BinaryHeap,
4 future::poll_fn,
5 sync::{
6 Arc, Mutex,
7 atomic::{AtomicBool, AtomicUsize, Ordering},
8 },
9 task::Poll,
10 time::Duration,
11};
12
13use futures::{FutureExt, SinkExt, Stream, StreamExt};
14use futures_timer::Delay;
15#[cfg(all(feature = "prometheus", not(test)))]
16use hopr_metrics::metrics::SimpleGauge;
17use tracing::{error, trace};
18
19use crate::{config::MixerConfig, data::DelayedData};
20
21#[cfg(all(feature = "prometheus", not(test)))]
22lazy_static::lazy_static! {
23 pub static ref METRIC_QUEUE_SIZE: SimpleGauge =
24 SimpleGauge::new("hopr_mixer_queue_size", "Current mixer queue size").unwrap();
25 pub static ref METRIC_MIXER_AVERAGE_DELAY: SimpleGauge = SimpleGauge::new(
26 "hopr_mixer_average_packet_delay",
27 "Average mixer packet delay averaged over a packet window"
28 )
29 .unwrap();
30}
31
32struct Channel<T> {
48 buffer: BinaryHeap<Reverse<DelayedData<T>>>,
50 timer: futures_timer::Delay,
51 waker: Option<std::task::Waker>,
52 cfg: MixerConfig,
53}
54
55struct TrackedChannel<T> {
57 channel: Arc<Mutex<Channel<T>>>,
58 sender_count: Arc<AtomicUsize>,
59 receiver_active: Arc<AtomicBool>,
60}
61
62impl<T> Clone for TrackedChannel<T> {
63 fn clone(&self) -> Self {
64 Self {
65 channel: self.channel.clone(),
66 sender_count: self.sender_count.clone(),
67 receiver_active: self.receiver_active.clone(),
68 }
69 }
70}
71
72#[derive(Debug, thiserror::Error)]
74pub enum SenderError {
75 #[error("Channel is closed")]
77 Closed,
78
79 #[error("Channel lock failed")]
81 Lock,
82}
83
84pub struct Sender<T> {
86 channel: TrackedChannel<T>,
87}
88
89impl<T> Clone for Sender<T> {
90 fn clone(&self) -> Self {
91 let channel = self.channel.clone();
92 channel.sender_count.fetch_add(1, Ordering::Relaxed);
93
94 Sender { channel }
95 }
96}
97
98impl<T> Drop for Sender<T> {
99 fn drop(&mut self) {
100 if self.channel.sender_count.fetch_sub(1, Ordering::Relaxed) == 1
101 && !self.channel.receiver_active.load(Ordering::Relaxed)
102 {
103 let mut channel = self.channel.channel.lock().unwrap_or_else(|e| {
104 self.channel.channel.clear_poison();
105 e.into_inner()
106 });
107
108 channel.waker = None;
109 }
110 }
111}
112
113impl<T> Sender<T> {
114 pub fn send(&self, item: T) -> Result<(), SenderError> {
116 let mut sender = self.clone();
117 sender.start_send_unpin(item)
118 }
119}
120
121impl<T> futures::sink::Sink<T> for Sender<T> {
122 type Error = SenderError;
123
124 fn poll_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
125 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
126 if is_active {
127 Poll::Ready(Ok(()))
128 } else {
129 Poll::Ready(Err(SenderError::Closed))
130 }
131 }
132
133 #[tracing::instrument(level = "trace", skip(self, item))]
134 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
135 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
136
137 if is_active {
138 let mut channel = self.channel.channel.lock().map_err(|_| SenderError::Lock)?;
139
140 let random_delay = channel.cfg.random_delay();
141
142 trace!(delay_in_ms = random_delay.as_millis(), "generated mixer delay",);
143
144 let delayed_data: DelayedData<T> = (std::time::Instant::now() + random_delay, item).into();
145 channel.buffer.push(Reverse(delayed_data));
146
147 if let Some(waker) = channel.waker.as_ref() {
148 waker.wake_by_ref();
149 }
150
151 #[cfg(all(feature = "prometheus", not(test)))]
152 {
153 METRIC_QUEUE_SIZE.increment(1.0f64);
154
155 let weight = 1.0f64 / channel.cfg.metric_delay_window as f64;
156 METRIC_MIXER_AVERAGE_DELAY.set(
157 (weight * random_delay.as_millis() as f64) + ((1.0f64 - weight) * METRIC_MIXER_AVERAGE_DELAY.get()),
158 );
159 }
160
161 Ok(())
162 } else {
163 Err(SenderError::Closed)
164 }
165 }
166
167 fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
172 Poll::Ready(Ok(()))
174 }
175}
176
177#[derive(Debug, thiserror::Error)]
179pub enum ReceiverError {
180 #[error("Channel is closed")]
182 Closed,
183
184 #[error("Channel lock failed")]
186 Lock,
187}
188
189pub struct Receiver<T> {
194 channel: TrackedChannel<T>,
195}
196
197impl<T> Stream for Receiver<T> {
198 type Item = T;
199
200 #[tracing::instrument(level = "trace", skip(self, cx))]
201 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
202 let now = std::time::Instant::now();
203 if self.channel.sender_count.load(Ordering::Relaxed) > 0 {
204 let Ok(mut channel) = self.channel.channel.lock() else {
205 error!("mutex is poisoned, terminating stream");
206 return Poll::Ready(None);
207 };
208
209 if channel.buffer.peek().map(|x| x.0.release_at < now).unwrap_or(false) {
210 let data = channel
211 .buffer
212 .pop()
213 .expect("The value should be present within the same locked access")
214 .0
215 .item;
216
217 trace!(from = "direct", "yield item");
218
219 #[cfg(all(feature = "prometheus", not(test)))]
220 METRIC_QUEUE_SIZE.decrement(1.0f64);
221
222 return Poll::Ready(Some(data));
223 }
224
225 if let Some(waker) = channel.waker.as_mut() {
226 waker.clone_from(cx.waker());
227 } else {
228 let waker = cx.waker().clone();
229 channel.waker = Some(waker);
230 }
231
232 if let Some(next) = channel.buffer.peek() {
233 let remaining = next.0.release_at.duration_since(now);
234
235 trace!("reseting the timer");
236 channel.timer.reset(remaining);
237
238 futures::ready!(channel.timer.poll_unpin(cx));
239
240 trace!(from = "timer", "yield item");
241
242 #[cfg(all(feature = "prometheus", not(test)))]
243 METRIC_QUEUE_SIZE.decrement(1.0f64);
244
245 return Poll::Ready(Some(
246 channel
247 .buffer
248 .pop()
249 .expect("The value should be present within the locked access")
250 .0
251 .item,
252 ));
253 }
254
255 trace!(from = "direct", "pending");
256 Poll::Pending
257 } else {
258 self.channel.receiver_active.store(false, Ordering::Relaxed);
259 Poll::Ready(None)
260 }
261 }
262}
263
264impl<T> Receiver<T> {
265 pub async fn recv(&mut self) -> Option<T> {
267 poll_fn(|cx| self.poll_next_unpin(cx)).await
268 }
269}
270
271pub fn channel<T>(cfg: crate::config::MixerConfig) -> (Sender<T>, Receiver<T>) {
273 #[cfg(all(feature = "prometheus", not(test)))]
274 {
275 lazy_static::initialize(&METRIC_QUEUE_SIZE);
277 lazy_static::initialize(&METRIC_MIXER_AVERAGE_DELAY);
278 }
279
280 let mut buffer = BinaryHeap::new();
281 buffer.reserve(cfg.capacity);
282
283 let channel = TrackedChannel {
284 channel: Arc::new(Mutex::new(Channel::<T> {
285 buffer,
286 timer: Delay::new(Duration::from_secs(0)),
287 waker: None,
288 cfg,
289 })),
290 sender_count: Arc::new(AtomicUsize::new(1)),
291 receiver_active: Arc::new(AtomicBool::new(true)),
292 };
293 (
294 Sender {
295 channel: channel.clone(),
296 },
297 Receiver { channel },
298 )
299}
300
301#[cfg(test)]
302mod tests {
303 use futures::StreamExt;
304 use tokio::time::timeout;
305
306 use super::*;
307
308 const PROCESSING_LEEWAY: Duration = Duration::from_millis(20);
309 const MAXIMUM_SINGLE_DELAY_DURATION: Duration = Duration::from_millis(
310 crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS + crate::config::HOPR_MIXER_DEFAULT_DELAY_RANGE_IN_MS,
311 );
312
313 #[tokio::test]
314 async fn mixer_channel_should_pass_an_element() -> anyhow::Result<()> {
315 let (tx, mut rx) = channel(MixerConfig::default());
316 tx.send(1)?;
317 assert_eq!(rx.recv().await, Some(1));
318
319 Ok(())
320 }
321
322 #[tokio::test]
323 async fn mixer_channel_should_introduce_random_delay() -> anyhow::Result<()> {
324 let start = std::time::SystemTime::now();
325
326 let (tx, mut rx) = channel(MixerConfig::default());
327 tx.send(1)?;
328 assert_eq!(rx.recv().await, Some(1));
329
330 let elapsed = start.elapsed()?;
331
332 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
333 assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
334 Ok(())
335 }
336
337 #[tokio::test]
338 async fn mixer_channel_should_batch_on_sending_emulating_concurrency() -> anyhow::Result<()> {
340 const ITERATIONS: usize = 10;
341
342 let (tx, mut rx) = channel(MixerConfig::default());
343
344 let start = std::time::SystemTime::now();
345
346 for i in 0..ITERATIONS {
347 tx.send(i)?;
348 }
349 for _ in 0..ITERATIONS {
350 let data = timeout(MAXIMUM_SINGLE_DELAY_DURATION, rx.next()).await?;
351 assert!(data.is_some());
352 }
353
354 let elapsed = start.elapsed()?;
355
356 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
357 assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
358 Ok(())
359 }
360
361 #[tokio::test]
362 async fn mixer_channel_should_work_concurrently_and_properly_closed_channels() -> anyhow::Result<()> {
364 const ITERATIONS: usize = 1000;
365
366 let (tx, mut rx) = channel(MixerConfig::default());
367
368 let recv_task = tokio::task::spawn(async move {
369 while let Some(_item) = timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION, rx.next())
370 .await
371 .expect("receiver should not fail")
372 {}
373 });
374
375 let send_task =
376 tokio::task::spawn(async move { futures::stream::iter(0..ITERATIONS).map(Ok).forward(tx).await });
377
378 let (_recv, send) = futures::try_join!(
379 timeout(MAXIMUM_SINGLE_DELAY_DURATION, recv_task),
380 timeout(MAXIMUM_SINGLE_DELAY_DURATION, send_task)
381 )?;
382
383 send??;
384
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_sync_send() -> anyhow::Result<()> {
391 const ITERATIONS: usize = 20; let (tx, rx) = channel(MixerConfig::default());
394
395 let input = (0..ITERATIONS).collect::<Vec<_>>();
396
397 for i in input.iter() {
398 tx.send(*i)?;
399 }
400
401 let mixed_output = timeout(
402 2 * MAXIMUM_SINGLE_DELAY_DURATION,
403 rx.take(ITERATIONS).collect::<Vec<_>>(),
404 )
405 .await?;
406
407 tracing::info!(?input, ?mixed_output, "asserted data");
408 assert_ne!(input, mixed_output);
409 Ok(())
410 }
411
412 #[tokio::test]
413 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_send() -> anyhow::Result<()>
415 {
416 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
419
420 let input = (0..ITERATIONS).collect::<Vec<_>>();
421
422 for i in input.iter() {
423 SinkExt::send(&mut tx, *i).await?;
424 }
425
426 let mixed_output = timeout(
427 2 * MAXIMUM_SINGLE_DELAY_DURATION,
428 rx.take(ITERATIONS).collect::<Vec<_>>(),
429 )
430 .await?;
431
432 tracing::info!(?input, ?mixed_output, "asserted data");
433 assert_ne!(input, mixed_output);
434 Ok(())
435 }
436
437 #[tokio::test]
438 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_feed() -> anyhow::Result<()>
440 {
441 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
444
445 let input = (0..ITERATIONS).collect::<Vec<_>>();
446
447 for i in input.iter() {
448 SinkExt::feed(&mut tx, *i).await?;
449 }
450 SinkExt::flush(&mut tx).await?;
451
452 let mixed_output = timeout(
453 2 * MAXIMUM_SINGLE_DELAY_DURATION,
454 rx.take(ITERATIONS).collect::<Vec<_>>(),
455 )
456 .await?;
457
458 tracing::info!(?input, ?mixed_output, "asserted data");
459 assert_ne!(input, mixed_output);
460 Ok(())
461 }
462
463 #[tokio::test]
464 async fn mixer_channel_should_not_mix_the_order_if_the_min_delay_and_delay_range_is_0() -> anyhow::Result<()> {
466 const ITERATIONS: usize = 40; let (tx, rx) = channel(MixerConfig {
469 min_delay: Duration::from_millis(0),
470 delay_range: Duration::from_millis(0),
471 ..MixerConfig::default()
472 });
473
474 let input = (0..ITERATIONS).collect::<Vec<_>>();
475
476 for i in input.iter() {
477 tx.send(*i)?;
478 tokio::time::sleep(std::time::Duration::from_micros(10)).await; }
480
481 let mixed_output = timeout(
482 2 * MAXIMUM_SINGLE_DELAY_DURATION,
483 rx.take(ITERATIONS).collect::<Vec<_>>(),
484 )
485 .await?;
486
487 tracing::info!(?input, ?mixed_output, "asserted data");
488 assert_eq!(input, mixed_output);
489
490 Ok(())
491 }
492}