Skip to main content

hopr_protocol_session/processing/
sequencer.rs

1//! This module defines the [`Sequencer`] stream adaptor.
2
3use std::{
4    collections::BinaryHeap,
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use futures_time::future::Timer;
12use tracing::instrument;
13
14use crate::{errors::SessionError, protocol::FrameId};
15
16/// Sequencer is an adaptor for streams, that yield elements that have a natural ordering and
17/// can be compared with [`FrameId`] and puts them in the correct sequence starting with
18/// `FrameId` equal to 1.
19///
20/// Sequencer internally maintains a `FrameId` to be yielded next, polls the underlying stream
21/// and yields elements only when they match the next `FrameId` to be yielded, incrementing the
22/// value on each yield.
23///
24/// The Sequencer takes to arguments: `max_wait` and `capacity`:
25///
26/// The `max_wait` indicates the maximum amount of time to wait for a certain `FrameId` to
27/// be yielded from the underlying stream.
28/// If this does not happen, the Segmenter yields an error,
29/// indicating that the given frame was discarded.
30///
31/// The `capacity` parameter sets the maximum number of buffered elements inside the Sequencer.
32/// If this value is reached, the Sequencer will stop polling the underlying stream, waiting for the
33/// next element to expire.
34///
35/// By definition, Sequencer is a fallible stream, yielding either `Ok(Item)`, `Err(`[`SessionError::FrameDiscarded`]`)`
36/// or `Ok(None)` when the underlying stream is closed and no more elements can be yielded.
37///
38/// Use [`SequencerExt`] methods to turn a stream into a sequenced stream.
39#[must_use = "streams do nothing unless polled"]
40#[pin_project::pin_project]
41pub struct Sequencer<S: futures::Stream> {
42    #[pin]
43    inner: S,
44    #[pin]
45    timer: futures_time::task::Sleep,
46    buffer: BinaryHeap<std::cmp::Reverse<S::Item>>,
47    next_id: FrameId,
48    last_emitted: Instant,
49    max_wait: Duration,
50    state: State,
51}
52
53impl<S> Sequencer<S>
54where
55    S: futures::Stream,
56    S::Item: Ord + PartialOrd<FrameId>,
57{
58    /// Creates a new instance, wrapping the given `inner` Segment sink.
59    ///
60    /// The `frame_size` value will be clamped into the `[C, (C - SessionMessage::SEGMENT_OVERHEAD) * SeqIndicator::MAX
61    /// + 1]` interval.
62    fn new(inner: S, max_wait: Duration, capacity: usize) -> Self {
63        assert!(capacity > 0, "capacity should be positive");
64        Self {
65            inner,
66            buffer: BinaryHeap::with_capacity(capacity),
67            timer: futures_time::task::sleep(max_wait.max(Duration::from_millis(1)).into()),
68            next_id: 1,
69            last_emitted: Instant::now(),
70            max_wait,
71            state: State::Polling,
72        }
73    }
74}
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
77enum State {
78    Polling,
79    BufferUpdated,
80    Done,
81}
82
83impl<S> futures::Stream for Sequencer<S>
84where
85    S: futures::Stream,
86    S::Item: Ord + PartialOrd<FrameId>,
87{
88    type Item = Result<S::Item, SessionError>;
89
90    #[instrument(name = "Sequencer::poll_next", level = "trace", skip(self, cx), fields(next_frame_id = self.next_id, state = ?self.state))]
91    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        let mut this = self.project();
93        if *this.next_id == 0 {
94            tracing::debug!("end of frame sequence reached");
95            return Poll::Ready(None);
96        }
97
98        loop {
99            match *this.state {
100                State::Polling => {
101                    if this.buffer.len() < this.buffer.capacity() {
102                        // We still have capacity available, poll the underlying stream
103                        let stream_poll = this.inner.as_mut().poll_next(cx);
104
105                        // Only poll timer if there's something in the buffer
106                        let timer_poll = if !this.buffer.is_empty() {
107                            let poll = this.timer.as_mut().poll(cx);
108                            if poll.is_ready() {
109                                this.timer.as_mut().reset_timer();
110                            }
111                            poll
112                        } else {
113                            Poll::Pending
114                        };
115
116                        match (stream_poll, timer_poll) {
117                            (Poll::Pending, Poll::Pending) => {
118                                tracing::trace!("pending");
119                                *this.state = State::Polling;
120                                return Poll::Pending;
121                            }
122                            (Poll::Ready(Some(item)), _) => {
123                                // We have to reset the last emitted timestamp if
124                                // the buffer was empty until now
125                                if this.buffer.is_empty() {
126                                    *this.last_emitted = Instant::now();
127                                }
128
129                                if item.lt(this.next_id) {
130                                    // Do not accept older frame ids
131                                    tracing::error!("old item");
132                                    *this.state = State::Polling;
133                                } else {
134                                    // Push new item to the buffer
135                                    tracing::trace!("new item");
136                                    this.buffer.push(std::cmp::Reverse(item));
137                                    *this.state = State::BufferUpdated;
138                                }
139                            }
140                            (Poll::Ready(None), _) => {
141                                tracing::trace!(len = this.buffer.len(), "stream is done");
142                                *this.state = State::Done
143                            }
144                            (_, Poll::Ready(_)) => {
145                                // Simulate buffer update when the timer elapses
146                                tracing::trace!("timer elapsed");
147                                *this.state = State::BufferUpdated;
148                            }
149                        }
150                    } else {
151                        // Simulate buffer update when at capacity
152                        tracing::warn!("sequencer buffer is full");
153                        *this.state = State::BufferUpdated;
154                    }
155                }
156                State::BufferUpdated => {
157                    // The buffer has been updated, check if we can yield something
158                    if let Some(next) = this.buffer.peek().map(|item| &item.0) {
159                        if next.eq(this.next_id) {
160                            *this.next_id = this.next_id.wrapping_add(1);
161                            *this.last_emitted = Instant::now();
162                            *this.state = State::BufferUpdated;
163
164                            tracing::trace!("emit next frame");
165
166                            return Poll::Ready(this.buffer.pop().map(|item| Ok(item.0)));
167                        } else if this.last_emitted.elapsed() >= *this.max_wait
168                            || this.buffer.len() == this.buffer.capacity()
169                        {
170                            let discarded = *this.next_id;
171                            *this.next_id = this.next_id.wrapping_add(1);
172                            *this.last_emitted = Instant::now();
173                            *this.state = State::BufferUpdated;
174
175                            tracing::trace!(discarded, "discard frame");
176
177                            return Poll::Ready(Some(Err(SessionError::FrameDiscarded(discarded))));
178                        }
179                    } else {
180                        tracing::trace!("buffer is empty");
181                    }
182
183                    // Nothing to yield, keep on polling
184                    *this.state = State::Polling;
185                }
186                State::Done => {
187                    // The underlying stream is done, drain what we have in the internal buffer
188                    return if let Some(next) = this.buffer.peek().map(|item| &item.0) {
189                        if next.lt(this.next_id) {
190                            tracing::error!("old item");
191                            this.buffer.pop();
192                            continue;
193                        } else if next.eq(this.next_id) {
194                            *this.next_id = this.next_id.wrapping_add(1);
195                            tracing::trace!("emit next frame when done");
196
197                            Poll::Ready(this.buffer.pop().map(|item| Ok(item.0)))
198                        } else {
199                            let discarded = *this.next_id;
200                            *this.next_id = this.next_id.wrapping_add(1);
201                            tracing::trace!(discarded, "discard frame when done");
202
203                            Poll::Ready(Some(Err(SessionError::FrameDiscarded(discarded))))
204                        }
205                    } else {
206                        tracing::trace!("buffer is empty and done");
207                        Poll::Ready(None)
208                    };
209                }
210            }
211        }
212    }
213}
214
215/// Stream extensions methods for item sequencing.
216pub trait SequencerExt: futures::Stream {
217    /// Attaches a [`Sequencer`] to the underlying stream, given the item `timeout` and `capacity`
218    /// of items.
219    fn sequencer(self, timeout: Duration, capacity: usize) -> Sequencer<Self>
220    where
221        Self::Item: Ord + PartialOrd<FrameId>,
222        Self: Sized,
223    {
224        Sequencer::new(self, timeout, capacity)
225    }
226}
227
228impl<T: ?Sized> SequencerExt for T where T: futures::Stream {}
229
230#[cfg(test)]
231mod tests {
232    use futures::{SinkExt, StreamExt, TryStreamExt, pin_mut};
233    use futures_time::future::FutureExt;
234
235    use super::*;
236
237    #[test_log::test(tokio::test)]
238    async fn sequencer_should_return_entries_in_order() -> anyhow::Result<()> {
239        let mut expected = vec![4u32, 1, 5, 7, 8, 6, 2, 3];
240
241        let actual: Vec<u32> = futures::stream::iter(expected.clone())
242            .sequencer(Duration::from_secs(5), 4096)
243            .try_collect()
244            .timeout(futures_time::time::Duration::from_secs(5))
245            .await??;
246
247        expected.sort();
248        assert_eq!(expected, actual);
249
250        Ok(())
251    }
252
253    #[test_log::test(tokio::test)]
254    async fn sequencer_should_not_allow_emitted_entries() -> anyhow::Result<()> {
255        let (seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
256
257        let seq_stream = seq_stream.sequencer(Duration::from_secs(1), 4096);
258
259        pin_mut!(seq_sink);
260        pin_mut!(seq_stream);
261
262        seq_sink.send(1u32).await?;
263        assert_eq!(Some(1), seq_stream.try_next().await?);
264
265        seq_sink.send(2u32).await?;
266        assert_eq!(Some(2), seq_stream.try_next().await?);
267
268        seq_sink.send(2u32).await?;
269        seq_sink.send(1u32).await?;
270
271        seq_sink.send(3u32).await?;
272        assert_eq!(Some(3), seq_stream.try_next().await?);
273
274        Ok(())
275    }
276
277    #[test_log::test(tokio::test)]
278    async fn sequencer_should_discard_entry_on_timeout() -> anyhow::Result<()> {
279        let timeout = Duration::from_millis(25);
280        let (mut seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
281
282        let input = vec![2u32, 1, 4, 5, 8, 7, 9, 11, 10];
283
284        let input_clone = input.clone();
285        let jh = hopr_utils::runtime::prelude::spawn(async move {
286            for v in input_clone {
287                seq_sink
288                    .feed(v)
289                    .delay(futures_time::time::Duration::from_millis(5))
290                    .await?;
291            }
292            seq_sink.flush().await?;
293            seq_sink.close().await
294        });
295
296        let seq_stream = seq_stream.sequencer(timeout, 4096);
297
298        pin_mut!(seq_stream);
299
300        assert_eq!(Some(1), seq_stream.try_next().await?);
301        assert_eq!(Some(2), seq_stream.try_next().await?);
302
303        let now = Instant::now();
304        assert!(matches!(
305            seq_stream.try_next().await,
306            Err(SessionError::FrameDiscarded(3))
307        ));
308        assert!(now.elapsed() >= timeout);
309
310        assert_eq!(Some(4), seq_stream.try_next().await?);
311        assert_eq!(Some(5), seq_stream.try_next().await?);
312
313        assert!(matches!(
314            seq_stream.try_next().await,
315            Err(SessionError::FrameDiscarded(6))
316        ));
317
318        assert_eq!(Some(7), seq_stream.try_next().await?);
319        assert_eq!(Some(8), seq_stream.try_next().await?);
320        assert_eq!(Some(9), seq_stream.try_next().await?);
321        assert_eq!(Some(10), seq_stream.try_next().await?);
322        assert_eq!(Some(11), seq_stream.try_next().await?);
323
324        assert_eq!(None, seq_stream.try_next().await?);
325
326        let _ = jh.await?;
327        Ok(())
328    }
329
330    #[test_log::test(tokio::test)]
331    async fn sequencer_should_discard_entry_close() -> anyhow::Result<()> {
332        let (seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
333
334        let input = vec![2u32, 1, 3, 5, 4, 8, 11];
335
336        hopr_utils::runtime::prelude::spawn(futures::stream::iter(input.clone()).map(Ok).forward(seq_sink)).await??;
337
338        let seq_stream = seq_stream.sequencer(Duration::from_millis(25), 4096);
339
340        pin_mut!(seq_stream);
341
342        assert_eq!(Some(1), seq_stream.try_next().await?);
343        assert_eq!(Some(2), seq_stream.try_next().await?);
344        assert_eq!(Some(3), seq_stream.try_next().await?);
345        assert_eq!(Some(4), seq_stream.try_next().await?);
346        assert_eq!(Some(5), seq_stream.try_next().await?);
347        assert!(matches!(
348            seq_stream.try_next().await,
349            Err(SessionError::FrameDiscarded(6))
350        ));
351        assert!(matches!(
352            seq_stream.try_next().await,
353            Err(SessionError::FrameDiscarded(7))
354        ));
355        assert_eq!(Some(8), seq_stream.try_next().await?);
356        assert!(matches!(
357            seq_stream.try_next().await,
358            Err(SessionError::FrameDiscarded(9))
359        ));
360        assert!(matches!(
361            seq_stream.try_next().await,
362            Err(SessionError::FrameDiscarded(10))
363        ));
364        assert_eq!(Some(11), seq_stream.try_next().await?);
365        assert_eq!(None, seq_stream.try_next().await?);
366
367        Ok(())
368    }
369
370    #[test_log::test(tokio::test)]
371    async fn sequencer_should_discard_entry_when_inner_stream_pending() -> anyhow::Result<()> {
372        let sent = vec![4u32, 1, 7, 8, 6, 2, 3];
373        let (tx, rx) = futures::channel::mpsc::unbounded();
374
375        pin_mut!(tx);
376        tx.send_all(&mut futures::stream::iter(sent.clone()).map(Ok)).await?;
377
378        let rx = rx.sequencer(Duration::from_millis(10), 4096);
379        pin_mut!(rx);
380
381        assert!(matches!(rx.next().await, Some(Ok(1))));
382        assert!(matches!(rx.next().await, Some(Ok(2))));
383        assert!(matches!(rx.next().await, Some(Ok(3))));
384        assert!(matches!(rx.next().await, Some(Ok(4))));
385        assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(5)))));
386        assert!(matches!(rx.next().await, Some(Ok(6))));
387        assert!(matches!(rx.next().await, Some(Ok(7))));
388        assert!(matches!(rx.next().await, Some(Ok(8))));
389
390        Ok(())
391    }
392
393    #[test_log::test(tokio::test)]
394    async fn sequencer_should_discard_entry_when_capacity_is_reached() -> anyhow::Result<()> {
395        let sent = vec![4u32, 5, 7, 8, 2, 6, 3];
396        let (tx, rx) = futures::channel::mpsc::unbounded();
397
398        pin_mut!(tx);
399        tx.send_all(&mut futures::stream::iter(sent.clone()).map(Ok)).await?;
400
401        let rx = rx.sequencer(Duration::from_millis(10), 4);
402        pin_mut!(rx);
403
404        assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(1)))));
405        assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(2)))));
406        assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(3)))));
407        assert!(matches!(rx.next().await, Some(Ok(4))));
408        assert!(matches!(rx.next().await, Some(Ok(5))));
409        assert!(matches!(rx.next().await, Some(Ok(6))));
410        assert!(matches!(rx.next().await, Some(Ok(7))));
411        assert!(matches!(rx.next().await, Some(Ok(8))));
412
413        Ok(())
414    }
415
416    #[test_log::test(tokio::test)]
417    async fn sequencer_must_terminate_on_last_frame_id() -> anyhow::Result<()> {
418        let (tx, rx) = futures::channel::mpsc::unbounded();
419
420        pin_mut!(tx);
421        tx.send_all(&mut futures::stream::iter([FrameId::MAX - 1, FrameId::MAX, 1, 2]).map(Ok))
422            .await?;
423
424        let mut rx = rx.sequencer(Duration::from_millis(10), 1024);
425        rx.next_id = FrameId::MAX - 1;
426        pin_mut!(rx);
427
428        const LAST_ID: FrameId = FrameId::MAX - 1;
429        assert!(matches!(rx.next().await, Some(Ok(LAST_ID))));
430        assert!(matches!(rx.next().await, Some(Ok(FrameId::MAX))));
431        assert!(rx.next().await.is_none());
432
433        Ok(())
434    }
435
436    #[test_log::test(tokio::test(flavor = "multi_thread"))]
437    async fn sequencer_must_not_discard_frames_when_buffer_was_empty_after_timeout() -> anyhow::Result<()> {
438        let (tx, rx) = futures::channel::mpsc::unbounded();
439
440        let jh = tokio::task::spawn(async move {
441            tokio::time::sleep(Duration::from_millis(2)).await;
442            pin_mut!(tx);
443            tx.send_all(&mut futures::stream::iter([3, 1, 2, 4]).map(Ok)).await?;
444
445            tokio::time::sleep(Duration::from_millis(150)).await;
446
447            tx.send_all(&mut futures::stream::iter([6, 5, 7]).map(Ok)).await?;
448
449            anyhow::Ok(())
450        });
451
452        let chunks = rx
453            .sequencer(Duration::from_millis(50), 1024)
454            .try_ready_chunks(10)
455            .try_collect::<Vec<Vec<_>>>()
456            .await?;
457
458        assert_eq!(chunks, vec![vec![1, 2, 3, 4], vec![5, 6, 7]]);
459        jh.await??;
460
461        Ok(())
462    }
463}