1use std::{
2 cmp::Reverse,
3 collections::BinaryHeap,
4 future::poll_fn,
5 sync::{
6 Arc, Mutex,
7 atomic::{AtomicBool, AtomicUsize, Ordering},
8 },
9 task::Poll,
10 time::Duration,
11};
12
13use futures::{FutureExt, SinkExt, Stream, StreamExt};
14use futures_timer::Delay;
15use tracing::{error, trace};
16
17use crate::{config::MixerConfig, data::DelayedData};
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21 pub static ref METRIC_QUEUE_SIZE: hopr_metrics::SimpleGauge =
22 hopr_metrics::SimpleGauge::new("hopr_mixer_queue_size", "Current mixer queue size").unwrap();
23 pub static ref METRIC_MIXER_AVERAGE_DELAY: hopr_metrics::SimpleGauge = hopr_metrics::SimpleGauge::new(
24 "hopr_mixer_average_packet_delay",
25 "Average mixer packet delay averaged over a packet window"
26 )
27 .unwrap();
28}
29
30struct Channel<T> {
46 buffer: BinaryHeap<Reverse<DelayedData<T>>>,
48 timer: futures_timer::Delay,
49 waker: Option<std::task::Waker>,
50 cfg: MixerConfig,
51}
52
53struct TrackedChannel<T> {
55 channel: Arc<Mutex<Channel<T>>>,
56 sender_count: Arc<AtomicUsize>,
57 receiver_active: Arc<AtomicBool>,
58}
59
60impl<T> Clone for TrackedChannel<T> {
61 fn clone(&self) -> Self {
62 Self {
63 channel: self.channel.clone(),
64 sender_count: self.sender_count.clone(),
65 receiver_active: self.receiver_active.clone(),
66 }
67 }
68}
69
70#[derive(Clone, Debug, thiserror::Error)]
72pub enum SenderError {
73 #[error("Channel is closed")]
75 Closed,
76
77 #[error("Channel lock failed")]
79 Lock,
80}
81
82pub struct Sender<T> {
84 channel: TrackedChannel<T>,
85}
86
87impl<T> Clone for Sender<T> {
88 fn clone(&self) -> Self {
89 let channel = self.channel.clone();
90 channel.sender_count.fetch_add(1, Ordering::Relaxed);
91
92 Sender { channel }
93 }
94}
95
96impl<T> Drop for Sender<T> {
97 fn drop(&mut self) {
98 if self.channel.sender_count.fetch_sub(1, Ordering::Relaxed) == 1
99 && !self.channel.receiver_active.load(Ordering::Relaxed)
100 {
101 let mut channel = self.channel.channel.lock().unwrap_or_else(|e| {
102 self.channel.channel.clear_poison();
103 e.into_inner()
104 });
105
106 channel.waker = None;
107 }
108 }
109}
110
111impl<T> Sender<T> {
112 pub fn send(&self, item: T) -> Result<(), SenderError> {
114 let mut sender = self.clone();
115 sender.start_send_unpin(item)
116 }
117}
118
119impl<T> futures::sink::Sink<T> for Sender<T> {
120 type Error = SenderError;
121
122 fn poll_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
123 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
124 if is_active {
125 Poll::Ready(Ok(()))
126 } else {
127 Poll::Ready(Err(SenderError::Closed))
128 }
129 }
130
131 #[tracing::instrument(level = "trace", skip(self, item))]
132 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
133 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
134
135 if is_active {
136 let mut channel = self.channel.channel.lock().map_err(|_| SenderError::Lock)?;
137
138 let random_delay = channel.cfg.random_delay();
139
140 trace!(delay_in_ms = random_delay.as_millis(), "generated mixer delay",);
141
142 let delayed_data: DelayedData<T> = (std::time::Instant::now() + random_delay, item).into();
143 channel.buffer.push(Reverse(delayed_data));
144
145 if let Some(waker) = channel.waker.as_ref() {
146 waker.wake_by_ref();
147 }
148
149 #[cfg(all(feature = "prometheus", not(test)))]
150 {
151 METRIC_QUEUE_SIZE.increment(1.0f64);
152
153 let weight = 1.0f64 / channel.cfg.metric_delay_window as f64;
154 METRIC_MIXER_AVERAGE_DELAY.set(
155 (weight * random_delay.as_millis() as f64) + ((1.0f64 - weight) * METRIC_MIXER_AVERAGE_DELAY.get()),
156 );
157 }
158
159 Ok(())
160 } else {
161 Err(SenderError::Closed)
162 }
163 }
164
165 fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
166 Poll::Ready(Ok(()))
167 }
168
169 fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
170 Poll::Ready(Ok(()))
172 }
173}
174
175#[derive(Debug, thiserror::Error)]
177pub enum ReceiverError {
178 #[error("Channel is closed")]
180 Closed,
181
182 #[error("Channel lock failed")]
184 Lock,
185}
186
187pub struct Receiver<T> {
192 channel: TrackedChannel<T>,
193}
194
195impl<T> Stream for Receiver<T> {
196 type Item = T;
197
198 #[tracing::instrument(level = "trace", skip(self, cx))]
199 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
200 let now = std::time::Instant::now();
201 if self.channel.sender_count.load(Ordering::Relaxed) > 0 {
202 let Ok(mut channel) = self.channel.channel.lock() else {
203 error!("mutex is poisoned, terminating stream");
204 return Poll::Ready(None);
205 };
206
207 if channel.buffer.peek().map(|x| x.0.release_at < now).unwrap_or(false) {
208 let data = channel
209 .buffer
210 .pop()
211 .expect("The value should be present within the same locked access")
212 .0
213 .item;
214
215 trace!(from = "direct", "yield item");
216
217 #[cfg(all(feature = "prometheus", not(test)))]
218 METRIC_QUEUE_SIZE.decrement(1.0f64);
219
220 return Poll::Ready(Some(data));
221 }
222
223 if let Some(waker) = channel.waker.as_mut() {
224 waker.clone_from(cx.waker());
225 } else {
226 let waker = cx.waker().clone();
227 channel.waker = Some(waker);
228 }
229
230 if let Some(next) = channel.buffer.peek() {
231 let remaining = next.0.release_at.duration_since(now);
232
233 trace!("reseting the timer");
234 channel.timer.reset(remaining);
235
236 futures::ready!(channel.timer.poll_unpin(cx));
237
238 trace!(from = "timer", "yield item");
239
240 #[cfg(all(feature = "prometheus", not(test)))]
241 METRIC_QUEUE_SIZE.decrement(1.0f64);
242
243 return Poll::Ready(Some(
244 channel
245 .buffer
246 .pop()
247 .expect("The value should be present within the locked access")
248 .0
249 .item,
250 ));
251 }
252
253 trace!(from = "direct", "pending");
254 Poll::Pending
255 } else {
256 self.channel.receiver_active.store(false, Ordering::Relaxed);
257 Poll::Ready(None)
258 }
259 }
260}
261
262impl<T> Receiver<T> {
263 pub async fn recv(&mut self) -> Option<T> {
265 poll_fn(|cx| self.poll_next_unpin(cx)).await
266 }
267}
268
269pub fn channel<T>(cfg: crate::config::MixerConfig) -> (Sender<T>, Receiver<T>) {
271 #[cfg(all(feature = "prometheus", not(test)))]
272 {
273 lazy_static::initialize(&METRIC_QUEUE_SIZE);
275 lazy_static::initialize(&METRIC_MIXER_AVERAGE_DELAY);
276 }
277
278 let mut buffer = BinaryHeap::new();
279 buffer.reserve(cfg.capacity);
280
281 let channel = TrackedChannel {
282 channel: Arc::new(Mutex::new(Channel::<T> {
283 buffer,
284 timer: Delay::new(Duration::from_secs(0)),
285 waker: None,
286 cfg,
287 })),
288 sender_count: Arc::new(AtomicUsize::new(1)),
289 receiver_active: Arc::new(AtomicBool::new(true)),
290 };
291 (
292 Sender {
293 channel: channel.clone(),
294 },
295 Receiver { channel },
296 )
297}
298
299#[cfg(test)]
300mod tests {
301 use futures::StreamExt;
302 use tokio::time::timeout;
303
304 use super::*;
305
306 const PROCESSING_LEEWAY: Duration = Duration::from_millis(20);
307 const MAXIMUM_SINGLE_DELAY_DURATION: Duration = Duration::from_millis(
308 crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS + crate::config::HOPR_MIXER_DEFAULT_DELAY_RANGE_IN_MS,
309 );
310
311 #[tokio::test]
312 async fn mixer_channel_should_pass_an_element() -> anyhow::Result<()> {
313 let (tx, mut rx) = channel(MixerConfig::default());
314 tx.send(1)?;
315 assert_eq!(rx.recv().await, Some(1));
316
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn mixer_channel_should_introduce_random_delay() -> anyhow::Result<()> {
322 let start = std::time::SystemTime::now();
323
324 let (tx, mut rx) = channel(MixerConfig::default());
325 tx.send(1)?;
326 assert_eq!(rx.recv().await, Some(1));
327
328 let elapsed = start.elapsed()?;
329
330 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
331 assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
332 Ok(())
333 }
334
335 #[tokio::test]
336 async fn mixer_channel_should_batch_on_sending_emulating_concurrency() -> anyhow::Result<()> {
338 const ITERATIONS: usize = 10;
339
340 let (tx, mut rx) = channel(MixerConfig::default());
341
342 let start = std::time::SystemTime::now();
343
344 for i in 0..ITERATIONS {
345 tx.send(i)?;
346 }
347 for _ in 0..ITERATIONS {
348 let data = timeout(MAXIMUM_SINGLE_DELAY_DURATION, rx.next()).await?;
349 assert!(data.is_some());
350 }
351
352 let elapsed = start.elapsed()?;
353
354 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
355 assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
356 Ok(())
357 }
358
359 #[tokio::test]
360 async fn mixer_channel_should_work_concurrently_and_properly_closed_channels() -> anyhow::Result<()> {
362 const ITERATIONS: usize = 1000;
363
364 let (tx, mut rx) = channel(MixerConfig::default());
365
366 let recv_task = tokio::task::spawn(async move {
367 while let Some(_item) = timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION, rx.next())
368 .await
369 .expect("receiver should not fail")
370 {}
371 });
372
373 let send_task =
374 tokio::task::spawn(async move { futures::stream::iter(0..ITERATIONS).map(Ok).forward(tx).await });
375
376 let (_recv, send) = futures::try_join!(
377 timeout(MAXIMUM_SINGLE_DELAY_DURATION, recv_task),
378 timeout(MAXIMUM_SINGLE_DELAY_DURATION, send_task)
379 )?;
380
381 send??;
382
383 Ok(())
384 }
385
386 #[tokio::test]
387 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_sync_send() -> anyhow::Result<()> {
389 const ITERATIONS: usize = 20; let (tx, rx) = channel(MixerConfig::default());
392
393 let input = (0..ITERATIONS).collect::<Vec<_>>();
394
395 for i in input.iter() {
396 tx.send(*i)?;
397 }
398
399 let mixed_output = timeout(
400 2 * MAXIMUM_SINGLE_DELAY_DURATION,
401 rx.take(ITERATIONS).collect::<Vec<_>>(),
402 )
403 .await?;
404
405 tracing::info!(?input, ?mixed_output, "asserted data");
406 assert_ne!(input, mixed_output);
407 Ok(())
408 }
409
410 #[tokio::test]
411 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_send() -> anyhow::Result<()>
413 {
414 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
417
418 let input = (0..ITERATIONS).collect::<Vec<_>>();
419
420 for i in input.iter() {
421 SinkExt::send(&mut tx, *i).await?;
422 }
423
424 let mixed_output = timeout(
425 2 * MAXIMUM_SINGLE_DELAY_DURATION,
426 rx.take(ITERATIONS).collect::<Vec<_>>(),
427 )
428 .await?;
429
430 tracing::info!(?input, ?mixed_output, "asserted data");
431 assert_ne!(input, mixed_output);
432 Ok(())
433 }
434
435 #[tokio::test]
436 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_feed() -> anyhow::Result<()>
438 {
439 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
442
443 let input = (0..ITERATIONS).collect::<Vec<_>>();
444
445 for i in input.iter() {
446 SinkExt::feed(&mut tx, *i).await?;
447 }
448 SinkExt::flush(&mut tx).await?;
449
450 let mixed_output = timeout(
451 2 * MAXIMUM_SINGLE_DELAY_DURATION,
452 rx.take(ITERATIONS).collect::<Vec<_>>(),
453 )
454 .await?;
455
456 tracing::info!(?input, ?mixed_output, "asserted data");
457 assert_ne!(input, mixed_output);
458 Ok(())
459 }
460
461 #[tokio::test]
462 async fn mixer_channel_should_not_mix_the_order_if_the_min_delay_and_delay_range_is_0() -> anyhow::Result<()> {
464 const ITERATIONS: usize = 40; let (tx, rx) = channel(MixerConfig {
467 min_delay: Duration::from_millis(0),
468 delay_range: Duration::from_millis(0),
469 ..MixerConfig::default()
470 });
471
472 let input = (0..ITERATIONS).collect::<Vec<_>>();
473
474 for i in input.iter() {
475 tx.send(*i)?;
476 tokio::time::sleep(std::time::Duration::from_micros(10)).await; }
478
479 let mixed_output = timeout(
480 2 * MAXIMUM_SINGLE_DELAY_DURATION,
481 rx.take(ITERATIONS).collect::<Vec<_>>(),
482 )
483 .await?;
484
485 tracing::info!(?input, ?mixed_output, "asserted data");
486 assert_eq!(input, mixed_output);
487
488 Ok(())
489 }
490}