Skip to main content

hopr_protocol_session/utils/
skip_queue.rs

1use std::{
2    cmp::Ordering,
3    collections::BTreeSet,
4    pin::Pin,
5    sync::{Arc, atomic::AtomicBool},
6    task::{Context, Poll, Waker},
7    time::{Duration, Instant},
8};
9
10use futures::FutureExt;
11use tracing::instrument;
12
13/// An internal type used by the [`SkipDelayQueue`].
14#[derive(Debug)]
15struct DelayedEntry<T> {
16    item: T,
17    at: Instant,
18    cancelled: AtomicBool,
19}
20
21// The entries are equal only if the items they carry are equal
22impl<T: PartialEq> PartialEq for DelayedEntry<T> {
23    fn eq(&self, other: &Self) -> bool {
24        self.item == other.item
25    }
26}
27
28impl<T: Eq> Eq for DelayedEntry<T> {}
29
30impl<T: Ord> Ord for DelayedEntry<T> {
31    fn cmp(&self, other: &Self) -> Ordering {
32        if other.item != self.item {
33            // If items are not equal, the order is determined by the deadline
34            match self.at.cmp(&other.at) {
35                // If the deadlines are equal, use the natural order of the items.
36                // This should be presumably consistent with their PartialEq and won't
37                // therefore return Ordering::Equal.
38                Ordering::Equal => self.item.cmp(&other.item),
39                x => x,
40            }
41        } else {
42            // Be consistent with PartialEq
43            Ordering::Equal
44        }
45    }
46}
47
48impl<T: Ord> PartialOrd for DelayedEntry<T> {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54/// Internal type used by the [`skip_delay_channel`].
55struct SkipDelayQueue<T> {
56    entries: BTreeSet<DelayedEntry<T>>,
57    next_wakeup: Option<futures_time::task::SleepUntil>,
58    rx_waker: Option<Waker>,
59    is_closed: bool,
60}
61
62/// An item with a deadline, which can be pushed into the [`SkipDelayQueue`].
63///
64/// For convenience, the type implements From traits from
65/// `(T, Instant)`, `(T, Duration)` and `(T, Skip)`.
66#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
67pub enum DelayedItem<T> {
68    /// Adds new (or replaces an existing) item with a deadline.
69    New(T, Instant),
70    /// Cancel a previously added item.
71    Cancel(T),
72}
73
74/// A marker type for canceling items pushed into the [`SkipDelayQueue`].
75#[derive(Debug, Copy, Clone, PartialEq, Eq)]
76pub struct Skip;
77
78impl<T> From<(T, Duration)> for DelayedItem<T> {
79    fn from(value: (T, Duration)) -> Self {
80        Self::New(value.0, Instant::now() + value.1)
81    }
82}
83
84impl<T> From<(T, Instant)> for DelayedItem<T> {
85    fn from(value: (T, Instant)) -> Self {
86        Self::New(value.0, value.1)
87    }
88}
89
90impl<T> From<(T, Skip)> for DelayedItem<T> {
91    fn from(value: (T, Skip)) -> Self {
92        Self::Cancel(value.0)
93    }
94}
95
96impl<T> SkipDelayQueue<T> {
97    const TOLERANCE: Duration = Duration::from_millis(5);
98
99    /// Creates a new instance.
100    ///
101    /// As a common practice, [`futures::StreamExt::split`] can be called to
102    /// get separate sending and receiving part of the queue.
103    pub fn new() -> Self {
104        Self {
105            entries: BTreeSet::new(),
106            next_wakeup: None,
107            rx_waker: None,
108            is_closed: false,
109        }
110    }
111}
112
113/// Receiver part for the [`skip_delay_channel`].
114pub struct SkipDelayReceiver<T>(Arc<std::sync::Mutex<SkipDelayQueue<T>>>);
115
116impl<T> Drop for SkipDelayReceiver<T> {
117    #[instrument(name = "SkipDelayReceiver::drop", level = "trace", skip(self))]
118    fn drop(&mut self) {
119        // When the receiver is dropped, clear the poison and mark the queue as closed.
120        self.0.clear_poison();
121        let mut queue = self.0.lock().expect("cannot panic because poison is cleared");
122        queue.is_closed = true;
123        queue.rx_waker = None;
124    }
125}
126
127impl<T: Ord> futures::Stream for SkipDelayReceiver<T> {
128    type Item = T;
129
130    #[instrument(name = "SkipDelayReceiver::poll_next", level = "trace", skip(self, cx))]
131    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132        let Ok(mut queue) = self.0.lock() else {
133            tracing::error!("poisoned mutex");
134            return Poll::Ready(None);
135        };
136
137        // Wait until the timer is done, if any
138        if let Some(next_wakeup) = queue.next_wakeup.as_mut() {
139            tracing::trace!("polling timer");
140            let _ = futures::ready!(next_wakeup.poll_unpin(cx));
141            queue.next_wakeup = None;
142        }
143
144        tracing::trace!("timer finished");
145
146        let now = Instant::now();
147        while let Some(e) = queue.entries.first() {
148            if !e.cancelled.load(std::sync::atomic::Ordering::SeqCst) {
149                return if e.at.saturating_duration_since(now) < SkipDelayQueue::<T>::TOLERANCE {
150                    // If the item is already in the past, yield it
151                    tracing::trace!("ready");
152                    Poll::Ready(queue.entries.pop_first().map(|e| e.item))
153                } else {
154                    // The next item is in the future, set up the timer and wake us up to start it
155                    tracing::trace!("pending new timer");
156                    queue.next_wakeup = Some(futures_time::task::sleep_until(e.at.into()));
157                    cx.waker().wake_by_ref();
158                    Poll::Pending
159                };
160            } else {
161                // If the item has been canceled, remove it and continue
162                queue.entries.pop_first();
163                tracing::trace!("item cancelled");
164            }
165        }
166
167        if !queue.is_closed {
168            // Need more data, wake us up when some are added
169            tracing::trace!("pending for data");
170            queue.rx_waker = Some(cx.waker().clone());
171            Poll::Pending
172        } else {
173            // We're done
174            Poll::Ready(None)
175        }
176    }
177}
178
179/// Sender part for the [`skip_delay_channel`].
180pub struct SkipDelaySender<T>(Option<Arc<std::sync::Mutex<SkipDelayQueue<T>>>>);
181
182impl<T> Clone for SkipDelaySender<T> {
183    fn clone(&self) -> Self {
184        Self(self.0.clone())
185    }
186}
187
188impl<T> SkipDelaySender<T> {
189    fn ensure_closure(&mut self) {
190        if let Some(queue) = self.0.take() {
191            let count_holders = Arc::strong_count(&queue);
192            tracing::trace!(count_holders, "ensure_closure");
193
194            // Check if the last holders are this instance and (potentially) the receiver
195            if count_holders == 2 {
196                Self::finalize_closure(queue);
197            }
198        }
199    }
200
201    fn finalize_closure(queue: Arc<std::sync::Mutex<SkipDelayQueue<T>>>) {
202        tracing::trace!("finalize_closure");
203        queue.clear_poison();
204        let mut queue = queue.lock().expect("cannot panic because poison is cleared");
205        queue.is_closed = true;
206        queue.rx_waker = None;
207    }
208
209    /// Forces closure of the queue (regardless of any remaining senders).
210    pub fn force_close(&mut self) {
211        if let Some(queue) = self.0.take() {
212            Self::finalize_closure(queue);
213        }
214    }
215}
216
217impl<T: Ord> SkipDelaySender<T> {
218    #[instrument(
219        name = "SkipDelaySender::send_internal",
220        level = "trace",
221        skip(self, items, flush),
222        ret
223    )]
224    fn send_internal<I: Iterator<Item = DelayedItem<T>>>(&self, items: I, flush: bool) -> Result<(), std::io::Error> {
225        if let Some(queue) = self.0.as_ref() {
226            let mut queue = queue.lock().map_err(|_| std::io::ErrorKind::BrokenPipe)?;
227
228            // This can happen only when the receiver was dropped.
229            if queue.is_closed {
230                return Err(std::io::ErrorKind::BrokenPipe.into());
231            }
232
233            for item in items {
234                match item {
235                    DelayedItem::New(item, at) => {
236                        tracing::trace!(at =  ?at.saturating_duration_since(Instant::now()), "inserting");
237                        queue.entries.replace(DelayedEntry {
238                            item,
239                            at,
240                            cancelled: AtomicBool::new(false),
241                        });
242                    }
243                    DelayedItem::Cancel(item) => {
244                        tracing::trace!("cancelling");
245                        queue
246                            .entries
247                            .iter()
248                            .filter(|e| item == e.item)
249                            .for_each(|e| e.cancelled.store(true, std::sync::atomic::Ordering::SeqCst));
250                    }
251                }
252            }
253
254            if flush {
255                tracing::trace!("flushing");
256                if let Some(waker) = queue.rx_waker.take() {
257                    waker.wake();
258                }
259            }
260
261            Ok(())
262        } else {
263            Err(std::io::ErrorKind::NotConnected.into())
264        }
265    }
266
267    /// Sends the given single item and flushes the queue.
268    pub fn send_one<I: Into<DelayedItem<T>>>(&mut self, item: I) -> Result<(), std::io::Error> {
269        self.send_internal(std::iter::once(item.into()), true)
270    }
271
272    /// Sends many items at once and then flushes the queue.
273    pub fn send_many<I: IntoIterator<Item = DelayedItem<T>>>(&mut self, items: I) -> Result<(), std::io::Error> {
274        self.send_internal(items.into_iter(), true)
275    }
276}
277
278impl<T> Drop for SkipDelaySender<T> {
279    fn drop(&mut self) {
280        self.ensure_closure();
281    }
282}
283
284impl<T: Ord> futures::Sink<DelayedItem<T>> for SkipDelaySender<T> {
285    type Error = std::io::Error;
286
287    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
288        if self.0.is_some() {
289            Poll::Ready(Ok(()))
290        } else {
291            Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()))
292        }
293    }
294
295    fn start_send(self: Pin<&mut Self>, item: DelayedItem<T>) -> Result<(), Self::Error> {
296        self.send_internal(std::iter::once(item), false)
297    }
298
299    #[instrument(name = "SkipDelaySender::poll_flush", level = "trace", skip(self), ret)]
300    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301        if let Some(queue) = self.0.as_ref() {
302            let Ok(mut queue) = queue.lock() else {
303                return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
304            };
305
306            tracing::trace!("flushing");
307            if let Some(waker) = queue.rx_waker.take() {
308                waker.wake();
309            }
310
311            Poll::Ready(Ok(()))
312        } else {
313            Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()))
314        }
315    }
316
317    fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318        if self.0.is_none() {
319            return Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()));
320        }
321
322        self.ensure_closure();
323        Poll::Ready(Ok(()))
324    }
325}
326
327/// A MPSC queue of [items](DelayedItem) with attached [`Instant`] that determines a deadline
328/// at which it should be yielded from the [`Stream`](futures::Stream) side of the queue.
329/// The queue also has the cancellation ability: an item that has been pushed into the
330/// queue earlier can be canceled before it meets its deadline.
331/// A canceled item will then be skipped in the output stream.
332///
333/// The items are internally sorted based on their deadline.
334/// In a case when two items have equal deadlines, they are sorted according
335/// to their values; therefore, items must implement [`Ord`].
336///
337/// If equal items are inserted, the deadline of the earlier one inserted is updated.
338pub fn skip_delay_channel<T: Ord>() -> (SkipDelaySender<T>, SkipDelayReceiver<T>) {
339    let queue = Arc::new(std::sync::Mutex::new(SkipDelayQueue::new()));
340    (SkipDelaySender(Some(queue.clone())), SkipDelayReceiver(queue))
341}
342
343#[cfg(test)]
344mod tests {
345    use futures::{SinkExt, StreamExt, pin_mut};
346
347    use super::*;
348
349    #[test_log::test(tokio::test)]
350    async fn skip_delay_queue_should_yield_items() -> anyhow::Result<()> {
351        let (mut tx, rx) = skip_delay_channel();
352        pin_mut!(rx);
353
354        let now = Instant::now();
355        tx.send((1, now + Duration::from_millis(100)).into()).await?;
356        tx.close().await?;
357
358        assert_eq!(Some(1), rx.next().await);
359        assert!(now.elapsed() >= Duration::from_millis(100));
360        assert_eq!(None, rx.next().await);
361
362        Ok(())
363    }
364
365    #[test_log::test(tokio::test)]
366    async fn skip_delay_queue_should_replace_and_yield_items() -> anyhow::Result<()> {
367        let (mut tx, rx) = skip_delay_channel();
368        pin_mut!(rx);
369
370        let now = Instant::now();
371        tx.send((1, now + Duration::from_millis(100)).into()).await?;
372        tx.send((1, now + Duration::from_millis(200)).into()).await?;
373        tx.close().await?;
374
375        assert_eq!(Some(1), rx.next().await);
376        assert!(now.elapsed() >= Duration::from_millis(200));
377        assert_eq!(None, rx.next().await);
378
379        Ok(())
380    }
381
382    #[test_log::test(tokio::test)]
383    async fn skip_delay_queue_should_yield_items_from_multiple_senders() -> anyhow::Result<()> {
384        let (mut tx, rx) = skip_delay_channel();
385        pin_mut!(rx);
386
387        let mut tx2 = tx.clone();
388
389        let now = Instant::now();
390        tx.send((2, now + Duration::from_millis(100)).into()).await?;
391        tx.close().await?;
392
393        tx2.send((1, now + Duration::from_millis(150)).into()).await?;
394        tx2.close().await?;
395
396        assert_eq!(Some(2), rx.next().await);
397        assert!(now.elapsed() >= Duration::from_millis(100));
398        assert_eq!(Some(1), rx.next().await);
399        assert!(now.elapsed() >= Duration::from_millis(150));
400
401        assert_eq!(None, rx.next().await);
402
403        Ok(())
404    }
405
406    #[test_log::test(tokio::test)]
407    async fn skip_delay_queue_yielded_items_should_be_apart() -> anyhow::Result<()> {
408        let (mut tx, rx) = skip_delay_channel();
409        pin_mut!(rx);
410
411        let now1 = Instant::now();
412        tx.send((1, now1 + Duration::from_millis(100)).into()).await?;
413        let now2 = Instant::now();
414        tx.send((2, now2 + Duration::from_millis(200)).into()).await?;
415        tx.close().await?;
416
417        assert_eq!(Some(1), rx.next().await);
418        assert!(now1.elapsed() >= Duration::from_millis(100));
419        assert_eq!(Some(2), rx.next().await);
420        assert!(now2.elapsed() >= Duration::from_millis(200));
421
422        assert_eq!(None, rx.next().await);
423
424        Ok(())
425    }
426
427    #[test_log::test(tokio::test)]
428    async fn skip_delay_queue_should_not_yield_cancelled_items() -> anyhow::Result<()> {
429        let (mut tx, rx) = skip_delay_channel();
430        pin_mut!(rx);
431
432        let now = Instant::now();
433        tx.send((1, now + Duration::from_millis(100)).into()).await?;
434        tx.send((1, Skip).into()).await?;
435        tx.close().await?;
436
437        assert_eq!(None, rx.next().await);
438
439        Ok(())
440    }
441
442    #[test_log::test(tokio::test)]
443    async fn skip_delay_queue_should_yield_past_items_immediately() -> anyhow::Result<()> {
444        let (mut tx, rx) = skip_delay_channel();
445        pin_mut!(rx);
446
447        let now = Instant::now();
448        tx.send((1, now).into()).await?;
449        tx.send((2, now).into()).await?;
450        tx.close().await?;
451
452        let now = Instant::now();
453        assert_eq!(Some(1), rx.next().await);
454        assert_eq!(Some(2), rx.next().await);
455        assert_eq!(None, rx.next().await);
456
457        assert!(now.elapsed() < Duration::from_millis(25));
458
459        Ok(())
460    }
461
462    #[test_log::test(tokio::test)]
463    async fn skip_delay_queue_should_not_yield_future_cancelled_items() -> anyhow::Result<()> {
464        let (mut tx, rx) = skip_delay_channel();
465        pin_mut!(rx);
466
467        let now = Instant::now();
468        tx.send((1, now).into()).await?;
469        tx.send((2, now + Duration::from_millis(100)).into()).await?;
470        tx.send((2, Skip).into()).await?;
471        tx.close().await?;
472
473        assert_eq!(Some(1), rx.next().await);
474        assert_eq!(None, rx.next().await);
475        assert!(now.elapsed() < Duration::from_millis(50));
476
477        Ok(())
478    }
479
480    #[test_log::test(tokio::test)]
481    async fn skip_delay_queue_should_discard_duplicate_entries() -> anyhow::Result<()> {
482        let (mut tx, rx) = skip_delay_channel();
483        pin_mut!(rx);
484
485        let now = Instant::now();
486        tx.send((1, now).into()).await?;
487        tx.send((1, now).into()).await?;
488        tx.close().await?;
489
490        assert_eq!(Some(1), rx.next().await);
491        assert_eq!(None, rx.next().await);
492
493        Ok(())
494    }
495
496    #[test_log::test(tokio::test)]
497    async fn skip_delay_queue_should_yield_items_in_order() -> anyhow::Result<()> {
498        let (mut tx, rx) = skip_delay_channel();
499        pin_mut!(rx);
500
501        let now = Instant::now();
502        tx.send((2, now).into()).await?;
503        tx.send((1, now).into()).await?;
504        tx.close().await?;
505
506        assert_eq!(Some(1), rx.next().await);
507        assert_eq!(Some(2), rx.next().await);
508        assert_eq!(None, rx.next().await);
509
510        Ok(())
511    }
512
513    #[test_log::test(tokio::test)]
514    async fn skip_delay_queue_should_yield_fed_items_in_order() -> anyhow::Result<()> {
515        let (mut tx, rx) = skip_delay_channel();
516        pin_mut!(rx);
517
518        let now = Instant::now();
519        tx.feed((2, now).into()).await?;
520        tx.feed((1, now).into()).await?;
521        tx.flush().await?;
522        tx.close().await?;
523
524        assert_eq!(Some(1), rx.next().await);
525        assert_eq!(Some(2), rx.next().await);
526        assert_eq!(None, rx.next().await);
527
528        Ok(())
529    }
530
531    #[test_log::test(tokio::test)]
532    async fn skip_delay_queue_should_not_send_items_when_closed() -> anyhow::Result<()> {
533        let (mut tx, rx) = skip_delay_channel();
534        pin_mut!(rx);
535        tx.close().await?;
536
537        let now = Instant::now();
538        tx.send((1, now).into()).await.unwrap_err();
539        tx.close().await.unwrap_err();
540
541        assert_eq!(None, rx.next().await);
542
543        Ok(())
544    }
545
546    #[test_log::test(tokio::test)]
547    async fn skip_delay_queue_should_continuously_yield_items() -> anyhow::Result<()> {
548        let (mut tx, rx) = skip_delay_channel();
549
550        let items = [5, 2, 1, 4, 3];
551
552        let now = Instant::now();
553        let timed_items = (0..5)
554            .map(|i| (items[i], now + Duration::from_millis(100) * (i as u32)))
555            .collect::<Vec<_>>();
556
557        let timed_items_clone = timed_items.clone();
558        let jh = hopr_utils::runtime::prelude::spawn(async move {
559            for (n, time) in timed_items_clone {
560                tx.send((n, time).into()).await?;
561                hopr_utils::runtime::prelude::sleep(Duration::from_millis(50)).await;
562            }
563            tx.close().await?;
564            Ok::<_, std::io::Error>(())
565        });
566
567        let collected = rx.map(|item| (item, Instant::now())).collect::<Vec<_>>().await;
568
569        assert_eq!(timed_items.len(), collected.len());
570
571        for (i, (item, received_at)) in collected.into_iter().enumerate() {
572            assert_eq!(timed_items[i].0, item);
573            if received_at < timed_items[i].1 {
574                assert!(timed_items[i].1.saturating_duration_since(received_at) < Duration::from_millis(20));
575            } else {
576                assert!(received_at.saturating_duration_since(timed_items[i].1) < Duration::from_millis(20));
577            }
578        }
579
580        jh.await??;
581        Ok(())
582    }
583}