hopr_protocol_session/utils/
skip_queue.rs1use std::{
2 cmp::Ordering,
3 collections::BTreeSet,
4 pin::Pin,
5 sync::{Arc, atomic::AtomicBool},
6 task::{Context, Poll, Waker},
7 time::{Duration, Instant},
8};
9
10use futures::FutureExt;
11use tracing::instrument;
12
13#[derive(Debug)]
15struct DelayedEntry<T> {
16 item: T,
17 at: Instant,
18 cancelled: AtomicBool,
19}
20
21impl<T: PartialEq> PartialEq for DelayedEntry<T> {
23 fn eq(&self, other: &Self) -> bool {
24 self.item == other.item
25 }
26}
27
28impl<T: Eq> Eq for DelayedEntry<T> {}
29
30impl<T: Ord> Ord for DelayedEntry<T> {
31 fn cmp(&self, other: &Self) -> Ordering {
32 if other.item != self.item {
33 match self.at.cmp(&other.at) {
35 Ordering::Equal => self.item.cmp(&other.item),
39 x => x,
40 }
41 } else {
42 Ordering::Equal
44 }
45 }
46}
47
48impl<T: Ord> PartialOrd for DelayedEntry<T> {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 Some(self.cmp(other))
51 }
52}
53
54struct SkipDelayQueue<T> {
56 entries: BTreeSet<DelayedEntry<T>>,
57 next_wakeup: Option<futures_time::task::SleepUntil>,
58 rx_waker: Option<Waker>,
59 is_closed: bool,
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
67pub enum DelayedItem<T> {
68 New(T, Instant),
70 Cancel(T),
72}
73
74#[derive(Debug, Copy, Clone, PartialEq, Eq)]
76pub struct Skip;
77
78impl<T> From<(T, Duration)> for DelayedItem<T> {
79 fn from(value: (T, Duration)) -> Self {
80 Self::New(value.0, Instant::now() + value.1)
81 }
82}
83
84impl<T> From<(T, Instant)> for DelayedItem<T> {
85 fn from(value: (T, Instant)) -> Self {
86 Self::New(value.0, value.1)
87 }
88}
89
90impl<T> From<(T, Skip)> for DelayedItem<T> {
91 fn from(value: (T, Skip)) -> Self {
92 Self::Cancel(value.0)
93 }
94}
95
96impl<T> SkipDelayQueue<T> {
97 const TOLERANCE: Duration = Duration::from_millis(5);
98
99 pub fn new() -> Self {
104 Self {
105 entries: BTreeSet::new(),
106 next_wakeup: None,
107 rx_waker: None,
108 is_closed: false,
109 }
110 }
111}
112
113pub struct SkipDelayReceiver<T>(Arc<std::sync::Mutex<SkipDelayQueue<T>>>);
115
116impl<T> Drop for SkipDelayReceiver<T> {
117 #[instrument(name = "SkipDelayReceiver::drop", level = "trace", skip(self))]
118 fn drop(&mut self) {
119 self.0.clear_poison();
121 let mut queue = self.0.lock().expect("cannot panic because poison is cleared");
122 queue.is_closed = true;
123 queue.rx_waker = None;
124 }
125}
126
127impl<T: Ord> futures::Stream for SkipDelayReceiver<T> {
128 type Item = T;
129
130 #[instrument(name = "SkipDelayReceiver::poll_next", level = "trace", skip(self, cx))]
131 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132 let Ok(mut queue) = self.0.lock() else {
133 tracing::error!("poisoned mutex");
134 return Poll::Ready(None);
135 };
136
137 if let Some(next_wakeup) = queue.next_wakeup.as_mut() {
139 tracing::trace!("polling timer");
140 let _ = futures::ready!(next_wakeup.poll_unpin(cx));
141 queue.next_wakeup = None;
142 }
143
144 tracing::trace!("timer finished");
145
146 let now = Instant::now();
147 while let Some(e) = queue.entries.first() {
148 if !e.cancelled.load(std::sync::atomic::Ordering::SeqCst) {
149 return if e.at.saturating_duration_since(now) < SkipDelayQueue::<T>::TOLERANCE {
150 tracing::trace!("ready");
152 Poll::Ready(queue.entries.pop_first().map(|e| e.item))
153 } else {
154 tracing::trace!("pending new timer");
156 queue.next_wakeup = Some(futures_time::task::sleep_until(e.at.into()));
157 cx.waker().wake_by_ref();
158 Poll::Pending
159 };
160 } else {
161 queue.entries.pop_first();
163 tracing::trace!("item cancelled");
164 }
165 }
166
167 if !queue.is_closed {
168 tracing::trace!("pending for data");
170 queue.rx_waker = Some(cx.waker().clone());
171 Poll::Pending
172 } else {
173 Poll::Ready(None)
175 }
176 }
177}
178
179pub struct SkipDelaySender<T>(Option<Arc<std::sync::Mutex<SkipDelayQueue<T>>>>);
181
182impl<T> Clone for SkipDelaySender<T> {
183 fn clone(&self) -> Self {
184 Self(self.0.clone())
185 }
186}
187
188impl<T> SkipDelaySender<T> {
189 fn ensure_closure(&mut self) {
190 if let Some(queue) = self.0.take() {
191 let count_holders = Arc::strong_count(&queue);
192 tracing::trace!(count_holders, "ensure_closure");
193
194 if count_holders == 2 {
196 Self::finalize_closure(queue);
197 }
198 }
199 }
200
201 fn finalize_closure(queue: Arc<std::sync::Mutex<SkipDelayQueue<T>>>) {
202 tracing::trace!("finalize_closure");
203 queue.clear_poison();
204 let mut queue = queue.lock().expect("cannot panic because poison is cleared");
205 queue.is_closed = true;
206 queue.rx_waker = None;
207 }
208
209 pub fn force_close(&mut self) {
211 if let Some(queue) = self.0.take() {
212 Self::finalize_closure(queue);
213 }
214 }
215}
216
217impl<T: Ord> SkipDelaySender<T> {
218 #[instrument(
219 name = "SkipDelaySender::send_internal",
220 level = "trace",
221 skip(self, items, flush),
222 ret
223 )]
224 fn send_internal<I: Iterator<Item = DelayedItem<T>>>(&self, items: I, flush: bool) -> Result<(), std::io::Error> {
225 if let Some(queue) = self.0.as_ref() {
226 let mut queue = queue.lock().map_err(|_| std::io::ErrorKind::BrokenPipe)?;
227
228 if queue.is_closed {
230 return Err(std::io::ErrorKind::BrokenPipe.into());
231 }
232
233 for item in items {
234 match item {
235 DelayedItem::New(item, at) => {
236 tracing::trace!(at = ?at.saturating_duration_since(Instant::now()), "inserting");
237 queue.entries.replace(DelayedEntry {
238 item,
239 at,
240 cancelled: AtomicBool::new(false),
241 });
242 }
243 DelayedItem::Cancel(item) => {
244 tracing::trace!("cancelling");
245 queue
246 .entries
247 .iter()
248 .filter(|e| item == e.item)
249 .for_each(|e| e.cancelled.store(true, std::sync::atomic::Ordering::SeqCst));
250 }
251 }
252 }
253
254 if flush {
255 tracing::trace!("flushing");
256 if let Some(waker) = queue.rx_waker.take() {
257 waker.wake();
258 }
259 }
260
261 Ok(())
262 } else {
263 Err(std::io::ErrorKind::NotConnected.into())
264 }
265 }
266
267 pub fn send_one<I: Into<DelayedItem<T>>>(&mut self, item: I) -> Result<(), std::io::Error> {
269 self.send_internal(std::iter::once(item.into()), true)
270 }
271
272 pub fn send_many<I: IntoIterator<Item = DelayedItem<T>>>(&mut self, items: I) -> Result<(), std::io::Error> {
274 self.send_internal(items.into_iter(), true)
275 }
276}
277
278impl<T> Drop for SkipDelaySender<T> {
279 fn drop(&mut self) {
280 self.ensure_closure();
281 }
282}
283
284impl<T: Ord> futures::Sink<DelayedItem<T>> for SkipDelaySender<T> {
285 type Error = std::io::Error;
286
287 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
288 if self.0.is_some() {
289 Poll::Ready(Ok(()))
290 } else {
291 Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()))
292 }
293 }
294
295 fn start_send(self: Pin<&mut Self>, item: DelayedItem<T>) -> Result<(), Self::Error> {
296 self.send_internal(std::iter::once(item), false)
297 }
298
299 #[instrument(name = "SkipDelaySender::poll_flush", level = "trace", skip(self), ret)]
300 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301 if let Some(queue) = self.0.as_ref() {
302 let Ok(mut queue) = queue.lock() else {
303 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
304 };
305
306 tracing::trace!("flushing");
307 if let Some(waker) = queue.rx_waker.take() {
308 waker.wake();
309 }
310
311 Poll::Ready(Ok(()))
312 } else {
313 Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()))
314 }
315 }
316
317 fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318 if self.0.is_none() {
319 return Poll::Ready(Err(std::io::ErrorKind::NotConnected.into()));
320 }
321
322 self.ensure_closure();
323 Poll::Ready(Ok(()))
324 }
325}
326
327pub fn skip_delay_channel<T: Ord>() -> (SkipDelaySender<T>, SkipDelayReceiver<T>) {
339 let queue = Arc::new(std::sync::Mutex::new(SkipDelayQueue::new()));
340 (SkipDelaySender(Some(queue.clone())), SkipDelayReceiver(queue))
341}
342
343#[cfg(test)]
344mod tests {
345 use futures::{SinkExt, StreamExt, pin_mut};
346
347 use super::*;
348
349 #[test_log::test(tokio::test)]
350 async fn skip_delay_queue_should_yield_items() -> anyhow::Result<()> {
351 let (mut tx, rx) = skip_delay_channel();
352 pin_mut!(rx);
353
354 let now = Instant::now();
355 tx.send((1, now + Duration::from_millis(100)).into()).await?;
356 tx.close().await?;
357
358 assert_eq!(Some(1), rx.next().await);
359 assert!(now.elapsed() >= Duration::from_millis(100));
360 assert_eq!(None, rx.next().await);
361
362 Ok(())
363 }
364
365 #[test_log::test(tokio::test)]
366 async fn skip_delay_queue_should_replace_and_yield_items() -> anyhow::Result<()> {
367 let (mut tx, rx) = skip_delay_channel();
368 pin_mut!(rx);
369
370 let now = Instant::now();
371 tx.send((1, now + Duration::from_millis(100)).into()).await?;
372 tx.send((1, now + Duration::from_millis(200)).into()).await?;
373 tx.close().await?;
374
375 assert_eq!(Some(1), rx.next().await);
376 assert!(now.elapsed() >= Duration::from_millis(200));
377 assert_eq!(None, rx.next().await);
378
379 Ok(())
380 }
381
382 #[test_log::test(tokio::test)]
383 async fn skip_delay_queue_should_yield_items_from_multiple_senders() -> anyhow::Result<()> {
384 let (mut tx, rx) = skip_delay_channel();
385 pin_mut!(rx);
386
387 let mut tx2 = tx.clone();
388
389 let now = Instant::now();
390 tx.send((2, now + Duration::from_millis(100)).into()).await?;
391 tx.close().await?;
392
393 tx2.send((1, now + Duration::from_millis(150)).into()).await?;
394 tx2.close().await?;
395
396 assert_eq!(Some(2), rx.next().await);
397 assert!(now.elapsed() >= Duration::from_millis(100));
398 assert_eq!(Some(1), rx.next().await);
399 assert!(now.elapsed() >= Duration::from_millis(150));
400
401 assert_eq!(None, rx.next().await);
402
403 Ok(())
404 }
405
406 #[test_log::test(tokio::test)]
407 async fn skip_delay_queue_yielded_items_should_be_apart() -> anyhow::Result<()> {
408 let (mut tx, rx) = skip_delay_channel();
409 pin_mut!(rx);
410
411 let now1 = Instant::now();
412 tx.send((1, now1 + Duration::from_millis(100)).into()).await?;
413 let now2 = Instant::now();
414 tx.send((2, now2 + Duration::from_millis(200)).into()).await?;
415 tx.close().await?;
416
417 assert_eq!(Some(1), rx.next().await);
418 assert!(now1.elapsed() >= Duration::from_millis(100));
419 assert_eq!(Some(2), rx.next().await);
420 assert!(now2.elapsed() >= Duration::from_millis(200));
421
422 assert_eq!(None, rx.next().await);
423
424 Ok(())
425 }
426
427 #[test_log::test(tokio::test)]
428 async fn skip_delay_queue_should_not_yield_cancelled_items() -> anyhow::Result<()> {
429 let (mut tx, rx) = skip_delay_channel();
430 pin_mut!(rx);
431
432 let now = Instant::now();
433 tx.send((1, now + Duration::from_millis(100)).into()).await?;
434 tx.send((1, Skip).into()).await?;
435 tx.close().await?;
436
437 assert_eq!(None, rx.next().await);
438
439 Ok(())
440 }
441
442 #[test_log::test(tokio::test)]
443 async fn skip_delay_queue_should_yield_past_items_immediately() -> anyhow::Result<()> {
444 let (mut tx, rx) = skip_delay_channel();
445 pin_mut!(rx);
446
447 let now = Instant::now();
448 tx.send((1, now).into()).await?;
449 tx.send((2, now).into()).await?;
450 tx.close().await?;
451
452 let now = Instant::now();
453 assert_eq!(Some(1), rx.next().await);
454 assert_eq!(Some(2), rx.next().await);
455 assert_eq!(None, rx.next().await);
456
457 assert!(now.elapsed() < Duration::from_millis(25));
458
459 Ok(())
460 }
461
462 #[test_log::test(tokio::test)]
463 async fn skip_delay_queue_should_not_yield_future_cancelled_items() -> anyhow::Result<()> {
464 let (mut tx, rx) = skip_delay_channel();
465 pin_mut!(rx);
466
467 let now = Instant::now();
468 tx.send((1, now).into()).await?;
469 tx.send((2, now + Duration::from_millis(100)).into()).await?;
470 tx.send((2, Skip).into()).await?;
471 tx.close().await?;
472
473 assert_eq!(Some(1), rx.next().await);
474 assert_eq!(None, rx.next().await);
475 assert!(now.elapsed() < Duration::from_millis(50));
476
477 Ok(())
478 }
479
480 #[test_log::test(tokio::test)]
481 async fn skip_delay_queue_should_discard_duplicate_entries() -> anyhow::Result<()> {
482 let (mut tx, rx) = skip_delay_channel();
483 pin_mut!(rx);
484
485 let now = Instant::now();
486 tx.send((1, now).into()).await?;
487 tx.send((1, now).into()).await?;
488 tx.close().await?;
489
490 assert_eq!(Some(1), rx.next().await);
491 assert_eq!(None, rx.next().await);
492
493 Ok(())
494 }
495
496 #[test_log::test(tokio::test)]
497 async fn skip_delay_queue_should_yield_items_in_order() -> anyhow::Result<()> {
498 let (mut tx, rx) = skip_delay_channel();
499 pin_mut!(rx);
500
501 let now = Instant::now();
502 tx.send((2, now).into()).await?;
503 tx.send((1, now).into()).await?;
504 tx.close().await?;
505
506 assert_eq!(Some(1), rx.next().await);
507 assert_eq!(Some(2), rx.next().await);
508 assert_eq!(None, rx.next().await);
509
510 Ok(())
511 }
512
513 #[test_log::test(tokio::test)]
514 async fn skip_delay_queue_should_yield_fed_items_in_order() -> anyhow::Result<()> {
515 let (mut tx, rx) = skip_delay_channel();
516 pin_mut!(rx);
517
518 let now = Instant::now();
519 tx.feed((2, now).into()).await?;
520 tx.feed((1, now).into()).await?;
521 tx.flush().await?;
522 tx.close().await?;
523
524 assert_eq!(Some(1), rx.next().await);
525 assert_eq!(Some(2), rx.next().await);
526 assert_eq!(None, rx.next().await);
527
528 Ok(())
529 }
530
531 #[test_log::test(tokio::test)]
532 async fn skip_delay_queue_should_not_send_items_when_closed() -> anyhow::Result<()> {
533 let (mut tx, rx) = skip_delay_channel();
534 pin_mut!(rx);
535 tx.close().await?;
536
537 let now = Instant::now();
538 tx.send((1, now).into()).await.unwrap_err();
539 tx.close().await.unwrap_err();
540
541 assert_eq!(None, rx.next().await);
542
543 Ok(())
544 }
545
546 #[test_log::test(tokio::test)]
547 async fn skip_delay_queue_should_continuously_yield_items() -> anyhow::Result<()> {
548 let (mut tx, rx) = skip_delay_channel();
549
550 let items = [5, 2, 1, 4, 3];
551
552 let now = Instant::now();
553 let timed_items = (0..5)
554 .map(|i| (items[i], now + Duration::from_millis(100) * (i as u32)))
555 .collect::<Vec<_>>();
556
557 let timed_items_clone = timed_items.clone();
558 let jh = hopr_utils::runtime::prelude::spawn(async move {
559 for (n, time) in timed_items_clone {
560 tx.send((n, time).into()).await?;
561 hopr_utils::runtime::prelude::sleep(Duration::from_millis(50)).await;
562 }
563 tx.close().await?;
564 Ok::<_, std::io::Error>(())
565 });
566
567 let collected = rx.map(|item| (item, Instant::now())).collect::<Vec<_>>().await;
568
569 assert_eq!(timed_items.len(), collected.len());
570
571 for (i, (item, received_at)) in collected.into_iter().enumerate() {
572 assert_eq!(timed_items[i].0, item);
573 if received_at < timed_items[i].1 {
574 assert!(timed_items[i].1.saturating_duration_since(received_at) < Duration::from_millis(20));
575 } else {
576 assert!(received_at.saturating_duration_since(timed_items[i].1) < Duration::from_millis(20));
577 }
578 }
579
580 jh.await??;
581 Ok(())
582 }
583}