Skip to main content

hopr_protocol_session/processing/
reassembly.rs

1//! Contains the frame [`Reassembler`]:
2//! an inverse component to the `Segmenter`.
3
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use futures_time::future::Timer;
12use tracing::instrument;
13
14use crate::{
15    errors::SessionError,
16    processing::types::{
17        FrameBuilder, FrameDashMap, FrameHashMap, FrameInspector, FrameMap, FrameMapEntry, FrameMapOccupiedEntry,
18        FrameMapVacantEntry,
19    },
20    protocol::{Frame, FrameId, Segment},
21};
22
23#[cfg(all(not(test), feature = "telemetry"))]
24lazy_static::lazy_static! {
25    static ref METRIC_TIME_TO_FRAME_FINISH: hopr_types::telemetry::SimpleHistogram =
26        hopr_types::telemetry::SimpleHistogram::new(
27            "hopr_session_time_to_finish_frame",
28            "Measures time in milliseconds it takes a frame to be reassembled",
29            vec![1.0, 2.0, 5.0, 10.0, 25.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 400.0, 500.0],
30        ).unwrap();
31}
32
33/// Reassembler is a stream adaptor that reads [`Segments`](Segment) from the underlying
34/// stream and tries to put them into correct order so they form a [`Frame`].
35///
36/// This is essentially the inverse of the `Segmenter`.
37///
38/// Reassembler takes two parameters: `max_age` and `capacity`:
39///
40/// The `max_age` specifies how long an incomplete Frame (with a missing segment(s)) is to be kept
41/// in the internal buffer until it is considered definitely lost.
42/// In other words, it specifies how long the reassembler is allowed to wait for all segments
43/// of a Frame to arrive from the underlying stream.
44///
45/// The `capacity` specifies the maximum number of incomplete frames to keep in
46/// the internal buffer. If the reassembler is at maximum capacity, the underlying stream is not
47/// polled for new segments, leaving the oldest incomplete frames in the reassembler to expire and
48/// be definitely lost.
49///
50/// By definition, Reassembler is a fallible stream, yielding either `Ok(Some(`[`Frame`]`))`,
51/// `Err(`[`SessionError::FrameDiscarded`]`)` when a frame is lost due to expiry, or `Ok(None)` when
52/// there are no more elements in the underlying stream.
53///
54/// The reassemblers internal buffer is stored in a [`FrameMap`] and can be constructed using
55/// different implementations of it, suitable for different use-cases.
56///
57/// Use [`ReassemblerExt`] methods to turn a ` Segment ` stream into a fallible `Frame` stream using the `Reassembler`.
58#[must_use = "streams do nothing unless polled"]
59#[pin_project::pin_project]
60pub struct Reassembler<S, M> {
61    #[pin]
62    inner: S,
63    #[pin]
64    timer: futures_time::task::Sleep,
65    incomplete_frames: M,
66    expired_frames: Vec<FrameId>,
67    max_age: Duration,
68    capacity: usize,
69    last_expiration: Option<Instant>,
70}
71
72impl<S: futures::Stream<Item = Segment>, M: FrameMap> Reassembler<S, M> {
73    fn new(inner: S, incomplete_frames: M, max_age: Duration, capacity: usize) -> Self {
74        Self {
75            inner,
76            timer: futures_time::task::sleep(
77                (max_age + Duration::from_millis(1))
78                    .max(Duration::from_millis(1))
79                    .into(),
80            ),
81            incomplete_frames,
82            expired_frames: Vec::with_capacity(capacity),
83            last_expiration: None,
84            max_age,
85            capacity,
86        }
87    }
88
89    fn expire_frames(incomplete_frames: &mut M, expired_frames: &mut Vec<FrameId>, max_age: Duration) {
90        incomplete_frames.retain(|id, builder| {
91            if builder.last_recv.elapsed() >= max_age {
92                expired_frames.push(*id);
93                false
94            } else {
95                true
96            }
97        });
98    }
99}
100
101impl<S: futures::Stream<Item = Segment>, M: FrameMap> futures::Stream for Reassembler<S, M> {
102    type Item = Result<Frame, SessionError>;
103
104    #[instrument(name = "Reassembler::poll_next", level = "trace", skip(self, cx), fields(num_incomplete = self.incomplete_frames.len()), ret)]
105    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        let mut this = self.project();
107        loop {
108            if let Some(frame_id) = this.expired_frames.pop() {
109                tracing::trace!(frame_id, "emit discarded frame");
110                return Poll::Ready(Some(Err(SessionError::FrameDiscarded(frame_id))));
111            }
112
113            // Poll the timer only if there are incomplete frames
114            let timer_poll = if this.incomplete_frames.len() > 0 {
115                this.timer.as_mut().poll(cx)
116            } else {
117                Poll::Pending
118            };
119
120            // Poll the inner stream only if there's space in the reassembler
121            let inner_poll = if this.incomplete_frames.len() < *this.capacity {
122                this.inner.as_mut().poll_next(cx)
123            } else {
124                // This essentially forces the incomplete frames to be expired
125                tracing::warn!("reassembler has reached its capacity");
126                Poll::Pending
127            };
128
129            tracing::trace!("polling next");
130            match (inner_poll, timer_poll) {
131                (Poll::Ready(Some(item)), timer) => {
132                    if timer.is_ready() {
133                        this.timer.as_mut().reset_timer();
134                    }
135
136                    tracing::trace!(
137                        frame_id = item.frame_id,
138                        seq_idx = item.seq_idx,
139                        seq_len = %item.seq_flags,
140                        "received segment"
141                    );
142
143                    match this.incomplete_frames.entry(item.frame_id) {
144                        FrameMapEntry::Occupied(mut e) => {
145                            let builder = e.get_builder_mut();
146                            let seg_id = item.id();
147                            match builder.add_segment(item) {
148                                Ok(_) => {
149                                    tracing::trace!(frame_id = builder.frame_id(), %seg_id, "added segment");
150                                    if builder.is_complete() {
151                                        #[cfg(all(not(test), feature = "telemetry"))]
152                                        METRIC_TIME_TO_FRAME_FINISH
153                                            .observe(builder.created.elapsed().as_millis() as f64);
154
155                                        tracing::trace!(frame_id = builder.frame_id(), "frame is complete");
156                                        return Poll::Ready(Some(e.finalize().try_into()));
157                                    }
158                                }
159                                Err(error) => {
160                                    tracing::error!(%error, %seg_id, "encountered invalid segment");
161                                }
162                            }
163                        }
164                        FrameMapEntry::Vacant(e) => {
165                            let builder = FrameBuilder::from(item);
166                            if builder.is_complete() {
167                                #[cfg(all(not(test), feature = "telemetry"))]
168                                METRIC_TIME_TO_FRAME_FINISH.observe(builder.created.elapsed().as_millis() as f64);
169
170                                tracing::trace!(frame_id = builder.frame_id(), "segment frame is complete");
171                                return Poll::Ready(Some(builder.try_into()));
172                            } else {
173                                tracing::trace!(frame_id = builder.frame_id(), "added segment for new frame");
174                                e.insert_builder(builder);
175                            }
176                        }
177                    };
178
179                    // Since the retaining operation is potentially expensive,
180                    // we do it actually only if there's a real chance that a frame is expired
181                    if this.last_expiration.is_none_or(|e| e.elapsed() >= *this.max_age) {
182                        Self::expire_frames(this.incomplete_frames, this.expired_frames, *this.max_age);
183                        *this.last_expiration = Some(Instant::now());
184                    }
185                }
186                (Poll::Ready(None), _) => {
187                    // Inner stream closed, dump all incomplete frames
188                    tracing::trace!("inner stream closed, dumping incomplete frames");
189                    if this.incomplete_frames.len() > 0 {
190                        this.incomplete_frames.retain(|id, _| {
191                            this.expired_frames.push(*id);
192                            false
193                        });
194                    } else {
195                        tracing::trace!("done");
196                        return Poll::Ready(None);
197                    }
198                }
199                (Poll::Pending, Poll::Ready(_)) => {
200                    // Check if some frames are expired
201                    Self::expire_frames(this.incomplete_frames, this.expired_frames, *this.max_age);
202                    *this.last_expiration = Some(Instant::now());
203                    this.timer.as_mut().reset_timer();
204                }
205                (Poll::Pending, Poll::Pending) => return Poll::Pending,
206            }
207        }
208    }
209}
210
211/// Stream extension methods for frame reassembly.
212pub trait ReassemblerExt: futures::Stream<Item = Segment> {
213    /// Attaches a [`Reassembler`] with the given `timeout` for frame completion and `capacity`
214    /// to this stream.
215    fn reassembler(self, timeout: Duration, capacity: usize) -> Reassembler<Self, FrameHashMap>
216    where
217        Self: Sized,
218    {
219        // FrameHashMap is much faster than a FrameDashMap used in a FrameInspector
220        Reassembler::new(
221            self,
222            FrameHashMap::with_capacity(FrameInspector::INCOMPLETE_FRAME_RATIO * capacity + 1),
223            timeout,
224            capacity,
225        )
226    }
227
228    /// Attaches a [`Reassembler`] with the given `timeout` for frame completion, `capacity`
229    /// to this stream and [`FrameInspector`].
230    ///
231    /// Use only in situations where the [`FrameInspector`] is really needed, as such Reassembler
232    /// is slower than a Reassembler without a `FrameInspector`.
233    fn reassembler_with_inspector(
234        self,
235        timeout: Duration,
236        capacity: usize,
237        inspector: FrameInspector,
238    ) -> Reassembler<Self, FrameDashMap>
239    where
240        Self: Sized,
241    {
242        Reassembler::new(self, inspector.0.clone(), timeout, capacity)
243    }
244}
245
246impl<T: ?Sized> ReassemblerExt for T where T: futures::Stream<Item = Segment> {}
247
248#[cfg(test)]
249mod tests {
250    use std::cmp::Ordering;
251
252    use anyhow::anyhow;
253    use futures::{SinkExt, StreamExt, TryStreamExt, pin_mut};
254    use futures_time::future::FutureExt;
255    use hex_literal::hex;
256    use rand::{SeedableRng, prelude::SliceRandom, rngs::StdRng};
257
258    use super::*;
259    use crate::utils::segment;
260
261    const RNG_SEED: [u8; 32] = hex!("d8a471f1c20490a3442b96fdde9d1807428096e1601b0cef0eea7e6d44a24c01");
262
263    fn result_comparator(a: &Result<Frame, SessionError>, b: &Result<Frame, SessionError>) -> Ordering {
264        match (a, b) {
265            (Ok(a), Ok(b)) => a.frame_id.cmp(&b.frame_id),
266            (Err(SessionError::FrameDiscarded(a)), Ok(b)) => a.cmp(&b.frame_id),
267            (Ok(a), Err(SessionError::FrameDiscarded(b))) => a.frame_id.cmp(b),
268            (Err(SessionError::FrameDiscarded(a)), Err(SessionError::FrameDiscarded(b))) => a.cmp(b),
269            _ => panic!("unexpected result"),
270        }
271    }
272
273    #[test_log::test(tokio::test)]
274    pub async fn reassembler_should_reassemble_frames() -> anyhow::Result<()> {
275        let expected = (1u32..=10)
276            .map(|frame_id| Frame {
277                frame_id,
278                data: hopr_types::crypto_random::random_bytes::<100>().into(),
279                is_terminating: false,
280            })
281            .collect::<Vec<_>>();
282
283        let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
284        let r_stream = r_stream.reassembler(Duration::from_secs(5), 1024);
285
286        let mut segments = expected
287            .iter()
288            .cloned()
289            .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
290            .collect::<Vec<_>>();
291
292        let mut rng = StdRng::from_seed(RNG_SEED);
293        segments.shuffle(&mut rng);
294
295        let jh = hopr_utils::runtime::prelude::spawn(futures::stream::iter(segments).map(Ok).forward(r_sink));
296
297        let mut actual = r_stream
298            .try_collect::<Vec<_>>()
299            .timeout(futures_time::time::Duration::from_secs(5))
300            .await??;
301
302        assert_eq!(actual.len(), expected.len());
303
304        actual.sort_by_key(|a| a.frame_id);
305        assert_eq!(actual, expected);
306
307        let _ = jh.await?;
308        Ok(())
309    }
310
311    #[test_log::test(tokio::test)]
312    pub async fn reassembler_should_discard_incomplete_frames_on_expiration() -> anyhow::Result<()> {
313        let expected = (1u32..=10)
314            .map(|frame_id| Frame {
315                frame_id,
316                data: hopr_types::crypto_random::random_bytes::<100>().into(),
317                is_terminating: false,
318            })
319            .collect::<Vec<_>>();
320
321        let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
322        let r_stream = r_stream.reassembler(Duration::from_millis(45), 1024);
323
324        let mut segments = expected
325            .iter()
326            .cloned()
327            .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
328            .filter(|s| s.frame_id != 2 || s.seq_idx != 1)
329            .collect::<Vec<_>>();
330
331        let mut rng = StdRng::from_seed(RNG_SEED);
332        segments.shuffle(&mut rng);
333
334        pin_mut!(r_sink);
335        r_sink.send_all(&mut futures::stream::iter(segments).map(Ok)).await?;
336
337        let mut actual = Vec::new();
338        pin_mut!(r_stream);
339        for _ in 0..expected.len() {
340            actual.push(r_stream.next().await.ok_or(anyhow!("missing frame"))?);
341        }
342        r_sink.close().await?;
343        assert_eq!(None, r_stream.try_next().await?);
344
345        actual.sort_by(result_comparator);
346
347        assert_eq!(actual.len(), expected.len());
348
349        for i in 0..expected.len() {
350            if i != 1 {
351                assert!(matches!(&actual[i], Ok(f) if *f == expected[i]));
352            } else {
353                // Frame 2 had a missing segment; therefore, there should be an error
354                assert!(matches!(actual[i], Err(SessionError::FrameDiscarded(2))));
355            }
356        }
357
358        Ok(())
359    }
360
361    #[test_log::test(tokio::test)]
362    pub async fn reassembler_should_discard_incomplete_frames_on_close() -> anyhow::Result<()> {
363        let expected = (1u32..=10)
364            .map(|frame_id| Frame {
365                frame_id,
366                data: hopr_types::crypto_random::random_bytes::<100>().into(),
367                is_terminating: false,
368            })
369            .collect::<Vec<_>>();
370
371        let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
372        let r_stream = r_stream.reassembler(Duration::from_millis(100), 1024);
373
374        let mut segments = expected
375            .iter()
376            .cloned()
377            .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
378            .filter(|s| s.frame_id != 5 || s.seq_idx != 2)
379            .collect::<Vec<_>>();
380
381        let mut rng = StdRng::from_seed(RNG_SEED);
382        segments.shuffle(&mut rng);
383
384        let jh = hopr_utils::runtime::prelude::spawn(futures::stream::iter(segments).map(Ok).forward(r_sink));
385
386        let mut actual = r_stream
387            .collect::<Vec<_>>()
388            .timeout(futures_time::time::Duration::from_secs(5))
389            .await?;
390
391        // Since `forward` closed the sink, even the incomplete Frame 5 should be yielded as error
392        assert_eq!(actual.len(), expected.len());
393
394        actual.sort_by(result_comparator);
395
396        for i in 0..expected.len() {
397            if i != 4 {
398                assert!(matches!(&actual[i], Ok(f) if *f == expected[i]));
399            } else {
400                // Frame 5 had a missing segment, therefore there should be an error
401                assert!(matches!(actual[i], Err(SessionError::FrameDiscarded(5))));
402            }
403        }
404
405        let _ = jh.await?;
406        Ok(())
407    }
408
409    #[test_log::test(tokio::test)]
410    pub async fn reassembler_should_wait_and_discard_if_full() -> anyhow::Result<()> {
411        let expected = (1u32..=5)
412            .map(|frame_id| Frame {
413                frame_id,
414                data: hopr_types::crypto_random::random_bytes::<30>().into(),
415                is_terminating: false,
416            })
417            .collect::<Vec<_>>();
418
419        let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
420        let r_stream = r_stream.reassembler(Duration::from_millis(200), 3);
421
422        pin_mut!(r_sink);
423        pin_mut!(r_stream);
424
425        // This creates 5 frames with 2 segments each
426        let segments = expected
427            .iter()
428            .cloned()
429            .flat_map(|f| segment(f.data, 20, f.frame_id).unwrap())
430            .collect::<Vec<_>>();
431
432        let to_send = [
433            // Frame 1: Segment 2, Segment 1 missing
434            segments[1].clone(),
435            // Frame 2: Segment 1, Segment 2 missing
436            segments[2].clone(),
437            // Frame 3: Segment 2, Segment 1 missing
438            segments[5].clone(),
439        ];
440
441        let start = Instant::now();
442
443        // Reassembler now contains 3 incomplete frames
444        r_sink.send_all(&mut futures::stream::iter(to_send).map(Ok)).await?;
445
446        // It must not yield anything
447        assert!(
448            r_stream
449                .next()
450                .timeout(futures_time::time::Duration::from_millis(20))
451                .await
452                .is_err()
453        );
454
455        // Entire Frames 4 & 5
456        r_sink
457            .send_all(
458                &mut futures::stream::iter([
459                    segments[6].clone(),
460                    segments[7].clone(),
461                    segments[8].clone(),
462                    segments[9].clone(),
463                ])
464                .map(Ok),
465            )
466            .await?;
467
468        let mut reassembled = Vec::new();
469        for _ in 0..5 {
470            reassembled.push(r_stream.next().await.ok_or(anyhow!("missing frame"))?);
471        }
472        reassembled.sort_by(result_comparator);
473
474        assert!(
475            matches!(reassembled[0], Err(SessionError::FrameDiscarded(1))),
476            "{:?} must be discarded ID 1",
477            reassembled[0]
478        );
479        assert!(
480            matches!(reassembled[1], Err(SessionError::FrameDiscarded(2))),
481            "{:?} must be discarded ID 2",
482            reassembled[1]
483        );
484        assert!(
485            matches!(reassembled[2], Err(SessionError::FrameDiscarded(3))),
486            "{:?} must be discarded ID 3",
487            reassembled[2]
488        );
489        assert!(
490            matches!(&reassembled[3], Ok(f) if f == &expected[3].clone()),
491            "{:?} (idx 3) must be {:?}",
492            reassembled[3],
493            expected[3]
494        );
495        assert!(
496            matches!(&reassembled[4], Ok(f) if f == &expected[4].clone()),
497            "{:?} (idx 3) must be {:?}",
498            reassembled[4],
499            expected[4]
500        );
501
502        r_sink.close().await?;
503        assert_eq!(None, r_stream.try_next().await?);
504
505        assert!(start.elapsed() >= Duration::from_millis(200));
506
507        Ok(())
508    }
509}