1use std::{
4 collections::BinaryHeap,
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use futures_time::future::Timer;
12use tracing::instrument;
13
14use crate::{errors::SessionError, protocol::FrameId};
15
16#[must_use = "streams do nothing unless polled"]
40#[pin_project::pin_project]
41pub struct Sequencer<S: futures::Stream> {
42 #[pin]
43 inner: S,
44 #[pin]
45 timer: futures_time::task::Sleep,
46 buffer: BinaryHeap<std::cmp::Reverse<S::Item>>,
47 next_id: FrameId,
48 last_emitted: Instant,
49 max_wait: Duration,
50 state: State,
51}
52
53impl<S> Sequencer<S>
54where
55 S: futures::Stream,
56 S::Item: Ord + PartialOrd<FrameId>,
57{
58 fn new(inner: S, max_wait: Duration, capacity: usize) -> Self {
63 assert!(capacity > 0, "capacity should be positive");
64 Self {
65 inner,
66 buffer: BinaryHeap::with_capacity(capacity),
67 timer: futures_time::task::sleep(max_wait.max(Duration::from_millis(1)).into()),
68 next_id: 1,
69 last_emitted: Instant::now(),
70 max_wait,
71 state: State::Polling,
72 }
73 }
74}
75
76#[derive(Clone, Copy, Debug, PartialEq, Eq)]
77enum State {
78 Polling,
79 BufferUpdated,
80 Done,
81}
82
83impl<S> futures::Stream for Sequencer<S>
84where
85 S: futures::Stream,
86 S::Item: Ord + PartialOrd<FrameId>,
87{
88 type Item = Result<S::Item, SessionError>;
89
90 #[instrument(name = "Sequencer::poll_next", level = "trace", skip(self, cx), fields(next_frame_id = self.next_id, state = ?self.state))]
91 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92 let mut this = self.project();
93 if *this.next_id == 0 {
94 tracing::debug!("end of frame sequence reached");
95 return Poll::Ready(None);
96 }
97
98 loop {
99 match *this.state {
100 State::Polling => {
101 if this.buffer.len() < this.buffer.capacity() {
102 let stream_poll = this.inner.as_mut().poll_next(cx);
104
105 let timer_poll = if !this.buffer.is_empty() {
107 let poll = this.timer.as_mut().poll(cx);
108 if poll.is_ready() {
109 this.timer.as_mut().reset_timer();
110 }
111 poll
112 } else {
113 Poll::Pending
114 };
115
116 match (stream_poll, timer_poll) {
117 (Poll::Pending, Poll::Pending) => {
118 tracing::trace!("pending");
119 *this.state = State::Polling;
120 return Poll::Pending;
121 }
122 (Poll::Ready(Some(item)), _) => {
123 if this.buffer.is_empty() {
126 *this.last_emitted = Instant::now();
127 }
128
129 if item.lt(this.next_id) {
130 tracing::error!("old item");
132 *this.state = State::Polling;
133 } else {
134 tracing::trace!("new item");
136 this.buffer.push(std::cmp::Reverse(item));
137 *this.state = State::BufferUpdated;
138 }
139 }
140 (Poll::Ready(None), _) => {
141 tracing::trace!(len = this.buffer.len(), "stream is done");
142 *this.state = State::Done
143 }
144 (_, Poll::Ready(_)) => {
145 tracing::trace!("timer elapsed");
147 *this.state = State::BufferUpdated;
148 }
149 }
150 } else {
151 tracing::warn!("sequencer buffer is full");
153 *this.state = State::BufferUpdated;
154 }
155 }
156 State::BufferUpdated => {
157 if let Some(next) = this.buffer.peek().map(|item| &item.0) {
159 if next.eq(this.next_id) {
160 *this.next_id = this.next_id.wrapping_add(1);
161 *this.last_emitted = Instant::now();
162 *this.state = State::BufferUpdated;
163
164 tracing::trace!("emit next frame");
165
166 return Poll::Ready(this.buffer.pop().map(|item| Ok(item.0)));
167 } else if this.last_emitted.elapsed() >= *this.max_wait
168 || this.buffer.len() == this.buffer.capacity()
169 {
170 let discarded = *this.next_id;
171 *this.next_id = this.next_id.wrapping_add(1);
172 *this.last_emitted = Instant::now();
173 *this.state = State::BufferUpdated;
174
175 tracing::trace!(discarded, "discard frame");
176
177 return Poll::Ready(Some(Err(SessionError::FrameDiscarded(discarded))));
178 }
179 } else {
180 tracing::trace!("buffer is empty");
181 }
182
183 *this.state = State::Polling;
185 }
186 State::Done => {
187 return if let Some(next) = this.buffer.peek().map(|item| &item.0) {
189 if next.lt(this.next_id) {
190 tracing::error!("old item");
191 this.buffer.pop();
192 continue;
193 } else if next.eq(this.next_id) {
194 *this.next_id = this.next_id.wrapping_add(1);
195 tracing::trace!("emit next frame when done");
196
197 Poll::Ready(this.buffer.pop().map(|item| Ok(item.0)))
198 } else {
199 let discarded = *this.next_id;
200 *this.next_id = this.next_id.wrapping_add(1);
201 tracing::trace!(discarded, "discard frame when done");
202
203 Poll::Ready(Some(Err(SessionError::FrameDiscarded(discarded))))
204 }
205 } else {
206 tracing::trace!("buffer is empty and done");
207 Poll::Ready(None)
208 };
209 }
210 }
211 }
212 }
213}
214
215pub trait SequencerExt: futures::Stream {
217 fn sequencer(self, timeout: Duration, capacity: usize) -> Sequencer<Self>
220 where
221 Self::Item: Ord + PartialOrd<FrameId>,
222 Self: Sized,
223 {
224 Sequencer::new(self, timeout, capacity)
225 }
226}
227
228impl<T: ?Sized> SequencerExt for T where T: futures::Stream {}
229
230#[cfg(test)]
231mod tests {
232 use futures::{SinkExt, StreamExt, TryStreamExt, pin_mut};
233 use futures_time::future::FutureExt;
234
235 use super::*;
236
237 #[test_log::test(tokio::test)]
238 async fn sequencer_should_return_entries_in_order() -> anyhow::Result<()> {
239 let mut expected = vec![4u32, 1, 5, 7, 8, 6, 2, 3];
240
241 let actual: Vec<u32> = futures::stream::iter(expected.clone())
242 .sequencer(Duration::from_secs(5), 4096)
243 .try_collect()
244 .timeout(futures_time::time::Duration::from_secs(5))
245 .await??;
246
247 expected.sort();
248 assert_eq!(expected, actual);
249
250 Ok(())
251 }
252
253 #[test_log::test(tokio::test)]
254 async fn sequencer_should_not_allow_emitted_entries() -> anyhow::Result<()> {
255 let (seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
256
257 let seq_stream = seq_stream.sequencer(Duration::from_secs(1), 4096);
258
259 pin_mut!(seq_sink);
260 pin_mut!(seq_stream);
261
262 seq_sink.send(1u32).await?;
263 assert_eq!(Some(1), seq_stream.try_next().await?);
264
265 seq_sink.send(2u32).await?;
266 assert_eq!(Some(2), seq_stream.try_next().await?);
267
268 seq_sink.send(2u32).await?;
269 seq_sink.send(1u32).await?;
270
271 seq_sink.send(3u32).await?;
272 assert_eq!(Some(3), seq_stream.try_next().await?);
273
274 Ok(())
275 }
276
277 #[test_log::test(tokio::test)]
278 async fn sequencer_should_discard_entry_on_timeout() -> anyhow::Result<()> {
279 let timeout = Duration::from_millis(25);
280 let (mut seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
281
282 let input = vec![2u32, 1, 4, 5, 8, 7, 9, 11, 10];
283
284 let input_clone = input.clone();
285 let jh = hopr_utils::runtime::prelude::spawn(async move {
286 for v in input_clone {
287 seq_sink
288 .feed(v)
289 .delay(futures_time::time::Duration::from_millis(5))
290 .await?;
291 }
292 seq_sink.flush().await?;
293 seq_sink.close().await
294 });
295
296 let seq_stream = seq_stream.sequencer(timeout, 4096);
297
298 pin_mut!(seq_stream);
299
300 assert_eq!(Some(1), seq_stream.try_next().await?);
301 assert_eq!(Some(2), seq_stream.try_next().await?);
302
303 let now = Instant::now();
304 assert!(matches!(
305 seq_stream.try_next().await,
306 Err(SessionError::FrameDiscarded(3))
307 ));
308 assert!(now.elapsed() >= timeout);
309
310 assert_eq!(Some(4), seq_stream.try_next().await?);
311 assert_eq!(Some(5), seq_stream.try_next().await?);
312
313 assert!(matches!(
314 seq_stream.try_next().await,
315 Err(SessionError::FrameDiscarded(6))
316 ));
317
318 assert_eq!(Some(7), seq_stream.try_next().await?);
319 assert_eq!(Some(8), seq_stream.try_next().await?);
320 assert_eq!(Some(9), seq_stream.try_next().await?);
321 assert_eq!(Some(10), seq_stream.try_next().await?);
322 assert_eq!(Some(11), seq_stream.try_next().await?);
323
324 assert_eq!(None, seq_stream.try_next().await?);
325
326 let _ = jh.await?;
327 Ok(())
328 }
329
330 #[test_log::test(tokio::test)]
331 async fn sequencer_should_discard_entry_close() -> anyhow::Result<()> {
332 let (seq_sink, seq_stream) = futures::channel::mpsc::unbounded();
333
334 let input = vec![2u32, 1, 3, 5, 4, 8, 11];
335
336 hopr_utils::runtime::prelude::spawn(futures::stream::iter(input.clone()).map(Ok).forward(seq_sink)).await??;
337
338 let seq_stream = seq_stream.sequencer(Duration::from_millis(25), 4096);
339
340 pin_mut!(seq_stream);
341
342 assert_eq!(Some(1), seq_stream.try_next().await?);
343 assert_eq!(Some(2), seq_stream.try_next().await?);
344 assert_eq!(Some(3), seq_stream.try_next().await?);
345 assert_eq!(Some(4), seq_stream.try_next().await?);
346 assert_eq!(Some(5), seq_stream.try_next().await?);
347 assert!(matches!(
348 seq_stream.try_next().await,
349 Err(SessionError::FrameDiscarded(6))
350 ));
351 assert!(matches!(
352 seq_stream.try_next().await,
353 Err(SessionError::FrameDiscarded(7))
354 ));
355 assert_eq!(Some(8), seq_stream.try_next().await?);
356 assert!(matches!(
357 seq_stream.try_next().await,
358 Err(SessionError::FrameDiscarded(9))
359 ));
360 assert!(matches!(
361 seq_stream.try_next().await,
362 Err(SessionError::FrameDiscarded(10))
363 ));
364 assert_eq!(Some(11), seq_stream.try_next().await?);
365 assert_eq!(None, seq_stream.try_next().await?);
366
367 Ok(())
368 }
369
370 #[test_log::test(tokio::test)]
371 async fn sequencer_should_discard_entry_when_inner_stream_pending() -> anyhow::Result<()> {
372 let sent = vec![4u32, 1, 7, 8, 6, 2, 3];
373 let (tx, rx) = futures::channel::mpsc::unbounded();
374
375 pin_mut!(tx);
376 tx.send_all(&mut futures::stream::iter(sent.clone()).map(Ok)).await?;
377
378 let rx = rx.sequencer(Duration::from_millis(10), 4096);
379 pin_mut!(rx);
380
381 assert!(matches!(rx.next().await, Some(Ok(1))));
382 assert!(matches!(rx.next().await, Some(Ok(2))));
383 assert!(matches!(rx.next().await, Some(Ok(3))));
384 assert!(matches!(rx.next().await, Some(Ok(4))));
385 assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(5)))));
386 assert!(matches!(rx.next().await, Some(Ok(6))));
387 assert!(matches!(rx.next().await, Some(Ok(7))));
388 assert!(matches!(rx.next().await, Some(Ok(8))));
389
390 Ok(())
391 }
392
393 #[test_log::test(tokio::test)]
394 async fn sequencer_should_discard_entry_when_capacity_is_reached() -> anyhow::Result<()> {
395 let sent = vec![4u32, 5, 7, 8, 2, 6, 3];
396 let (tx, rx) = futures::channel::mpsc::unbounded();
397
398 pin_mut!(tx);
399 tx.send_all(&mut futures::stream::iter(sent.clone()).map(Ok)).await?;
400
401 let rx = rx.sequencer(Duration::from_millis(10), 4);
402 pin_mut!(rx);
403
404 assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(1)))));
405 assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(2)))));
406 assert!(matches!(rx.next().await, Some(Err(SessionError::FrameDiscarded(3)))));
407 assert!(matches!(rx.next().await, Some(Ok(4))));
408 assert!(matches!(rx.next().await, Some(Ok(5))));
409 assert!(matches!(rx.next().await, Some(Ok(6))));
410 assert!(matches!(rx.next().await, Some(Ok(7))));
411 assert!(matches!(rx.next().await, Some(Ok(8))));
412
413 Ok(())
414 }
415
416 #[test_log::test(tokio::test)]
417 async fn sequencer_must_terminate_on_last_frame_id() -> anyhow::Result<()> {
418 let (tx, rx) = futures::channel::mpsc::unbounded();
419
420 pin_mut!(tx);
421 tx.send_all(&mut futures::stream::iter([FrameId::MAX - 1, FrameId::MAX, 1, 2]).map(Ok))
422 .await?;
423
424 let mut rx = rx.sequencer(Duration::from_millis(10), 1024);
425 rx.next_id = FrameId::MAX - 1;
426 pin_mut!(rx);
427
428 const LAST_ID: FrameId = FrameId::MAX - 1;
429 assert!(matches!(rx.next().await, Some(Ok(LAST_ID))));
430 assert!(matches!(rx.next().await, Some(Ok(FrameId::MAX))));
431 assert!(rx.next().await.is_none());
432
433 Ok(())
434 }
435
436 #[test_log::test(tokio::test(flavor = "multi_thread"))]
437 async fn sequencer_must_not_discard_frames_when_buffer_was_empty_after_timeout() -> anyhow::Result<()> {
438 let (tx, rx) = futures::channel::mpsc::unbounded();
439
440 let jh = tokio::task::spawn(async move {
441 tokio::time::sleep(Duration::from_millis(2)).await;
442 pin_mut!(tx);
443 tx.send_all(&mut futures::stream::iter([3, 1, 2, 4]).map(Ok)).await?;
444
445 tokio::time::sleep(Duration::from_millis(150)).await;
446
447 tx.send_all(&mut futures::stream::iter([6, 5, 7]).map(Ok)).await?;
448
449 anyhow::Ok(())
450 });
451
452 let chunks = rx
453 .sequencer(Duration::from_millis(50), 1024)
454 .try_ready_chunks(10)
455 .try_collect::<Vec<Vec<_>>>()
456 .await?;
457
458 assert_eq!(chunks, vec![vec![1, 2, 3, 4], vec![5, 6, 7]]);
459 jh.await??;
460
461 Ok(())
462 }
463}