1use futures::{FutureExt, SinkExt, Stream, StreamExt};
2use futures_timer::Delay;
3use std::{
4 cmp::Reverse,
5 collections::BinaryHeap,
6 future::poll_fn,
7 sync::{
8 atomic::{AtomicBool, AtomicUsize, Ordering},
9 Arc, Mutex,
10 },
11 task::Poll,
12 time::Duration,
13};
14use tracing::{error, trace};
15
16use crate::{config::MixerConfig, data::DelayedData};
17
18#[cfg(all(feature = "prometheus", not(test)))]
19use hopr_metrics::metrics::SimpleGauge;
20
21#[cfg(all(feature = "prometheus", not(test)))]
22lazy_static::lazy_static! {
23 pub static ref METRIC_QUEUE_SIZE: SimpleGauge =
24 SimpleGauge::new("hopr_mixer_queue_size", "Current mixer queue size").unwrap();
25 pub static ref METRIC_MIXER_AVERAGE_DELAY: SimpleGauge = SimpleGauge::new(
26 "hopr_mixer_average_packet_delay",
27 "Average mixer packet delay averaged over a packet window"
28 )
29 .unwrap();
30}
31
32struct Channel<T> {
48 buffer: BinaryHeap<Reverse<DelayedData<T>>>,
50 timer: futures_timer::Delay,
51 waker: Option<std::task::Waker>,
52 cfg: MixerConfig,
53}
54
55struct TrackedChannel<T> {
57 channel: Arc<Mutex<Channel<T>>>,
58 sender_count: Arc<AtomicUsize>,
59 receiver_active: Arc<AtomicBool>,
60}
61
62impl<T> Clone for TrackedChannel<T> {
63 fn clone(&self) -> Self {
64 Self {
65 channel: self.channel.clone(),
66 sender_count: self.sender_count.clone(),
67 receiver_active: self.receiver_active.clone(),
68 }
69 }
70}
71
72#[derive(Debug, thiserror::Error)]
74pub enum SenderError {
75 #[error("Channel is closed")]
77 Closed,
78
79 #[error("Channel lock failed")]
81 Lock,
82}
83
84pub struct Sender<T> {
86 channel: TrackedChannel<T>,
87}
88
89impl<T> Clone for Sender<T> {
90 fn clone(&self) -> Self {
91 let channel = self.channel.clone();
92 channel.sender_count.fetch_add(1, Ordering::Relaxed);
93
94 Sender { channel }
95 }
96}
97
98impl<T> Drop for Sender<T> {
99 fn drop(&mut self) {
100 if self.channel.sender_count.fetch_sub(1, Ordering::Relaxed) == 1
101 && !self.channel.receiver_active.load(Ordering::Relaxed)
102 {
103 let mut channel = self.channel.channel.lock().unwrap_or_else(|e| {
104 self.channel.channel.clear_poison();
105 e.into_inner()
106 });
107
108 channel.waker = None;
109 }
110 }
111}
112
113impl<T> Sender<T> {
114 pub fn send(&self, item: T) -> Result<(), SenderError> {
116 let mut sender = self.clone();
117 sender.start_send_unpin(item)
118 }
119}
120
121impl<T> futures::sink::Sink<T> for Sender<T> {
122 type Error = SenderError;
123
124 fn poll_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
125 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
126 if is_active {
127 Poll::Ready(Ok(()))
128 } else {
129 Poll::Ready(Err(SenderError::Closed))
130 }
131 }
132
133 #[tracing::instrument(level = "trace", skip(self, item))]
134 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
135 let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
136
137 if is_active {
138 let mut channel = self.channel.channel.lock().map_err(|_| SenderError::Lock)?;
139
140 let random_delay = channel.cfg.random_delay();
141
142 trace!(delay_in_ms = random_delay.as_millis(), "generated mixer delay",);
143
144 let delayed_data: DelayedData<T> = (std::time::Instant::now() + random_delay, item).into();
145 channel.buffer.push(Reverse(delayed_data));
146
147 if let Some(waker) = channel.waker.as_ref() {
148 waker.wake_by_ref();
149 }
150
151 #[cfg(all(feature = "prometheus", not(test)))]
152 {
153 METRIC_QUEUE_SIZE.increment(1.0f64);
154
155 let weight = 1.0f64 / channel.cfg.metric_delay_window as f64;
156 METRIC_MIXER_AVERAGE_DELAY.set(
157 (weight * random_delay.as_millis() as f64) + ((1.0f64 - weight) * METRIC_MIXER_AVERAGE_DELAY.get()),
158 );
159 }
160
161 Ok(())
162 } else {
163 Err(SenderError::Closed)
164 }
165 }
166
167 fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
172 Poll::Ready(Ok(()))
174 }
175}
176
177#[derive(Debug, thiserror::Error)]
179pub enum ReceiverError {
180 #[error("Channel is closed")]
182 Closed,
183
184 #[error("Channel lock failed")]
186 Lock,
187}
188
189pub struct Receiver<T> {
194 channel: TrackedChannel<T>,
195}
196
197impl<T> Stream for Receiver<T> {
198 type Item = T;
199
200 #[tracing::instrument(level = "trace", skip(self, cx))]
201 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
202 let now = std::time::Instant::now();
203 if self.channel.sender_count.load(Ordering::Relaxed) > 0 {
204 let Ok(mut channel) = self.channel.channel.lock() else {
205 error!("mutex is poisoned, terminating stream");
206 return Poll::Ready(None);
207 };
208
209 if channel.buffer.peek().map(|x| x.0.release_at < now).unwrap_or(false) {
210 let data = channel
211 .buffer
212 .pop()
213 .expect("The value should be present within the same locked access")
214 .0
215 .item;
216
217 trace!(from = "direct", "yield item");
218
219 #[cfg(all(feature = "prometheus", not(test)))]
220 METRIC_QUEUE_SIZE.decrement(1.0f64);
221
222 return Poll::Ready(Some(data));
223 }
224
225 if let Some(waker) = channel.waker.as_mut() {
226 waker.clone_from(cx.waker());
227 } else {
228 let waker = cx.waker().clone();
229 channel.waker = Some(waker);
230 }
231
232 if let Some(next) = channel.buffer.peek() {
233 let remaining = next.0.release_at.duration_since(now);
234
235 trace!("reseting the timer");
236 channel.timer.reset(remaining);
237
238 futures::ready!(channel.timer.poll_unpin(cx));
239
240 trace!(from = "timer", "yield item");
241
242 #[cfg(all(feature = "prometheus", not(test)))]
243 METRIC_QUEUE_SIZE.decrement(1.0f64);
244
245 return Poll::Ready(Some(
246 channel
247 .buffer
248 .pop()
249 .expect("The value should be present within the locked access")
250 .0
251 .item,
252 ));
253 }
254
255 trace!(from = "direct", "pending");
256 Poll::Pending
257 } else {
258 self.channel.receiver_active.store(false, Ordering::Relaxed);
259 Poll::Ready(None)
260 }
261 }
262}
263
264impl<T> Receiver<T> {
265 pub async fn recv(&mut self) -> Option<T> {
267 poll_fn(|cx| self.poll_next_unpin(cx)).await
268 }
269}
270
271pub fn channel<T>(cfg: crate::config::MixerConfig) -> (Sender<T>, Receiver<T>) {
273 #[cfg(all(feature = "prometheus", not(test)))]
274 {
275 lazy_static::initialize(&METRIC_QUEUE_SIZE);
277 lazy_static::initialize(&METRIC_MIXER_AVERAGE_DELAY);
278 }
279
280 let mut buffer = BinaryHeap::new();
281 buffer.reserve(cfg.capacity);
282
283 let channel = TrackedChannel {
284 channel: Arc::new(Mutex::new(Channel::<T> {
285 buffer,
286 timer: Delay::new(Duration::from_secs(0)),
287 waker: None,
288 cfg,
289 })),
290 sender_count: Arc::new(AtomicUsize::new(1)),
291 receiver_active: Arc::new(AtomicBool::new(true)),
292 };
293 (
294 Sender {
295 channel: channel.clone(),
296 },
297 Receiver { channel },
298 )
299}
300
301#[cfg(test)]
302mod tests {
303 use async_std::prelude::FutureExt;
304 use futures::StreamExt;
305
306 use super::*;
307
308 const PROCESSING_LEEWAY: Duration = Duration::from_millis(20);
309 const MAXIMUM_SINGLE_DELAY_DURATION: Duration = Duration::from_millis(
310 crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS + crate::config::HOPR_MIXER_DEFAULT_DELAY_RANGE_IN_MS,
311 );
312
313 #[async_std::test]
314 async fn mixer_channel_should_pass_an_element() -> anyhow::Result<()> {
315 let (tx, mut rx) = channel(MixerConfig::default());
316 tx.send(1)?;
317 assert_eq!(rx.recv().await, Some(1));
318
319 Ok(())
320 }
321
322 #[async_std::test]
323 async fn mixer_channel_should_introduce_random_delay() -> anyhow::Result<()> {
324 let start = std::time::SystemTime::now();
325
326 let (tx, mut rx) = channel(MixerConfig::default());
327 tx.send(1)?;
328 assert_eq!(rx.recv().await, Some(1));
329
330 let elapsed = start.elapsed()?;
331
332 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
333 Ok(assert!(
334 elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS)
335 ))
336 }
337
338 #[async_std::test]
339 async fn mixer_channel_should_batch_on_sending_emulating_concurrency() -> anyhow::Result<()> {
341 const ITERATIONS: usize = 10;
342
343 let (tx, mut rx) = channel(MixerConfig::default());
344
345 let start = std::time::SystemTime::now();
346
347 for i in 0..ITERATIONS {
348 tx.send(i)?;
349 }
350 for _ in 0..ITERATIONS {
351 let data = rx.next().timeout(MAXIMUM_SINGLE_DELAY_DURATION).await?;
352 assert!(data.is_some());
353 }
354
355 let elapsed = start.elapsed()?;
356
357 assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
358 Ok(assert!(
359 elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS)
360 ))
361 }
362
363 #[async_std::test]
364 async fn mixer_channel_should_work_concurrently_and_properly_closed_channels() -> anyhow::Result<()> {
366 const ITERATIONS: usize = 1000;
367
368 let (tx, mut rx) = channel(MixerConfig::default());
369
370 let recv_task = async_std::task::spawn(async move {
371 while let Some(_item) = rx
372 .next()
373 .timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION)
374 .await
375 .expect("receiver should not fail")
376 {}
377 });
378
379 let send_task =
380 async_std::task::spawn(async move { futures::stream::iter(0..ITERATIONS).map(Ok).forward(tx).await });
381
382 let (_recv, send) = futures::try_join!(
383 recv_task.timeout(MAXIMUM_SINGLE_DELAY_DURATION),
384 send_task.timeout(MAXIMUM_SINGLE_DELAY_DURATION)
385 )?;
386
387 send?;
388
389 Ok(())
390 }
391
392 #[async_std::test]
393 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_sync_send() -> anyhow::Result<()> {
395 const ITERATIONS: usize = 20; let (tx, rx) = channel(MixerConfig::default());
398
399 let input = (0..ITERATIONS).collect::<Vec<_>>();
400
401 for i in input.iter() {
402 tx.send(*i)?;
403 }
404
405 let mixed_output = rx
406 .take(ITERATIONS)
407 .collect::<Vec<_>>()
408 .timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION)
409 .await?;
410
411 tracing::info!(?input, ?mixed_output, "asserted data");
412 Ok(assert_ne!(input, mixed_output))
413 }
414
415 #[async_std::test]
416 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_send() -> anyhow::Result<()>
418 {
419 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
422
423 let input = (0..ITERATIONS).collect::<Vec<_>>();
424
425 for i in input.iter() {
426 SinkExt::send(&mut tx, *i).await?;
427 }
428
429 let mixed_output = rx
430 .take(ITERATIONS)
431 .collect::<Vec<_>>()
432 .timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION)
433 .await?;
434
435 tracing::info!(?input, ?mixed_output, "asserted data");
436 Ok(assert_ne!(input, mixed_output))
437 }
438
439 #[async_std::test]
440 async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_feed() -> anyhow::Result<()>
442 {
443 const ITERATIONS: usize = 20; let (mut tx, rx) = channel(MixerConfig::default());
446
447 let input = (0..ITERATIONS).collect::<Vec<_>>();
448
449 for i in input.iter() {
450 SinkExt::feed(&mut tx, *i).await?;
451 }
452 SinkExt::flush(&mut tx).await?;
453
454 let mixed_output = rx
455 .take(ITERATIONS)
456 .collect::<Vec<_>>()
457 .timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION)
458 .await?;
459
460 tracing::info!(?input, ?mixed_output, "asserted data");
461 Ok(assert_ne!(input, mixed_output))
462 }
463
464 #[async_std::test]
465 async fn mixer_channel_should_not_mix_the_order_if_the_min_delay_and_delay_range_is_0() -> anyhow::Result<()> {
467 const ITERATIONS: usize = 20; let (tx, rx) = channel(MixerConfig {
470 min_delay: Duration::from_millis(0),
471 delay_range: Duration::from_millis(0),
472 ..MixerConfig::default()
473 });
474
475 let input = (0..ITERATIONS).collect::<Vec<_>>();
476
477 for i in input.iter() {
478 tx.send(*i)?;
479 }
480
481 let mixed_output = rx
482 .take(ITERATIONS)
483 .collect::<Vec<_>>()
484 .timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION)
485 .await?;
486
487 tracing::info!(?input, ?mixed_output, "asserted data");
488 Ok(assert_eq!(input, mixed_output))
489 }
490}