1use std::{
42 collections::BinaryHeap,
43 fmt::{Debug, Display, Formatter},
44 mem,
45 ops::{Add, Sub},
46 pin::Pin,
47 sync::{
48 OnceLock,
49 atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering},
50 },
51 task::{Context, Poll},
52 time::{Duration, SystemTime},
53};
54
55use bitvec::{BitArr, bitarr, prelude::Msb0};
56use dashmap::{DashMap, mapref::entry::Entry};
57use futures::{Sink, Stream};
58use hopr_platform::time::native::current_time;
59use hopr_primitive_types::prelude::AsUnixTimestamp;
60
61use crate::{errors::NetworkTypeError, session::errors::SessionError};
62
63pub type FrameId = u32;
65pub type SeqNum = u8;
67
68#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
70#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
71pub struct SegmentId(pub FrameId, pub SeqNum);
72
73const EVICTION_TIME_THRESHOLD_MS: u64 = 50;
74const PUSH_TIME_THRESHOLD_MS: u64 = 50;
75
76impl From<&Segment> for SegmentId {
77 fn from(value: &Segment) -> Self {
78 value.id()
79 }
80}
81
82impl Display for SegmentId {
83 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
84 write!(f, "seg({},{})", self.0, self.1)
85 }
86}
87
88pub fn segment(data: &[u8], max_segment_size: usize, frame_id: u32) -> crate::session::errors::Result<Vec<Segment>> {
91 if frame_id == 0 {
92 return Err(SessionError::InvalidFrameId);
93 }
94
95 if max_segment_size == 0 {
96 return Err(SessionError::InvalidSegmentSize);
97 }
98
99 let num_chunks = data.len().div_ceil(max_segment_size);
100 if num_chunks > SeqNum::MAX as usize {
101 return Err(SessionError::DataTooLong);
102 }
103
104 let chunks = data.chunks(max_segment_size);
105
106 let seq_len = chunks.len() as SeqNum;
107 Ok(chunks
108 .enumerate()
109 .map(|(idx, data)| Segment {
110 frame_id,
111 seq_len,
112 seq_idx: idx as u8,
113 data: data.into(),
114 })
115 .collect())
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct Frame {
123 pub frame_id: FrameId,
125 pub data: Box<[u8]>,
127}
128
129impl Frame {
130 #[inline]
132 pub fn segment(&self, max_segment_size: usize) -> crate::session::errors::Result<Vec<Segment>> {
133 segment(self.data.as_ref(), max_segment_size, self.frame_id)
134 }
135}
136
137impl AsRef<[u8]> for Frame {
138 fn as_ref(&self) -> &[u8] {
139 &self.data
140 }
141}
142
143#[derive(Clone, Eq, PartialEq)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
149pub struct Segment {
150 pub frame_id: FrameId,
152 pub seq_idx: SeqNum,
154 pub seq_len: SeqNum,
156 #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
158 pub data: Box<[u8]>,
159}
160
161impl Segment {
162 pub const HEADER_SIZE: usize = mem::size_of::<FrameId>() + 2 * mem::size_of::<SeqNum>();
164 pub const MINIMUM_SIZE: usize = Self::HEADER_SIZE + 1;
166
167 pub fn id(&self) -> SegmentId {
169 SegmentId(self.frame_id, self.seq_idx)
170 }
171}
172
173impl Debug for Segment {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("Segment")
176 .field("frame_id", &self.frame_id)
177 .field("seq_id", &self.seq_idx)
178 .field("seq_len", &self.seq_len)
179 .field("data", &hex::encode(&self.data))
180 .finish()
181 }
182}
183
184impl PartialOrd<Segment> for Segment {
185 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
186 Some(self.cmp(other))
187 }
188}
189
190impl Ord for Segment {
191 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
192 match self.frame_id.cmp(&other.frame_id) {
193 std::cmp::Ordering::Equal => self.seq_idx.cmp(&other.seq_idx),
194 cmp => cmp,
195 }
196 }
197}
198
199impl From<Segment> for Vec<u8> {
200 fn from(value: Segment) -> Self {
201 let mut ret = Vec::with_capacity(Segment::HEADER_SIZE + value.data.len());
202 ret.extend_from_slice(value.frame_id.to_be_bytes().as_ref());
203 ret.extend_from_slice(value.seq_idx.to_be_bytes().as_ref());
204 ret.extend_from_slice(value.seq_len.to_be_bytes().as_ref());
205 ret.extend_from_slice(value.data.as_ref());
206 ret
207 }
208}
209
210impl TryFrom<&[u8]> for Segment {
211 type Error = SessionError;
212
213 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
214 let (header, data) = value.split_at(Self::HEADER_SIZE);
215 let segment = Segment {
216 frame_id: FrameId::from_be_bytes(header[0..4].try_into().map_err(|_| SessionError::InvalidSegment)?),
217 seq_idx: SeqNum::from_be_bytes(header[4..5].try_into().map_err(|_| SessionError::InvalidSegment)?),
218 seq_len: SeqNum::from_be_bytes(header[5..6].try_into().map_err(|_| SessionError::InvalidSegment)?),
219 data: data.into(),
220 };
221 (segment.frame_id > 0 && segment.seq_idx < segment.seq_len)
222 .then_some(segment)
223 .ok_or(SessionError::InvalidSegment)
224 }
225}
226
227#[derive(Debug)]
229struct FrameBuilder {
230 frame_id: FrameId,
231 _initiated: std::time::Instant,
232 segments: Vec<OnceLock<Box<[u8]>>>,
233 remaining: AtomicU8,
234 last_ts: AtomicU64,
235}
236
237impl FrameBuilder {
238 fn new(initial: Segment, ts: SystemTime) -> Self {
240 let ret = Self::empty(initial.frame_id, initial.seq_len);
241 ret.put(initial, ts).unwrap();
242 ret
243 }
244
245 fn empty(frame_id: FrameId, seq_len: SeqNum) -> Self {
247 Self {
248 frame_id,
249 _initiated: std::time::Instant::now(),
250 segments: vec![OnceLock::new(); seq_len as usize],
251 remaining: AtomicU8::new(seq_len),
252 last_ts: AtomicU64::new(0),
253 }
254 }
255
256 fn put(&self, segment: Segment, ts: SystemTime) -> crate::session::errors::Result<SeqNum> {
259 if self.frame_id == segment.frame_id {
260 if !self.is_complete() {
261 if self.segments[segment.seq_idx as usize].set(segment.data).is_ok() {
262 self.remaining.fetch_sub(1, Ordering::Relaxed);
264 self.last_ts
265 .fetch_max(ts.as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
266 }
267 Ok(self.remaining.load(Ordering::SeqCst))
268 } else {
269 Ok(0)
271 }
272 } else {
273 Err(SessionError::InvalidFrameId)
274 }
275 }
276
277 fn is_complete(&self) -> bool {
279 self.remaining.load(Ordering::SeqCst) == 0
280 }
281
282 fn is_expired(&self, cutoff: u64) -> bool {
286 self.last_ts.load(Ordering::SeqCst) < cutoff
287 }
288
289 pub fn info(&self) -> FrameInfo {
291 let mut missing_segments = NO_MISSING_SEGMENTS;
292 self.segments
293 .iter()
294 .enumerate()
295 .take(SeqNum::BITS as usize) .filter_map(|(i, s)| s.get().is_none().then_some(i))
297 .for_each(|i| missing_segments.set(i, true));
298
299 FrameInfo {
300 frame_id: self.frame_id,
301 missing_segments,
302 total_segments: self.segments.len() as SeqNum,
303 last_update: SystemTime::UNIX_EPOCH.add(Duration::from_millis(self.last_ts.load(Ordering::SeqCst))),
304 }
305 }
306
307 fn reassemble(self) -> crate::session::errors::Result<Frame> {
310 if self.is_complete() {
311 Ok(Frame {
312 frame_id: self.frame_id,
313 data: self
314 .segments
315 .into_iter()
316 .map(|lock| lock.into_inner().unwrap())
317 .collect::<Vec<Box<[u8]>>>()
318 .concat()
319 .into_boxed_slice(),
320 })
321 } else {
322 Err(SessionError::IncompleteFrame(self.frame_id))
323 }
324 }
325}
326
327pub type MissingSegmentsBitmap = BitArr!(for 1, in SeqNum, Msb0);
333pub const NO_MISSING_SEGMENTS: MissingSegmentsBitmap = bitarr![SeqNum, Msb0; 0; SeqNum::BITS as usize];
334
335#[derive(Debug, Clone, PartialEq, Eq)]
338pub struct FrameInfo {
339 pub frame_id: FrameId,
341 pub missing_segments: MissingSegmentsBitmap,
343 pub total_segments: SeqNum,
345 pub last_update: SystemTime,
347}
348
349impl FrameInfo {
350 pub fn iter_missing_sequence_indices(&self) -> impl Iterator<Item = SeqNum> + '_ {
352 self.missing_segments
353 .iter()
354 .by_vals()
355 .enumerate()
356 .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
357 .map(|(s, _)| s as SeqNum)
358 }
359
360 pub fn into_missing_segments(self) -> impl Iterator<Item = SegmentId> {
361 self.missing_segments
362 .into_iter()
363 .enumerate()
364 .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
365 .map(move |(i, _)| SegmentId(self.frame_id, i as SeqNum))
366 }
367}
368
369impl PartialOrd for FrameInfo {
370 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
371 Some(self.cmp(other))
372 }
373}
374
375impl Ord for FrameInfo {
376 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
377 match self.last_update.cmp(&other.last_update) {
378 std::cmp::Ordering::Equal => self.frame_id.cmp(&self.frame_id),
379 cmp => cmp,
380 }
381 .reverse()
382 }
383}
384
385#[derive(Debug)]
429pub struct FrameReassembler {
430 sequences: DashMap<FrameId, FrameBuilder>,
431 highest_buffered_frame: AtomicU32,
432 next_emitted_frame: AtomicU32,
433 last_emission: AtomicU64,
434 reassembled: futures::channel::mpsc::UnboundedSender<crate::session::errors::Result<Frame>>,
435 max_age: Duration,
436}
437
438impl FrameReassembler {
439 pub fn new(max_age: Duration) -> (Self, impl Stream<Item = crate::session::errors::Result<Frame>>) {
444 let (reassembled, reassembled_recv) = futures::channel::mpsc::unbounded();
445 (
446 Self {
447 sequences: DashMap::new(),
448 highest_buffered_frame: AtomicU32::new(0),
449 next_emitted_frame: AtomicU32::new(1),
450 last_emission: AtomicU64::new(u64::MAX),
451 reassembled,
452 max_age,
453 },
454 reassembled_recv,
455 )
456 }
457
458 fn emit_if_complete_discard_otherwise(&self, builder: FrameBuilder) -> crate::session::errors::Result<()> {
461 let time_spent = builder._initiated.elapsed();
462 let frame_id = builder.frame_id;
463
464 if self.next_emitted_frame.fetch_add(1, Ordering::SeqCst) == builder.frame_id && builder.is_complete() {
465 self.reassembled
466 .unbounded_send(builder.reassemble())
467 .map_err(|_| SessionError::ReassemblerClosed)?;
468 } else {
469 self.reassembled
470 .unbounded_send(Err(SessionError::FrameDiscarded(builder.frame_id)))
471 .map_err(|_| SessionError::ReassemblerClosed)?;
472 }
473 self.last_emission
474 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
475
476 tracing::trace!(frame_id, ?time_spent, "frame finished");
477
478 Ok(())
479 }
480
481 pub fn push_segment(&self, segment: Segment) -> crate::session::errors::Result<()> {
486 if self.reassembled.is_closed() {
487 return Err(SessionError::ReassemblerClosed);
488 }
489
490 let start = std::time::Instant::now();
491
492 let frame_id = segment.frame_id;
494 if frame_id < self.next_emitted_frame.load(Ordering::SeqCst) {
495 tracing::trace!("trying to push segment of a frame that has been emitted");
496 return Err(SessionError::OldSegment(frame_id));
497 }
498
499 let ts = current_time();
500 let mut cascade = false;
501
502 match self.sequences.entry(frame_id) {
503 Entry::Occupied(e) => {
504 if e.get().put(segment, ts)? == 0
506 && self
507 .next_emitted_frame
508 .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
509 .is_ok()
510 {
511 let builder = e.remove();
512 let time_spent = builder._initiated.elapsed();
513
514 self.reassembled
516 .unbounded_send(builder.reassemble())
517 .map_err(|_| SessionError::ReassemblerClosed)?;
518 self.last_emission
519 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
520
521 tracing::trace!(frame_id, ?time_spent, "frame finished");
522
523 cascade = true; }
525 }
526 Entry::Vacant(v) => {
527 let builder = FrameBuilder::new(segment, ts);
528 if builder.is_complete()
530 && self
531 .next_emitted_frame
532 .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
533 .is_ok()
534 {
535 let time_spent = builder._initiated.elapsed();
536
537 self.reassembled
539 .unbounded_send(builder.reassemble())
540 .map_err(|_| SessionError::ReassemblerClosed)?;
541 self.last_emission
542 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
543
544 tracing::trace!(frame_id, ?time_spent, "frame finished");
545
546 cascade = true; } else {
548 v.insert(builder);
550 self.highest_buffered_frame.fetch_max(frame_id, Ordering::Relaxed);
551 }
552 }
553 }
554
555 if cascade {
557 while let Some((_, builder)) = self
558 .sequences
559 .remove_if(&self.next_emitted_frame.load(Ordering::SeqCst), |_, b| b.is_complete())
560 {
561 self.emit_if_complete_discard_otherwise(builder)?;
563 }
564 }
565
566 let push_time = start.elapsed();
567 if push_time > Duration::from_millis(PUSH_TIME_THRESHOLD_MS) {
568 tracing::trace!(?push_time, "segment push done");
569 }
570
571 Ok(())
572 }
573
574 pub fn incomplete_frames(&self) -> BinaryHeap<FrameInfo> {
577 (self.next_emitted_frame.load(Ordering::SeqCst)..=self.highest_buffered_frame.load(Ordering::SeqCst))
578 .filter_map(|frame_id| match self.sequences.get(&frame_id) {
579 Some(e) => (!e.is_complete()).then(|| e.info()),
580 None => Some({
581 let mut missing_segments = NO_MISSING_SEGMENTS;
582 missing_segments.set(0, true);
583 FrameInfo {
584 frame_id,
585 missing_segments,
586 total_segments: 1,
587 last_update: SystemTime::UNIX_EPOCH,
588 }
589 }),
590 })
591 .collect()
592 }
593
594 pub fn evict(&self) -> crate::session::errors::Result<usize> {
598 if self.reassembled.is_closed() {
599 return Err(SessionError::ReassemblerClosed);
600 }
601
602 if self.sequences.is_empty() {
603 return Ok(0);
604 }
605
606 let start = std::time::Instant::now();
607
608 let cutoff = current_time().sub(self.max_age).as_unix_timestamp().as_millis() as u64;
609 let mut count = 0;
610 loop {
611 let next = self.next_emitted_frame.load(Ordering::SeqCst);
612 if let Some((_, builder)) = self
613 .sequences
614 .remove_if(&next, |_, b| b.is_complete() || b.is_expired(cutoff))
615 {
616 self.emit_if_complete_discard_otherwise(builder)?;
618 count += 1;
619 } else if !self.sequences.contains_key(&next) && self.last_emission.load(Ordering::SeqCst) < cutoff {
620 self.next_emitted_frame.fetch_add(1, Ordering::Relaxed);
622 self.last_emission
623 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
624 count += 1;
625 } else {
626 tracing::trace!(incomplete = self.sequences.len(), "incomplete frames in reassembler");
628 break;
629 }
630 }
631
632 let eviction_time = start.elapsed();
633 if eviction_time > Duration::from_millis(EVICTION_TIME_THRESHOLD_MS) {
634 tracing::trace!(?eviction_time, count, "eviction done");
635 }
636
637 Ok(count)
638 }
639
640 pub fn close(&self) {
643 self.reassembled.close_channel();
644 }
645}
646
647impl Drop for FrameReassembler {
648 fn drop(&mut self) {
649 let _ = self.evict();
650 self.reassembled.close_channel();
651 }
652}
653
654impl Extend<Segment> for FrameReassembler {
655 fn extend<T: IntoIterator<Item = Segment>>(&mut self, iter: T) {
656 iter.into_iter()
657 .try_for_each(|s| self.push_segment(s))
658 .expect("failed to extend")
659 }
660}
661
662impl Sink<Segment> for FrameReassembler {
663 type Error = NetworkTypeError;
664
665 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
666 Poll::Ready(Ok(()))
667 }
668
669 fn start_send(self: Pin<&mut Self>, item: Segment) -> Result<(), Self::Error> {
670 Ok(self.push_segment(item)?)
671 }
672
673 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
674 Poll::Ready(self.evict().map(|_| ()).map_err(NetworkTypeError::SessionProtocolError))
675 }
676
677 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
678 self.reassembled.close_channel();
679 Poll::Ready(Ok(()))
680 }
681}
682
683#[cfg(test)]
684pub(crate) mod tests {
685 use std::{
686 collections::{HashSet, VecDeque},
687 convert::identity,
688 sync::{
689 Arc,
690 atomic::{AtomicBool, Ordering},
691 },
692 time::Duration,
693 };
694
695 use async_stream::stream;
696 use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
697 use hex_literal::hex;
698 use lazy_static::lazy_static;
699 use rand::{
700 Rng, SeedableRng,
701 prelude::{Distribution, SliceRandom},
702 seq::IteratorRandom,
703 thread_rng,
704 };
705 use rand_distr::Normal;
706 use rayon::prelude::*;
707
708 use super::*;
709
710 const MTU: usize = 448;
711 const FRAME_COUNT: u32 = 65_535;
712 const FRAME_SIZE: usize = 3072;
713 const MIXING_FACTOR: f64 = 4.0;
714
715 lazy_static! {
716 static ref RAND_SEED: [u8; 32] = hopr_crypto_random::random_bytes();
718 static ref FRAMES: Vec<Frame> = (0..FRAME_COUNT)
719 .into_par_iter()
720 .map(|frame_id| Frame {
721 frame_id: frame_id + 1,
722 data: hopr_crypto_random::random_bytes::<FRAME_SIZE>().into(),
723 })
724 .collect::<Vec<_>>();
725 static ref SEGMENTS: Vec<Segment> = {
726 let vec = FRAMES.par_iter().flat_map(|f| f.segment(MTU).unwrap()).collect::<VecDeque<_>>();
727 let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
728 linear_half_normal_shuffle(&mut rng, vec, MIXING_FACTOR)
729 };
730 }
731
732 pub fn sample_index<T: Distribution<f64>, R: Rng>(dist: &mut T, rng: &mut R, len: usize) -> usize {
734 let f: f64 = dist.sample(rng);
735 (f.max(0.0).round() as usize).min(len - 1)
736 }
737
738 fn linear_half_normal_shuffle<T, R: Rng>(rng: &mut R, mut vec: VecDeque<T>, factor: f64) -> Vec<T> {
743 if factor == 0.0 || vec.is_empty() {
744 return vec.into(); }
746
747 let mut dist = Normal::new(0.0, factor).unwrap();
748 let mut ret = Vec::new();
749 while !vec.is_empty() {
750 ret.push(vec.remove(sample_index(&mut dist, rng, vec.len())).unwrap());
751 }
752 ret
753 }
754
755 #[ctor::ctor]
756 fn init() {
757 lazy_static::initialize(&FRAMES);
758 lazy_static::initialize(&SEGMENTS);
759 }
760
761 #[test]
762 fn segmentation_should_segment_data_correctly() -> anyhow::Result<()> {
763 let data = hex!("deadbeefcafebabe");
764 let frame = Frame {
765 frame_id: 1,
766 data: data.as_ref().into(),
767 };
768
769 let segments = frame.segment(3)?;
770 assert_eq!(3, segments.len());
771
772 assert_eq!(hex!("deadbe"), segments[0].data.as_ref());
773 assert_eq!(0, segments[0].seq_idx);
774 assert_eq!(3, segments[0].seq_len);
775 assert_eq!(frame.frame_id, segments[0].frame_id);
776
777 assert_eq!(hex!("efcafe"), segments[1].data.as_ref());
778 assert_eq!(1, segments[1].seq_idx);
779 assert_eq!(3, segments[1].seq_len);
780 assert_eq!(frame.frame_id, segments[1].frame_id);
781
782 assert_eq!(hex!("babe"), segments[2].data.as_ref());
783 assert_eq!(2, segments[2].seq_idx);
784 assert_eq!(3, segments[2].seq_len);
785 assert_eq!(frame.frame_id, segments[2].frame_id);
786
787 Ok(())
788 }
789
790 #[test]
791 fn segment_must_serialize_and_deserialize() {
792 let data = hopr_crypto_random::random_bytes::<128>();
793
794 let segment = Segment {
795 frame_id: 1234,
796 seq_len: 123,
797 seq_idx: 12,
798 data: data.into(),
799 };
800
801 let boxed: Vec<u8> = segment.clone().into();
802 let recovered: Segment = (&boxed[..]).try_into().unwrap();
803
804 assert_eq!(segment, recovered);
805 }
806
807 #[tokio::test]
808 async fn frame_reassembler_must_process_ordered_frames() -> anyhow::Result<()> {
809 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
810
811 FRAMES
812 .iter()
813 .flat_map(|f| f.segment(MTU).unwrap())
814 .try_for_each(|s| fragmented.push_segment(s))?;
815
816 drop(fragmented);
817 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
818
819 reassembled_frames
820 .into_par_iter()
821 .enumerate()
822 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
823
824 Ok(())
825 }
826
827 #[tokio::test]
828 async fn frame_reassembler_must_process_single_frame() -> anyhow::Result<()> {
829 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(10));
830
831 let data = hex!("cafe");
832
833 let segment = Segment {
834 frame_id: 1,
835 seq_idx: 0,
836 seq_len: 1,
837 data: hex!("cafe").into(),
838 };
839
840 fragmented.push_segment(segment)?;
841 drop(fragmented);
842 let mut reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
843
844 assert_eq!(1, reassembled_frames.len());
845 let frame = reassembled_frames.pop().ok_or(SessionError::InvalidSegment)?;
846
847 assert_eq!(1, frame.frame_id);
848 assert_eq!(&data, frame.data.as_ref());
849
850 Ok(())
851 }
852
853 #[test]
854 fn should_not_push_frame_id_0_into_reassembler() -> anyhow::Result<()> {
855 let frame = Frame {
856 frame_id: 1,
857 data: hex!("deadbeefcafe").into(),
858 };
859
860 let mut segments = frame.segment(2)?;
861 segments[0].frame_id = 0;
862
863 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
864 fragmented
865 .push_segment(segments[0].clone())
866 .expect_err("must not push frame id 0");
867
868 Ok(())
869 }
870
871 #[test]
872 fn pushing_segment_of_a_completed_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
873 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
874
875 let segments = FRAMES[0].segment(MTU)?;
876 let segment_1 = segments[0].clone();
877
878 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
879
880 fragmented
881 .push_segment(segment_1)
882 .expect_err("must fail pushing segment of a completed frame");
883
884 Ok(())
885 }
886
887 #[tokio::test]
888 async fn pushing_segment_of_an_evicted_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
889 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_millis(5));
890
891 let mut segments = FRAMES[0].segment(MTU)?;
892 let segment_1 = segments.pop().unwrap(); segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
895
896 tokio::time::sleep(Duration::from_millis(10)).await;
897 assert_eq!(1, fragmented.evict()?);
898
899 fragmented
900 .push_segment(segment_1)
901 .expect_err("must fail pushing segment of an evicted frame");
902
903 Ok(())
904 }
905
906 #[tokio::test]
907 async fn frame_reassembler_reassembles_single_frame() -> anyhow::Result<()> {
908 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
909
910 let mut rng = thread_rng();
911
912 let frame = FRAMES[0].clone();
913 let mut segments = frame.segment(MTU)?;
914 segments.shuffle(&mut rng);
915
916 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
917
918 drop(fragmented);
919 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
920
921 assert_eq!(1, reassembled_frames.len());
922 assert_eq!(frame, reassembled_frames[0]);
923
924 Ok(())
925 }
926
927 #[tokio::test]
928 async fn frame_reassembler_reassembles_shuffled_randomized_frames() -> anyhow::Result<()> {
929 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
930
931 SEGMENTS.iter().cloned().try_for_each(|b| fragmented.push_segment(b))?;
932
933 assert_eq!(0, fragmented.evict().unwrap());
934 drop(fragmented);
935
936 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
937
938 reassembled_frames
939 .into_par_iter()
940 .enumerate()
941 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
942
943 Ok(())
944 }
945
946 #[tokio::test]
947 async fn frame_reassembler_reassembles_shuffled_randomized_frames_in_parallel() -> anyhow::Result<()> {
948 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
949
950 SEGMENTS
951 .par_iter()
952 .cloned()
953 .try_for_each(|b| fragmented.push_segment(b))?;
954
955 assert_eq!(0, fragmented.evict()?);
956 drop(fragmented);
957
958 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
959
960 reassembled_frames
961 .into_par_iter()
962 .enumerate()
963 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
964
965 Ok(())
966 }
967
968 #[tokio::test]
969 async fn frame_reassembler_should_evict_expired_incomplete_frames() -> anyhow::Result<()> {
970 let frames = vec![
971 Frame {
972 frame_id: 1,
973 data: hex!("deadbeefcafebabe").into(),
974 },
975 Frame {
976 frame_id: 2,
977 data: hex!("feedbeefbaadcafe").into(),
978 },
979 Frame {
980 frame_id: 3,
981 data: hex!("00112233abcd").into(),
982 },
983 ];
984
985 let mut segments = frames
986 .iter()
987 .flat_map(|f| f.segment(3).unwrap())
988 .collect::<VecDeque<_>>();
989 segments.retain(|s| s.frame_id != 2 || s.seq_idx != 2); let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
992
993 segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
994
995 let frames_cpy = frames.clone();
996 let jh: hopr_async_runtime::prelude::JoinHandle<Result<(), SessionError>> = tokio::task::spawn(async move {
997 pin_mut!(reassembled);
998
999 assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
1001
1002 assert!(matches!(
1004 reassembled.try_next().await,
1005 Err(SessionError::FrameDiscarded(2))
1006 ));
1007
1008 assert_eq!(Some(frames_cpy[2].clone()), reassembled.try_next().await?);
1010
1011 Ok(())
1012 });
1013
1014 tokio::time::sleep(Duration::from_millis(20)).await;
1015
1016 assert_eq!(2, fragmented.evict()?); jh.await??;
1019
1020 Ok(())
1021 }
1022
1023 #[tokio::test]
1024 async fn frame_reassembler_should_evict_frame_that_never_arrived() -> anyhow::Result<()> {
1025 let frames = vec![
1026 Frame {
1027 frame_id: 1,
1028 data: hex!("deadbeefcafebabe").into(),
1029 },
1030 Frame {
1031 frame_id: 3,
1032 data: hex!("00112233abcd").into(),
1033 },
1034 ];
1035
1036 let segments = frames
1037 .iter()
1038 .flat_map(|f| f.segment(3).unwrap())
1039 .collect::<VecDeque<_>>();
1040
1041 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
1042
1043 segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
1044
1045 let flushed = Arc::new(AtomicBool::new(false));
1046
1047 let flushed_cpy = flushed.clone();
1048 let frames_cpy = frames.clone();
1049 let jh: hopr_async_runtime::prelude::JoinHandle<Result<(), SessionError>> = tokio::task::spawn(async move {
1050 pin_mut!(reassembled);
1051
1052 assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
1054
1055 assert!(!flushed_cpy.load(Ordering::SeqCst));
1056
1057 assert_eq!(Some(frames_cpy[1].clone()), reassembled.try_next().await?);
1059
1060 assert!(flushed_cpy.load(Ordering::SeqCst));
1062
1063 Ok(())
1064 });
1065
1066 tokio::time::sleep(Duration::from_millis(20)).await;
1067
1068 flushed.store(true, Ordering::SeqCst);
1070 assert_eq!(2, fragmented.evict()?); jh.await??;
1073
1074 Ok(())
1075 }
1076
1077 #[tokio::test]
1078 async fn frame_reassembler_reassembles_randomized_delayed_frames_in_parallel() -> anyhow::Result<()> {
1079 let frames = FRAMES.iter().take(100).collect::<Vec<_>>();
1080
1081 let segments = frames
1082 .iter()
1083 .flat_map(|frame| frame.segment(MTU).unwrap())
1084 .collect::<Vec<_>>();
1085
1086 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
1087
1088 futures::stream::iter(segments)
1089 .map(|segment| {
1090 let delay = Duration::from_millis(thread_rng().gen_range(0..10u64));
1091 tokio::task::spawn(async move {
1092 tokio::time::sleep(delay).await;
1093 Ok(segment)
1094 })
1095 })
1096 .buffer_unordered(4)
1097 .map(Result::unwrap) .forward(fragmented)
1099 .await
1100 .unwrap();
1101
1102 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1103
1104 reassembled_frames
1105 .into_par_iter()
1106 .enumerate()
1107 .for_each(|(i, frame)| assert_eq!(&frame, frames[i]));
1108
1109 Ok(())
1110 }
1111
1112 fn corrupt_frames(
1114 num_frames: u32,
1115 corrupted_ratio: f32,
1116 ) -> (Vec<Segment>, Vec<&'static Frame>, HashSet<SegmentId>) {
1117 assert!((0.0..=1.0).contains(&corrupted_ratio));
1118
1119 let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
1120
1121 let (excluded_frame_ids, excluded_segments): (HashSet<FrameId>, HashSet<SegmentId>) = (1..num_frames + 1)
1122 .choose_multiple(&mut rng, ((num_frames as f32) * corrupted_ratio) as usize)
1123 .into_iter() .map(|frame_id| {
1125 (
1126 frame_id,
1127 SegmentId(
1128 frame_id,
1129 rng.gen_range(0..SEGMENTS.iter().find(|s| s.frame_id == frame_id).unwrap().seq_len),
1130 ),
1131 )
1132 })
1133 .unzip();
1134
1135 let segments = SEGMENTS
1136 .par_iter()
1137 .filter(|s| s.frame_id < num_frames && !excluded_segments.contains(&SegmentId(s.frame_id, s.seq_idx)))
1138 .cloned()
1139 .collect::<Vec<_>>();
1140
1141 let expected_frames = FRAMES
1142 .par_iter()
1143 .filter(|f| f.frame_id < num_frames && !excluded_frame_ids.contains(&f.frame_id))
1144 .collect::<Vec<_>>();
1145
1146 (segments, expected_frames, excluded_segments)
1147 }
1148
1149 #[tokio::test]
1150 async fn frame_reassembler_yields_correct_frames_when_also_corrupted_frames_are_present() -> anyhow::Result<()> {
1151 let (segments, expected_frames, excluded) = corrupt_frames(FRAME_COUNT / 4, 0.3);
1153
1154 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1155
1156 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
1157
1158 let computed_missing = fragmented
1159 .incomplete_frames()
1160 .into_par_iter()
1161 .flat_map_iter(|e| e.into_missing_segments())
1162 .collect::<HashSet<_>>();
1163
1164 assert!(computed_missing.par_iter().all(|s| excluded.contains(s)));
1165 tokio::time::sleep(Duration::from_millis(25)).await;
1172 drop(fragmented);
1173
1174 let (reassembled_frames, discarded_frames) = reassembled
1175 .map(|f| match f {
1176 Ok(f) => (Some(f), None),
1177 Err(e) => (None, Some(e)),
1178 })
1179 .unzip::<_, _, Vec<_>, Vec<_>>()
1180 .await;
1181
1182 let reassembled_frames = reassembled_frames
1183 .into_par_iter()
1184 .filter_map(identity)
1185 .collect::<Vec<_>>();
1186
1187 (reassembled_frames, expected_frames)
1188 .into_par_iter()
1189 .all(|(a, b)| a.eq(b));
1190
1191 let discarded_frames = discarded_frames
1192 .into_par_iter()
1193 .filter_map(|s| match s {
1194 Some(SessionError::FrameDiscarded(f)) => Some(f),
1195 _ => None,
1196 })
1197 .collect::<Vec<_>>();
1198
1199 let expected_discarded_frames = excluded.into_par_iter().map(|s| s.0).collect::<Vec<_>>();
1200
1201 (discarded_frames, expected_discarded_frames)
1202 .into_par_iter()
1203 .all(|(a, b)| a == b);
1204
1205 Ok(())
1206 }
1207
1208 #[tokio::test]
1209 async fn frame_reassembler_yields_no_frames_when_all_corrupted() -> anyhow::Result<()> {
1210 let (segments, expected_frames, _) = corrupt_frames(1000, 1.0);
1212 assert!(expected_frames.is_empty());
1213
1214 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(100));
1215
1216 segments.into_par_iter().try_for_each(|s| fragmented.push_segment(s))?;
1217 drop(fragmented);
1218
1219 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1220
1221 assert!(reassembled_frames.is_empty());
1222
1223 Ok(())
1224 }
1225
1226 fn create_unreliable_segment_stream(
1227 num_frames: usize,
1228 max_latency: Duration,
1229 mixing_factor: f64,
1230 corruption_ratio: f64,
1231 ) -> (impl Stream<Item = Segment>, Vec<&'static Frame>) {
1232 let mut segments = FRAMES
1233 .par_iter()
1234 .take(num_frames)
1235 .flat_map(|f| f.segment(MTU).unwrap())
1236 .collect::<VecDeque<_>>();
1237
1238 let (corrupted_frames, corrupted_segments): (HashSet<FrameId>, HashSet<SegmentId>) = segments
1239 .iter()
1240 .choose_multiple(
1241 &mut thread_rng(),
1242 (segments.len() as f64 * corruption_ratio).round() as usize,
1243 )
1244 .into_par_iter()
1245 .map(|s| (s.frame_id, SegmentId(s.frame_id, s.seq_idx)))
1246 .unzip();
1247
1248 (
1249 stream! {
1250 let mut rng = thread_rng();
1251 let mut distr = Normal::new(0.0, mixing_factor).unwrap();
1252 while !segments.is_empty() {
1253 let segment = segments.remove(sample_index(&mut distr, &mut rng, segments.len())).unwrap();
1254
1255 if !corrupted_segments.contains(&SegmentId(segment.frame_id, segment.seq_idx)) {
1256 tokio::time::sleep(max_latency.mul_f64(rng.gen())).await;
1257 yield segment;
1258 }
1259 }
1260 },
1261 FRAMES
1262 .par_iter()
1263 .filter(|f| !corrupted_frames.contains(&f.frame_id))
1264 .collect(),
1265 )
1266 }
1267
1268 #[tokio::test]
1269 async fn frame_reassembler_yields_and_evicts_frames_on_unreliable_network() -> anyhow::Result<()> {
1270 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1271 let fragmented = Arc::new(fragmented);
1272
1273 let done = Arc::new(AtomicBool::new(false));
1274 let done_clone = done.clone();
1275 let frag_clone = fragmented.clone();
1276 let eviction_jh = tokio::task::spawn(async move {
1277 while !done_clone.load(Ordering::SeqCst) {
1278 tokio::time::sleep(Duration::from_millis(25)).await;
1279 frag_clone.evict().unwrap();
1280 }
1281 });
1282
1283 let (stream, expected_frames) =
1285 create_unreliable_segment_stream(200, Duration::from_millis(2), MIXING_FACTOR, 0.2);
1286 stream
1287 .map(Ok)
1288 .try_for_each(|s| futures::future::ready(fragmented.push_segment(s)))
1289 .await?;
1290
1291 done.store(true, Ordering::SeqCst);
1292 eviction_jh.await?;
1293 drop(fragmented);
1294
1295 let reassembled_frames = reassembled
1296 .filter(|f| futures::future::ready(f.is_ok())) .try_collect::<Vec<_>>()
1298 .await?;
1299 reassembled_frames
1300 .into_iter()
1301 .enumerate()
1302 .for_each(|(i, frame)| assert_eq!(&frame, expected_frames[i]));
1303
1304 Ok(())
1305 }
1306}