hopr_network_types/session/
frame.rs

1//! This module implements segmentation of [frames][Frame] into [segments][Segment] and
2//! their [reassembly](FrameReassembler) back into [`Frames`](Frame) and their sequencing.
3//!
4//! ## Frames
5//! Contain data of arbitrary length up to 65536 bytes, differently sized frames are supported.
6//! Each frame carries a [`frame_id`](FrameId) which
7//! should be unique within some higher level session. Frame ID ranges from 1 to 2^32-1.
8//! Frame ID of 0 is not allowed, and its segments cannot be pushed to the reassembler.
9//!
10//! ## Segmentation
11//! A [frame](Frame) can be [segmented](Frame::segment) into equally sized [`Segments`](Segment),
12//! each of them carrying its [sequence number](SeqNum).
13//! This operation runs in linear time with respect to the size of the frame.
14//! There can be up to 256 segments per frame.
15//! Frame segments are uniquely identified via [`SegmentId`].
16//!
17//! ## Reassembly
18//! This is an inverse operation to segmentation. Reassembly is performed by a [`FrameReassembler`]
19//! and is implemented lock-free. The reassembler acts as a [`Sink`] for [`Segments`](Segment) and
20//! is always paired with a [`Stream`] that outputs the reassembled [`Frames`](Frame).
21//!
22//! ### Ordering
23//! The reassembled frames will always have the segments in correct order, and complete frames emitted
24//! from the reassembler will also be ordered correctly according to their frame IDs.
25//! If the next frame in sequence cannot be completed within the `max_age` period given
26//! upon [construction](FrameReassembler::new) of the reassembler, [`NetworkTypeError::FrameDiscarded`]
27//! error will be emitted by the reassembler (see the next section).
28//!
29//! ### Expiration
30//! The reassembler also implements segment expiration. Upon [construction](FrameReassembler::new), the maximum
31//! incomplete frame age can be specified. If a frame is not completed in the reassembler within
32//! this period, it can be [evicted](FrameReassembler::evict) from the reassembler, so that it will be lost
33//! forever.
34//! The eviction operation is supposed to be run periodically, so that the space could be freed up in the
35//! reassembler, and the reassembler does not wait indefinitely for the next frame in sequence.
36//!
37//! Beware that once eviction is performed and an incomplete frame with ID `n` is destroyed;
38//! the caller should make sure that frames with ID <= `n` will not arrive into the reassembler,
39//! otherwise the [NetworkTypeError::OldSegment] error will be thrown.
40
41use bitvec::array::BitArray;
42use bitvec::{bitarr, BitArr};
43use dashmap::mapref::entry::Entry;
44use dashmap::DashMap;
45use futures::{Sink, Stream};
46use std::collections::BinaryHeap;
47use std::fmt::{Debug, Display, Formatter};
48use std::mem;
49use std::ops::{Add, Sub};
50use std::pin::Pin;
51use std::sync::atomic::{AtomicU32, AtomicU64, AtomicU8, Ordering};
52use std::sync::OnceLock;
53use std::task::{Context, Poll};
54use std::time::{Duration, SystemTime};
55
56use hopr_platform::time::native::current_time;
57use hopr_primitive_types::prelude::AsUnixTimestamp;
58
59use crate::errors::NetworkTypeError;
60use crate::session::errors::SessionError;
61
62/// ID of a [Frame].
63pub type FrameId = u32;
64/// Type representing the sequence numbers in a [Frame].
65pub type SeqNum = u8;
66
67/// Convenience type that identifies a segment within a frame.
68#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
70pub struct SegmentId(pub FrameId, pub SeqNum);
71
72impl From<&Segment> for SegmentId {
73    fn from(value: &Segment) -> Self {
74        value.id()
75    }
76}
77
78impl Display for SegmentId {
79    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80        write!(f, "seg({},{})", self.0, self.1)
81    }
82}
83
84/// Helper function to segment `data` into segments of a given ` max_segment_size ` length.
85/// All segments are tagged with the same `frame_id`.
86pub fn segment(data: &[u8], max_segment_size: usize, frame_id: u32) -> crate::session::errors::Result<Vec<Segment>> {
87    if frame_id == 0 {
88        return Err(SessionError::InvalidFrameId);
89    }
90
91    if max_segment_size == 0 {
92        return Err(SessionError::InvalidSegmentSize);
93    }
94
95    let num_chunks = data.len().div_ceil(max_segment_size);
96    if num_chunks > SeqNum::MAX as usize {
97        return Err(SessionError::DataTooLong);
98    }
99
100    let chunks = data.chunks(max_segment_size);
101
102    let seq_len = chunks.len() as SeqNum;
103    Ok(chunks
104        .enumerate()
105        .map(|(idx, data)| Segment {
106            frame_id,
107            seq_len,
108            seq_idx: idx as u8,
109            data: data.into(),
110        })
111        .collect())
112}
113
114/// Data frame of arbitrary length.
115/// The frame can be segmented into [segments](Segment) and reassembled back
116/// via [FrameReassembler].
117#[derive(Debug, Clone, PartialEq, Eq)]
118pub struct Frame {
119    /// Identifier of this frame.
120    pub frame_id: FrameId,
121    /// Frame data.
122    pub data: Box<[u8]>,
123}
124
125impl Frame {
126    /// Segments this frame into a list of [segments](Segment) each of maximum sizes `mtu`.
127    #[inline]
128    pub fn segment(&self, max_segment_size: usize) -> crate::session::errors::Result<Vec<Segment>> {
129        segment(self.data.as_ref(), max_segment_size, self.frame_id)
130    }
131}
132
133impl AsRef<[u8]> for Frame {
134    fn as_ref(&self) -> &[u8] {
135        &self.data
136    }
137}
138
139/// Represents a frame segment.
140/// Besides the data, a segment carries information about the total number of
141/// segments in the original frame, its index within the frame and
142/// ID of that frame.
143#[derive(Clone, Eq, PartialEq)]
144#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
145pub struct Segment {
146    /// ID of the [Frame] this segment belongs to.
147    pub frame_id: FrameId,
148    /// Index of this segment within the segment sequence.
149    pub seq_idx: SeqNum,
150    /// Total number of segments within this segment sequence.
151    pub seq_len: SeqNum,
152    /// Data in this segment.
153    #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
154    pub data: Box<[u8]>,
155}
156
157impl Segment {
158    /// Size of the segment header.
159    pub const HEADER_SIZE: usize = mem::size_of::<FrameId>() + 2 * mem::size_of::<SeqNum>();
160
161    /// The minimum size of a segment: [`Segment::HEADER_SIZE`] + 1 byte of data.
162    pub const MINIMUM_SIZE: usize = Self::HEADER_SIZE + 1;
163
164    /// Returns the [SegmentId] for this segment.
165    pub fn id(&self) -> SegmentId {
166        SegmentId(self.frame_id, self.seq_idx)
167    }
168}
169
170impl Debug for Segment {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        f.debug_struct("Segment")
173            .field("frame_id", &self.frame_id)
174            .field("seq_id", &self.seq_idx)
175            .field("seq_len", &self.seq_len)
176            .field("data", &hex::encode(&self.data))
177            .finish()
178    }
179}
180
181impl PartialOrd<Segment> for Segment {
182    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
183        Some(self.cmp(other))
184    }
185}
186
187impl Ord for Segment {
188    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
189        match self.frame_id.cmp(&other.frame_id) {
190            std::cmp::Ordering::Equal => self.seq_idx.cmp(&other.seq_idx),
191            cmp => cmp,
192        }
193    }
194}
195
196impl From<Segment> for Vec<u8> {
197    fn from(value: Segment) -> Self {
198        let mut ret = Vec::with_capacity(Segment::HEADER_SIZE + value.data.len());
199        ret.extend_from_slice(value.frame_id.to_be_bytes().as_ref());
200        ret.extend_from_slice(value.seq_idx.to_be_bytes().as_ref());
201        ret.extend_from_slice(value.seq_len.to_be_bytes().as_ref());
202        ret.extend_from_slice(value.data.as_ref());
203        ret
204    }
205}
206
207impl TryFrom<&[u8]> for Segment {
208    type Error = SessionError;
209
210    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
211        let (header, data) = value.split_at(Self::HEADER_SIZE);
212        let segment = Segment {
213            frame_id: FrameId::from_be_bytes(header[0..4].try_into().map_err(|_| SessionError::InvalidSegment)?),
214            seq_idx: SeqNum::from_be_bytes(header[4..5].try_into().map_err(|_| SessionError::InvalidSegment)?),
215            seq_len: SeqNum::from_be_bytes(header[5..6].try_into().map_err(|_| SessionError::InvalidSegment)?),
216            data: data.into(),
217        };
218        (segment.frame_id > 0 && segment.seq_idx < segment.seq_len)
219            .then_some(segment)
220            .ok_or(SessionError::InvalidSegment)
221    }
222}
223
224/// Rebuilds the [Frame] from [Segments](Segment).
225#[derive(Debug)]
226struct FrameBuilder {
227    frame_id: FrameId,
228    segments: Vec<OnceLock<Box<[u8]>>>,
229    remaining: AtomicU8,
230    last_ts: AtomicU64,
231}
232
233impl FrameBuilder {
234    /// Creates a new builder with the given `initial` [Segment] and its timestamp `ts`.
235    fn new(initial: Segment, ts: SystemTime) -> Self {
236        let ret = Self::empty(initial.frame_id, initial.seq_len);
237        ret.put(initial, ts).unwrap();
238        ret
239    }
240
241    /// Creates a new empty builder for the given frame.
242    fn empty(frame_id: FrameId, seq_len: SeqNum) -> Self {
243        Self {
244            frame_id,
245            segments: vec![OnceLock::new(); seq_len as usize],
246            remaining: AtomicU8::new(seq_len),
247            last_ts: AtomicU64::new(0),
248        }
249    }
250
251    /// Adds a new [`segment`](Segment) to the builder with a timestamp `ts`.
252    /// Returns the number of segments remaining in this builder.
253    fn put(&self, segment: Segment, ts: SystemTime) -> crate::session::errors::Result<SeqNum> {
254        if self.frame_id == segment.frame_id {
255            if !self.is_complete() {
256                if self.segments[segment.seq_idx as usize].set(segment.data).is_ok() {
257                    // A new segment has been added, decrease the remaining number and update timestamp
258                    self.remaining.fetch_sub(1, Ordering::Relaxed);
259                    self.last_ts
260                        .fetch_max(ts.as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
261                }
262                Ok(self.remaining.load(Ordering::SeqCst))
263            } else {
264                // Silently throw away segments of a frame that is already complete
265                Ok(0)
266            }
267        } else {
268            Err(SessionError::InvalidFrameId)
269        }
270    }
271
272    /// Checks if the builder contains all segments of the frame.
273    fn is_complete(&self) -> bool {
274        self.remaining.load(Ordering::SeqCst) == 0
275    }
276
277    /// Checks if the last added segment to this frame happened before `cutoff`.
278    /// In other words, the frame under construction is considered expired if the last
279    /// segment was added before `cutoff`.
280    fn is_expired(&self, cutoff: u64) -> bool {
281        self.last_ts.load(Ordering::SeqCst) < cutoff
282    }
283
284    /// Returns information about the frame that is being built by this builder.
285    pub fn info(&self) -> FrameInfo {
286        let mut missing_segments = bitarr![0; 256];
287        self.segments
288            .iter()
289            .enumerate()
290            .filter_map(|(i, s)| s.get().is_none().then_some(i))
291            .for_each(|i| missing_segments.set(i, true));
292
293        FrameInfo {
294            frame_id: self.frame_id,
295            missing_segments,
296            total_segments: self.segments.len() as SeqNum,
297            last_update: SystemTime::UNIX_EPOCH.add(Duration::from_millis(self.last_ts.load(Ordering::SeqCst))),
298        }
299    }
300
301    /// Reassembles the [Frame]. Returns [`NetworkTypeError::IncompleteFrame`] if not [complete](FrameBuilder::is_complete).
302    fn reassemble(self) -> crate::session::errors::Result<Frame> {
303        if self.is_complete() {
304            Ok(Frame {
305                frame_id: self.frame_id,
306                data: self
307                    .segments
308                    .into_iter()
309                    .map(|lock| lock.into_inner().unwrap())
310                    .collect::<Vec<Box<[u8]>>>()
311                    .concat()
312                    .into_boxed_slice(),
313            })
314        } else {
315            Err(SessionError::IncompleteFrame(self.frame_id))
316        }
317    }
318}
319
320/// Contains information about a frame that being built.
321/// The instances are totally ordered as most recently used first.
322#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct FrameInfo {
324    /// ID of the frame.
325    pub frame_id: FrameId,
326    /// Indices of segments that are missing. Empty if the frame is complete.
327    pub missing_segments: BitArr!(for 256),
328    /// The total number of segments in this frame.
329    pub total_segments: SeqNum,
330    /// Time of the last received segment in this frame.
331    pub last_update: SystemTime,
332}
333
334impl FrameInfo {
335    /// Transform self into iterator of missing segment numbers.
336    pub fn iter_missing_sequence_indices(&self) -> impl Iterator<Item = SeqNum> + '_ {
337        self.missing_segments
338            .iter()
339            .by_vals()
340            .enumerate()
341            .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
342            .map(|(s, _)| s as SeqNum)
343    }
344
345    pub fn into_missing_segments(self) -> impl Iterator<Item = SegmentId> {
346        self.missing_segments
347            .into_iter()
348            .enumerate()
349            .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
350            .map(move |(i, _)| SegmentId(self.frame_id, i as SeqNum))
351    }
352}
353
354impl PartialOrd for FrameInfo {
355    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
356        Some(self.cmp(other))
357    }
358}
359
360impl Ord for FrameInfo {
361    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
362        match self.last_update.cmp(&other.last_update) {
363            std::cmp::Ordering::Equal => self.frame_id.cmp(&self.frame_id),
364            cmp => cmp,
365        }
366        .reverse()
367    }
368}
369
370/// Represents a frame reassembler.
371///
372/// The [`FrameReassembler`] behaves as a [`Sink`] for [`Segment`].
373/// Upon creation, also [`Stream`] for reassembled [Frames](Frame) is created.
374/// The corresponding stream is closed either when the reassembler is dropped or
375/// [`futures::SinkExt::close`] is called.
376///
377/// As new segments are [pushed](FrameReassembler::push_segment) into the reassembler,
378/// the frames get reassembled, and once they are completed, they are automatically pushed out into
379/// the outgoing frame stream.
380///
381/// The reassembler can also have a `max_age` of frames that are under construction.
382/// The [`evict`](FrameReassembler::evict) method then can be called to remove
383/// the incomplete frames over `max_age`. The timestamps are measured with millisecond precision.
384///
385/// Note that the reassembler is also evicted when dropped.
386///
387/// ````rust
388/// # use std::time::Duration;
389/// futures::executor::block_on(async {
390/// use hopr_network_types::session::{Frame, FrameReassembler};
391/// use futures::{pin_mut, StreamExt, TryStreamExt};
392///
393/// let bytes = b"deadbeefcafe00112233";
394///
395/// // Build Frame and segment it
396/// let frame = Frame { frame_id: 1, data: bytes.as_ref().into() };
397/// let segments = frame.segment(2).unwrap();
398/// assert_eq!(bytes.len() / 2, segments.len());
399///
400/// // Create FrameReassembler and feed the segments to it
401/// let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(10));
402///
403/// for segment in segments {
404///     fragmented.push_segment(segment).unwrap();
405/// }
406///
407/// drop(fragmented);
408/// pin_mut!(reassembled);
409///
410/// assert!(matches!(reassembled.try_next().await, Ok(Some(frame))));
411/// # });
412/// ````
413#[derive(Debug)]
414pub struct FrameReassembler {
415    sequences: DashMap<FrameId, FrameBuilder>,
416    highest_buffered_frame: AtomicU32,
417    next_emitted_frame: AtomicU32,
418    last_emission: AtomicU64,
419    reassembled: futures::channel::mpsc::UnboundedSender<crate::session::errors::Result<Frame>>,
420    max_age: Duration,
421}
422
423impl FrameReassembler {
424    /// Creates a new frame reassembler and a corresponding stream
425    /// for reassembled [Frames](Frame).
426    /// An optional `max_age` of segments can be specified,
427    /// which allows the [`evict`](FrameReassembler::evict) method to remove stale incomplete segments.
428    pub fn new(max_age: Duration) -> (Self, impl Stream<Item = crate::session::errors::Result<Frame>>) {
429        let (reassembled, reassembled_recv) = futures::channel::mpsc::unbounded();
430        (
431            Self {
432                sequences: DashMap::new(),
433                highest_buffered_frame: AtomicU32::new(0),
434                next_emitted_frame: AtomicU32::new(1),
435                last_emission: AtomicU64::new(u64::MAX),
436                reassembled,
437                max_age,
438            },
439            reassembled_recv,
440        )
441    }
442
443    /// Emits the frame if it is the next in sequence and complete.
444    /// If it is not next in the sequence or incomplete, it is discarded forever.
445    fn emit_if_complete_discard_otherwise(&self, builder: FrameBuilder) -> crate::session::errors::Result<()> {
446        if self.next_emitted_frame.fetch_add(1, Ordering::SeqCst) == builder.frame_id && builder.is_complete() {
447            self.reassembled
448                .unbounded_send(builder.reassemble())
449                .map_err(|_| SessionError::ReassemblerClosed)?;
450        } else {
451            self.reassembled
452                .unbounded_send(Err(SessionError::FrameDiscarded(builder.frame_id)))
453                .map_err(|_| SessionError::ReassemblerClosed)?;
454        }
455        self.last_emission
456            .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
457        Ok(())
458    }
459
460    /// Pushes a new [Segment] for reassembly.
461    /// This function also pushes out the reassembled frame if this segment completed it.
462    /// Pushing a segment belonging to a frame ID that has been already
463    /// previously completed or [evicted](FrameReassembler::evict) will fail.
464    pub fn push_segment(&self, segment: Segment) -> crate::session::errors::Result<()> {
465        if self.reassembled.is_closed() {
466            return Err(SessionError::ReassemblerClosed);
467        }
468
469        // Check if this frame has not been emitted yet.
470        let frame_id = segment.frame_id;
471        if frame_id < self.next_emitted_frame.load(Ordering::SeqCst) {
472            return Err(SessionError::OldSegment(frame_id));
473        }
474
475        let ts = current_time();
476        let mut cascade = false;
477
478        match self.sequences.entry(frame_id) {
479            Entry::Occupied(e) => {
480                // No more segments missing in this frame, check if it is the next on to emit
481                if e.get().put(segment, ts)? == 0
482                    && self
483                        .next_emitted_frame
484                        .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
485                        .is_ok()
486                {
487                    // Emit this complete frame
488                    self.reassembled
489                        .unbounded_send(e.remove().reassemble())
490                        .map_err(|_| SessionError::ReassemblerClosed)?;
491                    self.last_emission
492                        .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
493                    cascade = true; // Try to emit next frames in sequence
494                }
495            }
496            Entry::Vacant(v) => {
497                let builder = FrameBuilder::new(segment, ts);
498                // If this frame is already complete, check if it is the next one to emit
499                if builder.is_complete()
500                    && self
501                        .next_emitted_frame
502                        .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
503                        .is_ok()
504                {
505                    // Emit this frame if already complete
506                    self.reassembled
507                        .unbounded_send(builder.reassemble())
508                        .map_err(|_| SessionError::ReassemblerClosed)?;
509                    self.last_emission
510                        .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
511                    cascade = true; // Try to emit the next frames in sequence
512                } else {
513                    // If not complete nor the next one to be emitted, just start building it
514                    v.insert(builder);
515                    self.highest_buffered_frame.fetch_max(frame_id, Ordering::Relaxed);
516                }
517            }
518        }
519
520        // As long as there are more in-sequence frames completed, emit them
521        if cascade {
522            while let Some((_, builder)) = self
523                .sequences
524                .remove_if(&self.next_emitted_frame.load(Ordering::SeqCst), |_, b| b.is_complete())
525            {
526                // If the frame is complete, push it out as reassembled
527                self.emit_if_complete_discard_otherwise(builder)?;
528            }
529        }
530
531        Ok(())
532    }
533
534    /// Returns [information](FrameInfo) about the incomplete frames.
535    /// The returned collection is ordered by frame IDs.
536    pub fn incomplete_frames(&self) -> BinaryHeap<FrameInfo> {
537        (self.next_emitted_frame.load(Ordering::SeqCst)..=self.highest_buffered_frame.load(Ordering::SeqCst))
538            .filter_map(|frame_id| match self.sequences.get(&frame_id) {
539                Some(e) => (!e.is_complete()).then(|| e.info()),
540                None => Some({
541                    let mut missing_segments = BitArray::ZERO;
542                    missing_segments.set(0, true);
543                    FrameInfo {
544                        frame_id,
545                        missing_segments,
546                        total_segments: 1,
547                        last_update: SystemTime::UNIX_EPOCH,
548                    }
549                }),
550            })
551            .collect()
552    }
553
554    /// According to the [max_age](FrameReassembler::new) set during construction, evicts
555    /// leading incomplete frames that are expired at the time this method was called.
556    /// Returns that total number of frames that were evicted.
557    pub fn evict(&self) -> crate::session::errors::Result<usize> {
558        if self.reassembled.is_closed() {
559            return Err(SessionError::ReassemblerClosed);
560        }
561
562        if self.sequences.is_empty() {
563            return Ok(0);
564        }
565
566        let cutoff = current_time().sub(self.max_age).as_unix_timestamp().as_millis() as u64;
567        let mut count = 0;
568        loop {
569            let next = self.next_emitted_frame.load(Ordering::SeqCst);
570            if let Some((_, builder)) = self
571                .sequences
572                .remove_if(&next, |_, b| b.is_complete() || b.is_expired(cutoff))
573            {
574                // If the frame is complete, push it out as reassembled or discard it as expired
575                self.emit_if_complete_discard_otherwise(builder)?;
576                count += 1;
577            } else if !self.sequences.contains_key(&next) && self.last_emission.load(Ordering::SeqCst) < cutoff {
578                // Do not stall the sequencer too long if we haven't seen this frame at all
579                self.next_emitted_frame.fetch_add(1, Ordering::Relaxed);
580                self.last_emission
581                    .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
582                count += 1;
583            } else {
584                // Break on the first incomplete and non-expired frame
585                break;
586            }
587        }
588
589        Ok(count)
590    }
591
592    /// Closes the reassembler.
593    /// Any subsequent calls to [`FrameReassembler::push_segment`] will fail.
594    pub fn close(&self) {
595        self.reassembled.close_channel();
596    }
597}
598
599impl Drop for FrameReassembler {
600    fn drop(&mut self) {
601        let _ = self.evict();
602        self.reassembled.close_channel();
603    }
604}
605
606impl Extend<Segment> for FrameReassembler {
607    fn extend<T: IntoIterator<Item = Segment>>(&mut self, iter: T) {
608        iter.into_iter()
609            .try_for_each(|s| self.push_segment(s))
610            .expect("failed to extend")
611    }
612}
613
614impl Sink<Segment> for FrameReassembler {
615    type Error = NetworkTypeError;
616
617    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
618        Poll::Ready(Ok(()))
619    }
620
621    fn start_send(self: Pin<&mut Self>, item: Segment) -> Result<(), Self::Error> {
622        Ok(self.push_segment(item)?)
623    }
624
625    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
626        Poll::Ready(self.evict().map(|_| ()).map_err(NetworkTypeError::SessionProtocolError))
627    }
628
629    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630        self.reassembled.close_channel();
631        Poll::Ready(Ok(()))
632    }
633}
634
635#[cfg(test)]
636pub(crate) mod tests {
637    use super::*;
638    use async_stream::stream;
639    use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
640    use hex_literal::hex;
641    use lazy_static::lazy_static;
642    use rand::prelude::{Distribution, SliceRandom};
643    use rand::{seq::IteratorRandom, thread_rng, Rng, SeedableRng};
644    use rand_distr::Normal;
645    use rayon::prelude::*;
646    use std::collections::{HashSet, VecDeque};
647    use std::convert::identity;
648    use std::sync::atomic::{AtomicBool, Ordering};
649    use std::sync::Arc;
650    use std::time::Duration;
651
652    const MTU: usize = 448;
653    const FRAME_COUNT: u32 = 65_535;
654    const FRAME_SIZE: usize = 4096;
655    const MIXING_FACTOR: f64 = 4.0;
656
657    lazy_static! {
658        // bd8e89c13f96c29377528865424efa7380a8c8e5cdd486b0fc508c9130ab39ef
659        static ref RAND_SEED: [u8; 32] = hopr_crypto_random::random_bytes();
660        static ref FRAMES: Vec<Frame> = (0..FRAME_COUNT)
661            .into_par_iter()
662            .map(|frame_id| Frame {
663                frame_id: frame_id + 1,
664                data: hopr_crypto_random::random_bytes::<FRAME_SIZE>().into(),
665            })
666            .collect::<Vec<_>>();
667        static ref SEGMENTS: Vec<Segment> = {
668            let vec = FRAMES.par_iter().flat_map(|f| f.segment(MTU).unwrap()).collect::<VecDeque<_>>();
669            let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
670            linear_half_normal_shuffle(&mut rng, vec, MIXING_FACTOR)
671        };
672    }
673
674    /// Sample an index between `0` and `len - 1` using the given distribution and RNG.
675    pub fn sample_index<T: Distribution<f64>, R: Rng>(dist: &mut T, rng: &mut R, len: usize) -> usize {
676        let f: f64 = dist.sample(rng);
677        (f.max(0.0).round() as usize).min(len - 1)
678    }
679
680    /// Shuffles the given `vec` by taking a next element with index `|N(0,factor^2)`|, where
681    /// `N` denotes normal distribution.
682    /// When used on frame segments vector, it will shuffle the segments in a controlled manner;
683    /// such that an entire frame can unlikely swap position with another, if `factor` ~ frame length.
684    fn linear_half_normal_shuffle<T, R: Rng>(rng: &mut R, mut vec: VecDeque<T>, factor: f64) -> Vec<T> {
685        if factor == 0.0 || vec.is_empty() {
686            return vec.into(); // no mixing
687        }
688
689        let mut dist = Normal::new(0.0, factor).unwrap();
690        let mut ret = Vec::new();
691        while !vec.is_empty() {
692            ret.push(vec.remove(sample_index(&mut dist, rng, vec.len())).unwrap());
693        }
694        ret
695    }
696
697    #[ctor::ctor]
698    fn init() {
699        lazy_static::initialize(&FRAMES);
700        lazy_static::initialize(&SEGMENTS);
701    }
702
703    #[test]
704    fn segmentation_should_segment_data_correctly() -> anyhow::Result<()> {
705        let data = hex!("deadbeefcafebabe");
706        let frame = Frame {
707            frame_id: 1,
708            data: data.as_ref().into(),
709        };
710
711        let segments = frame.segment(3)?;
712        assert_eq!(3, segments.len());
713
714        assert_eq!(hex!("deadbe"), segments[0].data.as_ref());
715        assert_eq!(0, segments[0].seq_idx);
716        assert_eq!(3, segments[0].seq_len);
717        assert_eq!(frame.frame_id, segments[0].frame_id);
718
719        assert_eq!(hex!("efcafe"), segments[1].data.as_ref());
720        assert_eq!(1, segments[1].seq_idx);
721        assert_eq!(3, segments[1].seq_len);
722        assert_eq!(frame.frame_id, segments[1].frame_id);
723
724        assert_eq!(hex!("babe"), segments[2].data.as_ref());
725        assert_eq!(2, segments[2].seq_idx);
726        assert_eq!(3, segments[2].seq_len);
727        assert_eq!(frame.frame_id, segments[2].frame_id);
728
729        Ok(())
730    }
731
732    #[test]
733    fn segment_must_serialize_and_deserialize() {
734        let data = hopr_crypto_random::random_bytes::<128>();
735
736        let segment = Segment {
737            frame_id: 1234,
738            seq_len: 123,
739            seq_idx: 12,
740            data: data.into(),
741        };
742
743        let boxed: Vec<u8> = segment.clone().into();
744        let recovered: Segment = (&boxed[..]).try_into().unwrap();
745
746        assert_eq!(segment, recovered);
747    }
748
749    #[async_std::test]
750    async fn frame_reassembler_must_process_ordered_frames() -> anyhow::Result<()> {
751        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
752
753        FRAMES
754            .iter()
755            .flat_map(|f| f.segment(MTU).unwrap())
756            .try_for_each(|s| fragmented.push_segment(s))?;
757
758        drop(fragmented);
759        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
760
761        reassembled_frames
762            .into_par_iter()
763            .enumerate()
764            .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
765
766        Ok(())
767    }
768
769    #[async_std::test]
770    async fn frame_reassembler_must_process_single_frame() -> anyhow::Result<()> {
771        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(10));
772
773        let data = hex!("cafe");
774
775        let segment = Segment {
776            frame_id: 1,
777            seq_idx: 0,
778            seq_len: 1,
779            data: hex!("cafe").into(),
780        };
781
782        fragmented.push_segment(segment)?;
783        drop(fragmented);
784        let mut reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
785
786        assert_eq!(1, reassembled_frames.len());
787        let frame = reassembled_frames.pop().ok_or(SessionError::InvalidSegment)?;
788
789        assert_eq!(1, frame.frame_id);
790        assert_eq!(&data, frame.data.as_ref());
791
792        Ok(())
793    }
794
795    #[test]
796    fn should_not_push_frame_id_0_into_reassembler() -> anyhow::Result<()> {
797        let frame = Frame {
798            frame_id: 1,
799            data: hex!("deadbeefcafe").into(),
800        };
801
802        let mut segments = frame.segment(2)?;
803        segments[0].frame_id = 0;
804
805        let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
806        fragmented
807            .push_segment(segments[0].clone())
808            .expect_err("must not push frame id 0");
809
810        Ok(())
811    }
812
813    #[test]
814    fn pushing_segment_of_a_completed_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
815        let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
816
817        let segments = FRAMES[0].segment(MTU)?;
818        let segment_1 = segments[0].clone();
819
820        segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
821
822        fragmented
823            .push_segment(segment_1)
824            .expect_err("must fail pushing segment of a completed frame");
825
826        Ok(())
827    }
828
829    #[async_std::test]
830    async fn pushing_segment_of_an_evicted_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
831        let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_millis(5));
832
833        let mut segments = FRAMES[0].segment(MTU)?;
834        let segment_1 = segments.pop().unwrap(); // Remove the first segment
835
836        segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
837
838        async_std::task::sleep(Duration::from_millis(10)).await;
839        assert_eq!(1, fragmented.evict()?);
840
841        fragmented
842            .push_segment(segment_1)
843            .expect_err("must fail pushing segment of an evicted frame");
844
845        Ok(())
846    }
847
848    #[async_std::test]
849    async fn frame_reassembler_reassembles_single_frame() -> anyhow::Result<()> {
850        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
851
852        let mut rng = thread_rng();
853
854        let frame = FRAMES[0].clone();
855        let mut segments = frame.segment(MTU)?;
856        segments.shuffle(&mut rng);
857
858        segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
859
860        drop(fragmented);
861        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
862
863        assert_eq!(1, reassembled_frames.len());
864        assert_eq!(frame, reassembled_frames[0]);
865
866        Ok(())
867    }
868
869    #[async_std::test]
870    async fn frame_reassembler_reassembles_shuffled_randomized_frames() -> anyhow::Result<()> {
871        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
872
873        SEGMENTS.iter().cloned().try_for_each(|b| fragmented.push_segment(b))?;
874
875        assert_eq!(0, fragmented.evict().unwrap());
876        drop(fragmented);
877
878        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
879
880        reassembled_frames
881            .into_par_iter()
882            .enumerate()
883            .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
884
885        Ok(())
886    }
887
888    #[async_std::test]
889    async fn frame_reassembler_reassembles_shuffled_randomized_frames_in_parallel() -> anyhow::Result<()> {
890        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
891
892        SEGMENTS
893            .par_iter()
894            .cloned()
895            .try_for_each(|b| fragmented.push_segment(b))?;
896
897        assert_eq!(0, fragmented.evict()?);
898        drop(fragmented);
899
900        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
901
902        reassembled_frames
903            .into_par_iter()
904            .enumerate()
905            .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
906
907        Ok(())
908    }
909
910    #[async_std::test]
911    async fn frame_reassembler_should_evict_expired_incomplete_frames() -> anyhow::Result<()> {
912        let frames = vec![
913            Frame {
914                frame_id: 1,
915                data: hex!("deadbeefcafebabe").into(),
916            },
917            Frame {
918                frame_id: 2,
919                data: hex!("feedbeefbaadcafe").into(),
920            },
921            Frame {
922                frame_id: 3,
923                data: hex!("00112233abcd").into(),
924            },
925        ];
926
927        let mut segments = frames
928            .iter()
929            .flat_map(|f| f.segment(3).unwrap())
930            .collect::<VecDeque<_>>();
931        segments.retain(|s| s.frame_id != 2 || s.seq_idx != 2); // Remove 2nd segment of Frame 2
932
933        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
934
935        segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
936
937        let frames_cpy = frames.clone();
938        let jh = async_std::task::spawn(async move {
939            pin_mut!(reassembled);
940
941            // Frame #1 should yield immediately
942            assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
943
944            // Frame #2 will yield an error once `evict` has been called
945            assert!(matches!(
946                reassembled.try_next().await,
947                Err(SessionError::FrameDiscarded(2))
948            ));
949
950            // Frame #3 will yield normally
951            assert_eq!(Some(frames_cpy[2].clone()), reassembled.try_next().await?);
952
953            Ok(())
954        });
955
956        async_std::task::sleep(Duration::from_millis(20)).await;
957
958        assert_eq!(2, fragmented.evict()?); // One expired, one complete
959
960        jh.await
961    }
962
963    #[async_std::test]
964    async fn frame_reassembler_should_evict_frame_that_never_arrived() -> anyhow::Result<()> {
965        let frames = vec![
966            Frame {
967                frame_id: 1,
968                data: hex!("deadbeefcafebabe").into(),
969            },
970            Frame {
971                frame_id: 3,
972                data: hex!("00112233abcd").into(),
973            },
974        ];
975
976        let segments = frames
977            .iter()
978            .flat_map(|f| f.segment(3).unwrap())
979            .collect::<VecDeque<_>>();
980
981        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
982
983        segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
984
985        let flushed = Arc::new(AtomicBool::new(false));
986
987        let flushed_cpy = flushed.clone();
988        let frames_cpy = frames.clone();
989        let jh = async_std::task::spawn(async move {
990            pin_mut!(reassembled);
991
992            // The first frame should yield immediately
993            assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
994
995            assert!(!flushed_cpy.load(Ordering::SeqCst));
996
997            // The next frame is the third one
998            assert_eq!(Some(frames_cpy[1].clone()), reassembled.try_next().await?);
999
1000            // and it must've happened only after pruning
1001            assert!(flushed_cpy.load(Ordering::SeqCst));
1002
1003            Ok(())
1004        });
1005
1006        async_std::task::sleep(Duration::from_millis(20)).await;
1007
1008        // Prune the expired entry, which is Frame 2 (that is missing a segment)
1009        flushed.store(true, Ordering::SeqCst);
1010        assert_eq!(2, fragmented.evict()?); // One expired, one complete
1011
1012        jh.await
1013    }
1014
1015    #[async_std::test]
1016    async fn frame_reassembler_reassembles_randomized_delayed_frames_in_parallel() -> anyhow::Result<()> {
1017        let frames = FRAMES.iter().take(100).collect::<Vec<_>>();
1018
1019        let segments = frames
1020            .iter()
1021            .flat_map(|frame| frame.segment(MTU).unwrap())
1022            .collect::<Vec<_>>();
1023
1024        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
1025
1026        futures::stream::iter(segments)
1027            .map(|segment| {
1028                let delay = Duration::from_millis(thread_rng().gen_range(0..10u64));
1029                async_std::task::spawn(async move {
1030                    async_std::task::sleep(delay).await;
1031                    Ok(segment)
1032                })
1033            })
1034            .buffer_unordered(4)
1035            .forward(fragmented)
1036            .await
1037            .unwrap();
1038
1039        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1040
1041        reassembled_frames
1042            .into_par_iter()
1043            .enumerate()
1044            .for_each(|(i, frame)| assert_eq!(&frame, frames[i]));
1045
1046        Ok(())
1047    }
1048
1049    /// Creates `num_frames` out of which `num_corrupted` will have missing segments.
1050    fn corrupt_frames(
1051        num_frames: u32,
1052        corrupted_ratio: f32,
1053    ) -> (Vec<Segment>, Vec<&'static Frame>, HashSet<SegmentId>) {
1054        assert!((0.0..=1.0).contains(&corrupted_ratio));
1055
1056        let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
1057
1058        let (excluded_frame_ids, excluded_segments): (HashSet<FrameId>, HashSet<SegmentId>) = (1..num_frames + 1)
1059            .choose_multiple(&mut rng, ((num_frames as f32) * corrupted_ratio) as usize)
1060            .into_iter() // Must be sequentially generated due RNG determinism
1061            .map(|frame_id| {
1062                (
1063                    frame_id,
1064                    SegmentId(
1065                        frame_id,
1066                        rng.gen_range(0..SEGMENTS.iter().find(|s| s.frame_id == frame_id).unwrap().seq_len),
1067                    ),
1068                )
1069            })
1070            .unzip();
1071
1072        let segments = SEGMENTS
1073            .par_iter()
1074            .filter(|s| s.frame_id < num_frames && !excluded_segments.contains(&SegmentId(s.frame_id, s.seq_idx)))
1075            .cloned()
1076            .collect::<Vec<_>>();
1077
1078        let expected_frames = FRAMES
1079            .par_iter()
1080            .filter(|f| f.frame_id < num_frames && !excluded_frame_ids.contains(&f.frame_id))
1081            .collect::<Vec<_>>();
1082
1083        (segments, expected_frames, excluded_segments)
1084    }
1085
1086    #[async_std::test]
1087    async fn frame_reassembler_yields_correct_frames_when_also_corrupted_frames_are_present() -> anyhow::Result<()> {
1088        // Corrupt 30% of the frames, by removing a random segment from them
1089        let (segments, expected_frames, excluded) = corrupt_frames(FRAME_COUNT / 4, 0.3);
1090
1091        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1092
1093        segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
1094
1095        let computed_missing = fragmented
1096            .incomplete_frames()
1097            .into_par_iter()
1098            .flat_map_iter(|e| e.into_missing_segments())
1099            .collect::<HashSet<_>>();
1100
1101        assert!(computed_missing.par_iter().all(|s| excluded.contains(s)));
1102        /*assert!(
1103            excluded.par_iter().all(|s| computed_missing.contains(&s)),
1104            "seed {}",
1105            hex::encode(RAND_SEED.clone())
1106        );*/
1107
1108        async_std::task::sleep(Duration::from_millis(25)).await;
1109        drop(fragmented);
1110
1111        let (reassembled_frames, discarded_frames) = reassembled
1112            .map(|f| match f {
1113                Ok(f) => (Some(f), None),
1114                Err(e) => (None, Some(e)),
1115            })
1116            .unzip::<_, _, Vec<_>, Vec<_>>()
1117            .await;
1118
1119        let reassembled_frames = reassembled_frames
1120            .into_par_iter()
1121            .filter_map(identity)
1122            .collect::<Vec<_>>();
1123
1124        (reassembled_frames, expected_frames)
1125            .into_par_iter()
1126            .all(|(a, b)| a.eq(b));
1127
1128        let discarded_frames = discarded_frames
1129            .into_par_iter()
1130            .filter_map(|s| match s {
1131                Some(SessionError::FrameDiscarded(f)) => Some(f),
1132                _ => None,
1133            })
1134            .collect::<Vec<_>>();
1135
1136        let expected_discarded_frames = excluded.into_par_iter().map(|s| s.0).collect::<Vec<_>>();
1137
1138        (discarded_frames, expected_discarded_frames)
1139            .into_par_iter()
1140            .all(|(a, b)| a == b);
1141
1142        Ok(())
1143    }
1144
1145    #[async_std::test]
1146    async fn frame_reassembler_yields_no_frames_when_all_corrupted() -> anyhow::Result<()> {
1147        // Corrupt each frame
1148        let (segments, expected_frames, _) = corrupt_frames(1000, 1.0);
1149        assert!(expected_frames.is_empty());
1150
1151        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(100));
1152
1153        segments.into_par_iter().try_for_each(|s| fragmented.push_segment(s))?;
1154        drop(fragmented);
1155
1156        let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1157
1158        assert!(reassembled_frames.is_empty());
1159
1160        Ok(())
1161    }
1162
1163    fn create_unreliable_segment_stream(
1164        num_frames: usize,
1165        max_latency: Duration,
1166        mixing_factor: f64,
1167        corruption_ratio: f64,
1168    ) -> (impl Stream<Item = Segment>, Vec<&'static Frame>) {
1169        let mut segments = FRAMES
1170            .par_iter()
1171            .take(num_frames)
1172            .flat_map(|f| f.segment(MTU).unwrap())
1173            .collect::<VecDeque<_>>();
1174
1175        let (corrupted_frames, corrupted_segments): (HashSet<FrameId>, HashSet<SegmentId>) = segments
1176            .iter()
1177            .choose_multiple(
1178                &mut thread_rng(),
1179                (segments.len() as f64 * corruption_ratio).round() as usize,
1180            )
1181            .into_par_iter()
1182            .map(|s| (s.frame_id, SegmentId(s.frame_id, s.seq_idx)))
1183            .unzip();
1184
1185        (
1186            stream! {
1187                let mut rng = thread_rng();
1188                let mut distr = Normal::new(0.0, mixing_factor).unwrap();
1189                while !segments.is_empty() {
1190                    let segment = segments.remove(sample_index(&mut distr, &mut rng, segments.len())).unwrap();
1191
1192                    if !corrupted_segments.contains(&SegmentId(segment.frame_id, segment.seq_idx)) {
1193                        async_std::task::sleep(max_latency.mul_f64(rng.gen())).await;
1194                        yield segment;
1195                    }
1196                }
1197            },
1198            FRAMES
1199                .par_iter()
1200                .filter(|f| !corrupted_frames.contains(&f.frame_id))
1201                .collect(),
1202        )
1203    }
1204
1205    #[async_std::test]
1206    async fn frame_reassembler_yields_and_evicts_frames_on_unreliable_network() -> anyhow::Result<()> {
1207        let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1208        let fragmented = Arc::new(fragmented);
1209
1210        let done = Arc::new(AtomicBool::new(false));
1211        let done_clone = done.clone();
1212        let frag_clone = fragmented.clone();
1213        let eviction_jh = async_std::task::spawn(async move {
1214            while !done_clone.load(Ordering::SeqCst) {
1215                async_std::task::sleep(Duration::from_millis(25)).await;
1216                frag_clone.evict().unwrap();
1217            }
1218        });
1219
1220        // Corrupt 20% of the frames
1221        let (stream, expected_frames) =
1222            create_unreliable_segment_stream(200, Duration::from_millis(2), MIXING_FACTOR, 0.2);
1223        stream
1224            .map(Ok)
1225            .try_for_each(|s| futures::future::ready(fragmented.push_segment(s)))
1226            .await?;
1227
1228        done.store(true, Ordering::SeqCst);
1229        eviction_jh.await;
1230        drop(fragmented);
1231
1232        let reassembled_frames = reassembled
1233            .filter(|f| futures::future::ready(f.is_ok())) // Skip the discarded frames
1234            .try_collect::<Vec<_>>()
1235            .await?;
1236        reassembled_frames
1237            .into_iter()
1238            .enumerate()
1239            .for_each(|(i, frame)| assert_eq!(&frame, expected_frames[i]));
1240
1241        Ok(())
1242    }
1243}