hopr_transport_mixer/
channel.rs

1use std::{
2    cmp::Reverse,
3    collections::BinaryHeap,
4    future::poll_fn,
5    sync::{
6        Arc, Mutex,
7        atomic::{AtomicBool, AtomicUsize, Ordering},
8    },
9    task::Poll,
10    time::Duration,
11};
12
13use futures::{FutureExt, SinkExt, Stream, StreamExt};
14use futures_timer::Delay;
15use tracing::{error, trace};
16
17use crate::{config::MixerConfig, data::DelayedData};
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21    pub static ref METRIC_QUEUE_SIZE: hopr_metrics::SimpleGauge =
22        hopr_metrics::SimpleGauge::new("hopr_mixer_queue_size", "Current mixer queue size").unwrap();
23    pub static ref METRIC_MIXER_AVERAGE_DELAY: hopr_metrics::SimpleGauge = hopr_metrics::SimpleGauge::new(
24        "hopr_mixer_average_packet_delay",
25        "Average mixer packet delay averaged over a packet window"
26    )
27    .unwrap();
28}
29
30/// Mixing and delaying channel using random delay function.
31///
32/// Mixing is performed by assigning random delays to the ingress timestamp of data,
33/// then storing the values inside a binary heap with reversed ordering (max heap).
34/// This effectively creates a min heap behavior, which is required to ensure that
35/// data is released in order of their delay expiration.
36///
37/// When data arrives:
38/// 1. A random delay is assigned
39/// 2. Data is stored in the heap with its release timestamp
40/// 3. The heap maintains ordering so items with earliest release time are at the top
41///
42/// The channel uses a single timer thread that is instantiated on the first
43/// timer reset and shared across all operations. This channel is **unbounded** by nature
44/// using the `capacity` in the configuration to solely pre-allocate the buffer.
45struct Channel<T> {
46    /// Buffer holding the data with a timestamp ordering to ensure the min heap behavior.
47    buffer: BinaryHeap<Reverse<DelayedData<T>>>,
48    timer: futures_timer::Delay,
49    waker: Option<std::task::Waker>,
50    cfg: MixerConfig,
51}
52
53/// Channel with sender and receiver counters allowing closure tracking.
54struct TrackedChannel<T> {
55    channel: Arc<Mutex<Channel<T>>>,
56    sender_count: Arc<AtomicUsize>,
57    receiver_active: Arc<AtomicBool>,
58}
59
60impl<T> Clone for TrackedChannel<T> {
61    fn clone(&self) -> Self {
62        Self {
63            channel: self.channel.clone(),
64            sender_count: self.sender_count.clone(),
65            receiver_active: self.receiver_active.clone(),
66        }
67    }
68}
69
70/// Error returned by the [`Sender`].
71#[derive(Clone, Debug, thiserror::Error)]
72pub enum SenderError {
73    /// The channel is closed due to receiver being dropped.
74    #[error("Channel is closed")]
75    Closed,
76
77    /// The mutex lock over the channel failed.
78    #[error("Channel lock failed")]
79    Lock,
80}
81
82/// Sender object interacting with the mixing channel.
83pub struct Sender<T> {
84    channel: TrackedChannel<T>,
85}
86
87impl<T> Clone for Sender<T> {
88    fn clone(&self) -> Self {
89        let channel = self.channel.clone();
90        channel.sender_count.fetch_add(1, Ordering::Relaxed);
91
92        Sender { channel }
93    }
94}
95
96impl<T> Drop for Sender<T> {
97    fn drop(&mut self) {
98        if self.channel.sender_count.fetch_sub(1, Ordering::Relaxed) == 1
99            && !self.channel.receiver_active.load(Ordering::Relaxed)
100        {
101            let mut channel = self.channel.channel.lock().unwrap_or_else(|e| {
102                self.channel.channel.clear_poison();
103                e.into_inner()
104            });
105
106            channel.waker = None;
107        }
108    }
109}
110
111impl<T> Sender<T> {
112    /// Send one item to the mixing channel.
113    pub fn send(&self, item: T) -> Result<(), SenderError> {
114        let mut sender = self.clone();
115        sender.start_send_unpin(item)
116    }
117}
118
119impl<T> futures::sink::Sink<T> for Sender<T> {
120    type Error = SenderError;
121
122    fn poll_ready(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
123        let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
124        if is_active {
125            Poll::Ready(Ok(()))
126        } else {
127            Poll::Ready(Err(SenderError::Closed))
128        }
129    }
130
131    #[tracing::instrument(level = "trace", skip(self, item))]
132    fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
133        let is_active = self.channel.receiver_active.load(Ordering::Relaxed);
134
135        if is_active {
136            let mut channel = self.channel.channel.lock().map_err(|_| SenderError::Lock)?;
137
138            let random_delay = channel.cfg.random_delay();
139
140            trace!(delay_in_ms = random_delay.as_millis(), "generated mixer delay",);
141
142            let delayed_data: DelayedData<T> = (std::time::Instant::now() + random_delay, item).into();
143            channel.buffer.push(Reverse(delayed_data));
144
145            if let Some(waker) = channel.waker.as_ref() {
146                waker.wake_by_ref();
147            }
148
149            #[cfg(all(feature = "prometheus", not(test)))]
150            {
151                METRIC_QUEUE_SIZE.increment(1.0f64);
152
153                let weight = 1.0f64 / channel.cfg.metric_delay_window as f64;
154                METRIC_MIXER_AVERAGE_DELAY.set(
155                    (weight * random_delay.as_millis() as f64) + ((1.0f64 - weight) * METRIC_MIXER_AVERAGE_DELAY.get()),
156                );
157            }
158
159            Ok(())
160        } else {
161            Err(SenderError::Closed)
162        }
163    }
164
165    fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
166        Poll::Ready(Ok(()))
167    }
168
169    fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
170        // The channel can only be closed by the receiver. The sender can be dropped at any point.
171        Poll::Ready(Ok(()))
172    }
173}
174
175/// Error returned by the [`Receiver`].
176#[derive(Debug, thiserror::Error)]
177pub enum ReceiverError {
178    /// The channel is closed due to receiver being dropped.
179    #[error("Channel is closed")]
180    Closed,
181
182    /// The mutex lock over the channel failed.
183    #[error("Channel lock failed")]
184    Lock,
185}
186
187/// Receiver object interacting with the mixer channel.
188///
189/// The receiver receives already mixed elements without any knowledge of
190/// the original order.
191pub struct Receiver<T> {
192    channel: TrackedChannel<T>,
193}
194
195impl<T> Stream for Receiver<T> {
196    type Item = T;
197
198    #[tracing::instrument(level = "trace", skip(self, cx))]
199    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
200        let now = std::time::Instant::now();
201        if self.channel.sender_count.load(Ordering::Relaxed) > 0 {
202            let Ok(mut channel) = self.channel.channel.lock() else {
203                error!("mutex is poisoned, terminating stream");
204                return Poll::Ready(None);
205            };
206
207            if channel.buffer.peek().map(|x| x.0.release_at < now).unwrap_or(false) {
208                let data = channel
209                    .buffer
210                    .pop()
211                    .expect("The value should be present within the same locked access")
212                    .0
213                    .item;
214
215                trace!(from = "direct", "yield item");
216
217                #[cfg(all(feature = "prometheus", not(test)))]
218                METRIC_QUEUE_SIZE.decrement(1.0f64);
219
220                return Poll::Ready(Some(data));
221            }
222
223            if let Some(waker) = channel.waker.as_mut() {
224                waker.clone_from(cx.waker());
225            } else {
226                let waker = cx.waker().clone();
227                channel.waker = Some(waker);
228            }
229
230            if let Some(next) = channel.buffer.peek() {
231                let remaining = next.0.release_at.duration_since(now);
232
233                trace!("reseting the timer");
234                channel.timer.reset(remaining);
235
236                futures::ready!(channel.timer.poll_unpin(cx));
237
238                trace!(from = "timer", "yield item");
239
240                #[cfg(all(feature = "prometheus", not(test)))]
241                METRIC_QUEUE_SIZE.decrement(1.0f64);
242
243                return Poll::Ready(Some(
244                    channel
245                        .buffer
246                        .pop()
247                        .expect("The value should be present within the locked access")
248                        .0
249                        .item,
250                ));
251            }
252
253            trace!(from = "direct", "pending");
254            Poll::Pending
255        } else {
256            self.channel.receiver_active.store(false, Ordering::Relaxed);
257            Poll::Ready(None)
258        }
259    }
260}
261
262impl<T> Receiver<T> {
263    /// Receive a single delayed mixed item.
264    pub async fn recv(&mut self) -> Option<T> {
265        poll_fn(|cx| self.poll_next_unpin(cx)).await
266    }
267}
268
269/// Instantiate a mixing channel and return the sender and receiver end of the channel.
270pub fn channel<T>(cfg: crate::config::MixerConfig) -> (Sender<T>, Receiver<T>) {
271    #[cfg(all(feature = "prometheus", not(test)))]
272    {
273        // Initialize the lazy statics here
274        lazy_static::initialize(&METRIC_QUEUE_SIZE);
275        lazy_static::initialize(&METRIC_MIXER_AVERAGE_DELAY);
276    }
277
278    let mut buffer = BinaryHeap::new();
279    buffer.reserve(cfg.capacity);
280
281    let channel = TrackedChannel {
282        channel: Arc::new(Mutex::new(Channel::<T> {
283            buffer,
284            timer: Delay::new(Duration::from_secs(0)),
285            waker: None,
286            cfg,
287        })),
288        sender_count: Arc::new(AtomicUsize::new(1)),
289        receiver_active: Arc::new(AtomicBool::new(true)),
290    };
291    (
292        Sender {
293            channel: channel.clone(),
294        },
295        Receiver { channel },
296    )
297}
298
299#[cfg(test)]
300mod tests {
301    use futures::StreamExt;
302    use tokio::time::timeout;
303
304    use super::*;
305
306    const PROCESSING_LEEWAY: Duration = Duration::from_millis(20);
307    const MAXIMUM_SINGLE_DELAY_DURATION: Duration = Duration::from_millis(
308        crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS + crate::config::HOPR_MIXER_DEFAULT_DELAY_RANGE_IN_MS,
309    );
310
311    #[tokio::test]
312    async fn mixer_channel_should_pass_an_element() -> anyhow::Result<()> {
313        let (tx, mut rx) = channel(MixerConfig::default());
314        tx.send(1)?;
315        assert_eq!(rx.recv().await, Some(1));
316
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn mixer_channel_should_introduce_random_delay() -> anyhow::Result<()> {
322        let start = std::time::SystemTime::now();
323
324        let (tx, mut rx) = channel(MixerConfig::default());
325        tx.send(1)?;
326        assert_eq!(rx.recv().await, Some(1));
327
328        let elapsed = start.elapsed()?;
329
330        assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
331        assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
332        Ok(())
333    }
334
335    #[tokio::test]
336    // #[tracing_test::traced_test]
337    async fn mixer_channel_should_batch_on_sending_emulating_concurrency() -> anyhow::Result<()> {
338        const ITERATIONS: usize = 10;
339
340        let (tx, mut rx) = channel(MixerConfig::default());
341
342        let start = std::time::SystemTime::now();
343
344        for i in 0..ITERATIONS {
345            tx.send(i)?;
346        }
347        for _ in 0..ITERATIONS {
348            let data = timeout(MAXIMUM_SINGLE_DELAY_DURATION, rx.next()).await?;
349            assert!(data.is_some());
350        }
351
352        let elapsed = start.elapsed()?;
353
354        assert!(elapsed < MAXIMUM_SINGLE_DELAY_DURATION + PROCESSING_LEEWAY);
355        assert!(elapsed > Duration::from_millis(crate::config::HOPR_MIXER_MINIMUM_DEFAULT_DELAY_IN_MS));
356        Ok(())
357    }
358
359    #[tokio::test]
360    // #[tracing_test::traced_test]
361    async fn mixer_channel_should_work_concurrently_and_properly_closed_channels() -> anyhow::Result<()> {
362        const ITERATIONS: usize = 1000;
363
364        let (tx, mut rx) = channel(MixerConfig::default());
365
366        let recv_task = tokio::task::spawn(async move {
367            while let Some(_item) = timeout(2 * MAXIMUM_SINGLE_DELAY_DURATION, rx.next())
368                .await
369                .expect("receiver should not fail")
370            {}
371        });
372
373        let send_task =
374            tokio::task::spawn(async move { futures::stream::iter(0..ITERATIONS).map(Ok).forward(tx).await });
375
376        let (_recv, send) = futures::try_join!(
377            timeout(MAXIMUM_SINGLE_DELAY_DURATION, recv_task),
378            timeout(MAXIMUM_SINGLE_DELAY_DURATION, send_task)
379        )?;
380
381        send??;
382
383        Ok(())
384    }
385
386    #[tokio::test]
387    // #[tracing_test::traced_test]
388    async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_sync_send() -> anyhow::Result<()> {
389        const ITERATIONS: usize = 20; // highly unlikely that this produces the same order on the input given the size
390
391        let (tx, rx) = channel(MixerConfig::default());
392
393        let input = (0..ITERATIONS).collect::<Vec<_>>();
394
395        for i in input.iter() {
396            tx.send(*i)?;
397        }
398
399        let mixed_output = timeout(
400            2 * MAXIMUM_SINGLE_DELAY_DURATION,
401            rx.take(ITERATIONS).collect::<Vec<_>>(),
402        )
403        .await?;
404
405        tracing::info!(?input, ?mixed_output, "asserted data");
406        assert_ne!(input, mixed_output);
407        Ok(())
408    }
409
410    #[tokio::test]
411    // #[tracing_test::traced_test]
412    async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_send() -> anyhow::Result<()>
413    {
414        const ITERATIONS: usize = 20; // highly unlikely that this produces the same order on the input given the size
415
416        let (mut tx, rx) = channel(MixerConfig::default());
417
418        let input = (0..ITERATIONS).collect::<Vec<_>>();
419
420        for i in input.iter() {
421            SinkExt::send(&mut tx, *i).await?;
422        }
423
424        let mixed_output = timeout(
425            2 * MAXIMUM_SINGLE_DELAY_DURATION,
426            rx.take(ITERATIONS).collect::<Vec<_>>(),
427        )
428        .await?;
429
430        tracing::info!(?input, ?mixed_output, "asserted data");
431        assert_ne!(input, mixed_output);
432        Ok(())
433    }
434
435    #[tokio::test]
436    // #[tracing_test::traced_test]
437    async fn mixer_channel_should_produce_mixed_output_from_the_supplied_input_using_async_feed() -> anyhow::Result<()>
438    {
439        const ITERATIONS: usize = 20; // highly unlikely that this produces the same order on the input given the size
440
441        let (mut tx, rx) = channel(MixerConfig::default());
442
443        let input = (0..ITERATIONS).collect::<Vec<_>>();
444
445        for i in input.iter() {
446            SinkExt::feed(&mut tx, *i).await?;
447        }
448        SinkExt::flush(&mut tx).await?;
449
450        let mixed_output = timeout(
451            2 * MAXIMUM_SINGLE_DELAY_DURATION,
452            rx.take(ITERATIONS).collect::<Vec<_>>(),
453        )
454        .await?;
455
456        tracing::info!(?input, ?mixed_output, "asserted data");
457        assert_ne!(input, mixed_output);
458        Ok(())
459    }
460
461    #[tokio::test]
462    // #[tracing_test::traced_test]
463    async fn mixer_channel_should_not_mix_the_order_if_the_min_delay_and_delay_range_is_0() -> anyhow::Result<()> {
464        const ITERATIONS: usize = 40; // highly unlikely that this produces the same order on the input given the size
465
466        let (tx, rx) = channel(MixerConfig {
467            min_delay: Duration::from_millis(0),
468            delay_range: Duration::from_millis(0),
469            ..MixerConfig::default()
470        });
471
472        let input = (0..ITERATIONS).collect::<Vec<_>>();
473
474        for i in input.iter() {
475            tx.send(*i)?;
476            tokio::time::sleep(std::time::Duration::from_micros(10)).await; // ensure we don't send too fast
477        }
478
479        let mixed_output = timeout(
480            2 * MAXIMUM_SINGLE_DELAY_DURATION,
481            rx.take(ITERATIONS).collect::<Vec<_>>(),
482        )
483        .await?;
484
485        tracing::info!(?input, ?mixed_output, "asserted data");
486        assert_eq!(input, mixed_output);
487
488        Ok(())
489    }
490}