1use bitvec::array::BitArray;
42use bitvec::{bitarr, BitArr};
43use dashmap::mapref::entry::Entry;
44use dashmap::DashMap;
45use futures::{Sink, Stream};
46use std::collections::BinaryHeap;
47use std::fmt::{Debug, Display, Formatter};
48use std::mem;
49use std::ops::{Add, Sub};
50use std::pin::Pin;
51use std::sync::atomic::{AtomicU32, AtomicU64, AtomicU8, Ordering};
52use std::sync::OnceLock;
53use std::task::{Context, Poll};
54use std::time::{Duration, SystemTime};
55
56use hopr_platform::time::native::current_time;
57use hopr_primitive_types::prelude::AsUnixTimestamp;
58
59use crate::errors::NetworkTypeError;
60use crate::session::errors::SessionError;
61
62pub type FrameId = u32;
64pub type SeqNum = u8;
66
67#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
70pub struct SegmentId(pub FrameId, pub SeqNum);
71
72impl From<&Segment> for SegmentId {
73 fn from(value: &Segment) -> Self {
74 value.id()
75 }
76}
77
78impl Display for SegmentId {
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 write!(f, "seg({},{})", self.0, self.1)
81 }
82}
83
84pub fn segment(data: &[u8], max_segment_size: usize, frame_id: u32) -> crate::session::errors::Result<Vec<Segment>> {
87 if frame_id == 0 {
88 return Err(SessionError::InvalidFrameId);
89 }
90
91 if max_segment_size == 0 {
92 return Err(SessionError::InvalidSegmentSize);
93 }
94
95 let num_chunks = data.len().div_ceil(max_segment_size);
96 if num_chunks > SeqNum::MAX as usize {
97 return Err(SessionError::DataTooLong);
98 }
99
100 let chunks = data.chunks(max_segment_size);
101
102 let seq_len = chunks.len() as SeqNum;
103 Ok(chunks
104 .enumerate()
105 .map(|(idx, data)| Segment {
106 frame_id,
107 seq_len,
108 seq_idx: idx as u8,
109 data: data.into(),
110 })
111 .collect())
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
118pub struct Frame {
119 pub frame_id: FrameId,
121 pub data: Box<[u8]>,
123}
124
125impl Frame {
126 #[inline]
128 pub fn segment(&self, max_segment_size: usize) -> crate::session::errors::Result<Vec<Segment>> {
129 segment(self.data.as_ref(), max_segment_size, self.frame_id)
130 }
131}
132
133impl AsRef<[u8]> for Frame {
134 fn as_ref(&self) -> &[u8] {
135 &self.data
136 }
137}
138
139#[derive(Clone, Eq, PartialEq)]
144#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
145pub struct Segment {
146 pub frame_id: FrameId,
148 pub seq_idx: SeqNum,
150 pub seq_len: SeqNum,
152 #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
154 pub data: Box<[u8]>,
155}
156
157impl Segment {
158 pub const HEADER_SIZE: usize = mem::size_of::<FrameId>() + 2 * mem::size_of::<SeqNum>();
160
161 pub const MINIMUM_SIZE: usize = Self::HEADER_SIZE + 1;
163
164 pub fn id(&self) -> SegmentId {
166 SegmentId(self.frame_id, self.seq_idx)
167 }
168}
169
170impl Debug for Segment {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("Segment")
173 .field("frame_id", &self.frame_id)
174 .field("seq_id", &self.seq_idx)
175 .field("seq_len", &self.seq_len)
176 .field("data", &hex::encode(&self.data))
177 .finish()
178 }
179}
180
181impl PartialOrd<Segment> for Segment {
182 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
183 Some(self.cmp(other))
184 }
185}
186
187impl Ord for Segment {
188 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
189 match self.frame_id.cmp(&other.frame_id) {
190 std::cmp::Ordering::Equal => self.seq_idx.cmp(&other.seq_idx),
191 cmp => cmp,
192 }
193 }
194}
195
196impl From<Segment> for Vec<u8> {
197 fn from(value: Segment) -> Self {
198 let mut ret = Vec::with_capacity(Segment::HEADER_SIZE + value.data.len());
199 ret.extend_from_slice(value.frame_id.to_be_bytes().as_ref());
200 ret.extend_from_slice(value.seq_idx.to_be_bytes().as_ref());
201 ret.extend_from_slice(value.seq_len.to_be_bytes().as_ref());
202 ret.extend_from_slice(value.data.as_ref());
203 ret
204 }
205}
206
207impl TryFrom<&[u8]> for Segment {
208 type Error = SessionError;
209
210 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
211 let (header, data) = value.split_at(Self::HEADER_SIZE);
212 let segment = Segment {
213 frame_id: FrameId::from_be_bytes(header[0..4].try_into().map_err(|_| SessionError::InvalidSegment)?),
214 seq_idx: SeqNum::from_be_bytes(header[4..5].try_into().map_err(|_| SessionError::InvalidSegment)?),
215 seq_len: SeqNum::from_be_bytes(header[5..6].try_into().map_err(|_| SessionError::InvalidSegment)?),
216 data: data.into(),
217 };
218 (segment.frame_id > 0 && segment.seq_idx < segment.seq_len)
219 .then_some(segment)
220 .ok_or(SessionError::InvalidSegment)
221 }
222}
223
224#[derive(Debug)]
226struct FrameBuilder {
227 frame_id: FrameId,
228 segments: Vec<OnceLock<Box<[u8]>>>,
229 remaining: AtomicU8,
230 last_ts: AtomicU64,
231}
232
233impl FrameBuilder {
234 fn new(initial: Segment, ts: SystemTime) -> Self {
236 let ret = Self::empty(initial.frame_id, initial.seq_len);
237 ret.put(initial, ts).unwrap();
238 ret
239 }
240
241 fn empty(frame_id: FrameId, seq_len: SeqNum) -> Self {
243 Self {
244 frame_id,
245 segments: vec![OnceLock::new(); seq_len as usize],
246 remaining: AtomicU8::new(seq_len),
247 last_ts: AtomicU64::new(0),
248 }
249 }
250
251 fn put(&self, segment: Segment, ts: SystemTime) -> crate::session::errors::Result<SeqNum> {
254 if self.frame_id == segment.frame_id {
255 if !self.is_complete() {
256 if self.segments[segment.seq_idx as usize].set(segment.data).is_ok() {
257 self.remaining.fetch_sub(1, Ordering::Relaxed);
259 self.last_ts
260 .fetch_max(ts.as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
261 }
262 Ok(self.remaining.load(Ordering::SeqCst))
263 } else {
264 Ok(0)
266 }
267 } else {
268 Err(SessionError::InvalidFrameId)
269 }
270 }
271
272 fn is_complete(&self) -> bool {
274 self.remaining.load(Ordering::SeqCst) == 0
275 }
276
277 fn is_expired(&self, cutoff: u64) -> bool {
281 self.last_ts.load(Ordering::SeqCst) < cutoff
282 }
283
284 pub fn info(&self) -> FrameInfo {
286 let mut missing_segments = bitarr![0; 256];
287 self.segments
288 .iter()
289 .enumerate()
290 .filter_map(|(i, s)| s.get().is_none().then_some(i))
291 .for_each(|i| missing_segments.set(i, true));
292
293 FrameInfo {
294 frame_id: self.frame_id,
295 missing_segments,
296 total_segments: self.segments.len() as SeqNum,
297 last_update: SystemTime::UNIX_EPOCH.add(Duration::from_millis(self.last_ts.load(Ordering::SeqCst))),
298 }
299 }
300
301 fn reassemble(self) -> crate::session::errors::Result<Frame> {
303 if self.is_complete() {
304 Ok(Frame {
305 frame_id: self.frame_id,
306 data: self
307 .segments
308 .into_iter()
309 .map(|lock| lock.into_inner().unwrap())
310 .collect::<Vec<Box<[u8]>>>()
311 .concat()
312 .into_boxed_slice(),
313 })
314 } else {
315 Err(SessionError::IncompleteFrame(self.frame_id))
316 }
317 }
318}
319
320#[derive(Debug, Clone, PartialEq, Eq)]
323pub struct FrameInfo {
324 pub frame_id: FrameId,
326 pub missing_segments: BitArr!(for 256),
328 pub total_segments: SeqNum,
330 pub last_update: SystemTime,
332}
333
334impl FrameInfo {
335 pub fn iter_missing_sequence_indices(&self) -> impl Iterator<Item = SeqNum> + '_ {
337 self.missing_segments
338 .iter()
339 .by_vals()
340 .enumerate()
341 .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
342 .map(|(s, _)| s as SeqNum)
343 }
344
345 pub fn into_missing_segments(self) -> impl Iterator<Item = SegmentId> {
346 self.missing_segments
347 .into_iter()
348 .enumerate()
349 .filter(|(i, s)| *s && *i <= SeqNum::MAX as usize)
350 .map(move |(i, _)| SegmentId(self.frame_id, i as SeqNum))
351 }
352}
353
354impl PartialOrd for FrameInfo {
355 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
356 Some(self.cmp(other))
357 }
358}
359
360impl Ord for FrameInfo {
361 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
362 match self.last_update.cmp(&other.last_update) {
363 std::cmp::Ordering::Equal => self.frame_id.cmp(&self.frame_id),
364 cmp => cmp,
365 }
366 .reverse()
367 }
368}
369
370#[derive(Debug)]
414pub struct FrameReassembler {
415 sequences: DashMap<FrameId, FrameBuilder>,
416 highest_buffered_frame: AtomicU32,
417 next_emitted_frame: AtomicU32,
418 last_emission: AtomicU64,
419 reassembled: futures::channel::mpsc::UnboundedSender<crate::session::errors::Result<Frame>>,
420 max_age: Duration,
421}
422
423impl FrameReassembler {
424 pub fn new(max_age: Duration) -> (Self, impl Stream<Item = crate::session::errors::Result<Frame>>) {
429 let (reassembled, reassembled_recv) = futures::channel::mpsc::unbounded();
430 (
431 Self {
432 sequences: DashMap::new(),
433 highest_buffered_frame: AtomicU32::new(0),
434 next_emitted_frame: AtomicU32::new(1),
435 last_emission: AtomicU64::new(u64::MAX),
436 reassembled,
437 max_age,
438 },
439 reassembled_recv,
440 )
441 }
442
443 fn emit_if_complete_discard_otherwise(&self, builder: FrameBuilder) -> crate::session::errors::Result<()> {
446 if self.next_emitted_frame.fetch_add(1, Ordering::SeqCst) == builder.frame_id && builder.is_complete() {
447 self.reassembled
448 .unbounded_send(builder.reassemble())
449 .map_err(|_| SessionError::ReassemblerClosed)?;
450 } else {
451 self.reassembled
452 .unbounded_send(Err(SessionError::FrameDiscarded(builder.frame_id)))
453 .map_err(|_| SessionError::ReassemblerClosed)?;
454 }
455 self.last_emission
456 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
457 Ok(())
458 }
459
460 pub fn push_segment(&self, segment: Segment) -> crate::session::errors::Result<()> {
465 if self.reassembled.is_closed() {
466 return Err(SessionError::ReassemblerClosed);
467 }
468
469 let frame_id = segment.frame_id;
471 if frame_id < self.next_emitted_frame.load(Ordering::SeqCst) {
472 return Err(SessionError::OldSegment(frame_id));
473 }
474
475 let ts = current_time();
476 let mut cascade = false;
477
478 match self.sequences.entry(frame_id) {
479 Entry::Occupied(e) => {
480 if e.get().put(segment, ts)? == 0
482 && self
483 .next_emitted_frame
484 .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
485 .is_ok()
486 {
487 self.reassembled
489 .unbounded_send(e.remove().reassemble())
490 .map_err(|_| SessionError::ReassemblerClosed)?;
491 self.last_emission
492 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
493 cascade = true; }
495 }
496 Entry::Vacant(v) => {
497 let builder = FrameBuilder::new(segment, ts);
498 if builder.is_complete()
500 && self
501 .next_emitted_frame
502 .compare_exchange(frame_id, frame_id + 1, Ordering::SeqCst, Ordering::Relaxed)
503 .is_ok()
504 {
505 self.reassembled
507 .unbounded_send(builder.reassemble())
508 .map_err(|_| SessionError::ReassemblerClosed)?;
509 self.last_emission
510 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
511 cascade = true; } else {
513 v.insert(builder);
515 self.highest_buffered_frame.fetch_max(frame_id, Ordering::Relaxed);
516 }
517 }
518 }
519
520 if cascade {
522 while let Some((_, builder)) = self
523 .sequences
524 .remove_if(&self.next_emitted_frame.load(Ordering::SeqCst), |_, b| b.is_complete())
525 {
526 self.emit_if_complete_discard_otherwise(builder)?;
528 }
529 }
530
531 Ok(())
532 }
533
534 pub fn incomplete_frames(&self) -> BinaryHeap<FrameInfo> {
537 (self.next_emitted_frame.load(Ordering::SeqCst)..=self.highest_buffered_frame.load(Ordering::SeqCst))
538 .filter_map(|frame_id| match self.sequences.get(&frame_id) {
539 Some(e) => (!e.is_complete()).then(|| e.info()),
540 None => Some({
541 let mut missing_segments = BitArray::ZERO;
542 missing_segments.set(0, true);
543 FrameInfo {
544 frame_id,
545 missing_segments,
546 total_segments: 1,
547 last_update: SystemTime::UNIX_EPOCH,
548 }
549 }),
550 })
551 .collect()
552 }
553
554 pub fn evict(&self) -> crate::session::errors::Result<usize> {
558 if self.reassembled.is_closed() {
559 return Err(SessionError::ReassemblerClosed);
560 }
561
562 if self.sequences.is_empty() {
563 return Ok(0);
564 }
565
566 let cutoff = current_time().sub(self.max_age).as_unix_timestamp().as_millis() as u64;
567 let mut count = 0;
568 loop {
569 let next = self.next_emitted_frame.load(Ordering::SeqCst);
570 if let Some((_, builder)) = self
571 .sequences
572 .remove_if(&next, |_, b| b.is_complete() || b.is_expired(cutoff))
573 {
574 self.emit_if_complete_discard_otherwise(builder)?;
576 count += 1;
577 } else if !self.sequences.contains_key(&next) && self.last_emission.load(Ordering::SeqCst) < cutoff {
578 self.next_emitted_frame.fetch_add(1, Ordering::Relaxed);
580 self.last_emission
581 .store(current_time().as_unix_timestamp().as_millis() as u64, Ordering::Relaxed);
582 count += 1;
583 } else {
584 break;
586 }
587 }
588
589 Ok(count)
590 }
591
592 pub fn close(&self) {
595 self.reassembled.close_channel();
596 }
597}
598
599impl Drop for FrameReassembler {
600 fn drop(&mut self) {
601 let _ = self.evict();
602 self.reassembled.close_channel();
603 }
604}
605
606impl Extend<Segment> for FrameReassembler {
607 fn extend<T: IntoIterator<Item = Segment>>(&mut self, iter: T) {
608 iter.into_iter()
609 .try_for_each(|s| self.push_segment(s))
610 .expect("failed to extend")
611 }
612}
613
614impl Sink<Segment> for FrameReassembler {
615 type Error = NetworkTypeError;
616
617 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
618 Poll::Ready(Ok(()))
619 }
620
621 fn start_send(self: Pin<&mut Self>, item: Segment) -> Result<(), Self::Error> {
622 Ok(self.push_segment(item)?)
623 }
624
625 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
626 Poll::Ready(self.evict().map(|_| ()).map_err(NetworkTypeError::SessionProtocolError))
627 }
628
629 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630 self.reassembled.close_channel();
631 Poll::Ready(Ok(()))
632 }
633}
634
635#[cfg(test)]
636pub(crate) mod tests {
637 use super::*;
638 use async_stream::stream;
639 use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
640 use hex_literal::hex;
641 use lazy_static::lazy_static;
642 use rand::prelude::{Distribution, SliceRandom};
643 use rand::{seq::IteratorRandom, thread_rng, Rng, SeedableRng};
644 use rand_distr::Normal;
645 use rayon::prelude::*;
646 use std::collections::{HashSet, VecDeque};
647 use std::convert::identity;
648 use std::sync::atomic::{AtomicBool, Ordering};
649 use std::sync::Arc;
650 use std::time::Duration;
651
652 const MTU: usize = 448;
653 const FRAME_COUNT: u32 = 65_535;
654 const FRAME_SIZE: usize = 4096;
655 const MIXING_FACTOR: f64 = 4.0;
656
657 lazy_static! {
658 static ref RAND_SEED: [u8; 32] = hopr_crypto_random::random_bytes();
660 static ref FRAMES: Vec<Frame> = (0..FRAME_COUNT)
661 .into_par_iter()
662 .map(|frame_id| Frame {
663 frame_id: frame_id + 1,
664 data: hopr_crypto_random::random_bytes::<FRAME_SIZE>().into(),
665 })
666 .collect::<Vec<_>>();
667 static ref SEGMENTS: Vec<Segment> = {
668 let vec = FRAMES.par_iter().flat_map(|f| f.segment(MTU).unwrap()).collect::<VecDeque<_>>();
669 let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
670 linear_half_normal_shuffle(&mut rng, vec, MIXING_FACTOR)
671 };
672 }
673
674 pub fn sample_index<T: Distribution<f64>, R: Rng>(dist: &mut T, rng: &mut R, len: usize) -> usize {
676 let f: f64 = dist.sample(rng);
677 (f.max(0.0).round() as usize).min(len - 1)
678 }
679
680 fn linear_half_normal_shuffle<T, R: Rng>(rng: &mut R, mut vec: VecDeque<T>, factor: f64) -> Vec<T> {
685 if factor == 0.0 || vec.is_empty() {
686 return vec.into(); }
688
689 let mut dist = Normal::new(0.0, factor).unwrap();
690 let mut ret = Vec::new();
691 while !vec.is_empty() {
692 ret.push(vec.remove(sample_index(&mut dist, rng, vec.len())).unwrap());
693 }
694 ret
695 }
696
697 #[ctor::ctor]
698 fn init() {
699 lazy_static::initialize(&FRAMES);
700 lazy_static::initialize(&SEGMENTS);
701 }
702
703 #[test]
704 fn segmentation_should_segment_data_correctly() -> anyhow::Result<()> {
705 let data = hex!("deadbeefcafebabe");
706 let frame = Frame {
707 frame_id: 1,
708 data: data.as_ref().into(),
709 };
710
711 let segments = frame.segment(3)?;
712 assert_eq!(3, segments.len());
713
714 assert_eq!(hex!("deadbe"), segments[0].data.as_ref());
715 assert_eq!(0, segments[0].seq_idx);
716 assert_eq!(3, segments[0].seq_len);
717 assert_eq!(frame.frame_id, segments[0].frame_id);
718
719 assert_eq!(hex!("efcafe"), segments[1].data.as_ref());
720 assert_eq!(1, segments[1].seq_idx);
721 assert_eq!(3, segments[1].seq_len);
722 assert_eq!(frame.frame_id, segments[1].frame_id);
723
724 assert_eq!(hex!("babe"), segments[2].data.as_ref());
725 assert_eq!(2, segments[2].seq_idx);
726 assert_eq!(3, segments[2].seq_len);
727 assert_eq!(frame.frame_id, segments[2].frame_id);
728
729 Ok(())
730 }
731
732 #[test]
733 fn segment_must_serialize_and_deserialize() {
734 let data = hopr_crypto_random::random_bytes::<128>();
735
736 let segment = Segment {
737 frame_id: 1234,
738 seq_len: 123,
739 seq_idx: 12,
740 data: data.into(),
741 };
742
743 let boxed: Vec<u8> = segment.clone().into();
744 let recovered: Segment = (&boxed[..]).try_into().unwrap();
745
746 assert_eq!(segment, recovered);
747 }
748
749 #[async_std::test]
750 async fn frame_reassembler_must_process_ordered_frames() -> anyhow::Result<()> {
751 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
752
753 FRAMES
754 .iter()
755 .flat_map(|f| f.segment(MTU).unwrap())
756 .try_for_each(|s| fragmented.push_segment(s))?;
757
758 drop(fragmented);
759 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
760
761 reassembled_frames
762 .into_par_iter()
763 .enumerate()
764 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
765
766 Ok(())
767 }
768
769 #[async_std::test]
770 async fn frame_reassembler_must_process_single_frame() -> anyhow::Result<()> {
771 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(10));
772
773 let data = hex!("cafe");
774
775 let segment = Segment {
776 frame_id: 1,
777 seq_idx: 0,
778 seq_len: 1,
779 data: hex!("cafe").into(),
780 };
781
782 fragmented.push_segment(segment)?;
783 drop(fragmented);
784 let mut reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
785
786 assert_eq!(1, reassembled_frames.len());
787 let frame = reassembled_frames.pop().ok_or(SessionError::InvalidSegment)?;
788
789 assert_eq!(1, frame.frame_id);
790 assert_eq!(&data, frame.data.as_ref());
791
792 Ok(())
793 }
794
795 #[test]
796 fn should_not_push_frame_id_0_into_reassembler() -> anyhow::Result<()> {
797 let frame = Frame {
798 frame_id: 1,
799 data: hex!("deadbeefcafe").into(),
800 };
801
802 let mut segments = frame.segment(2)?;
803 segments[0].frame_id = 0;
804
805 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
806 fragmented
807 .push_segment(segments[0].clone())
808 .expect_err("must not push frame id 0");
809
810 Ok(())
811 }
812
813 #[test]
814 fn pushing_segment_of_a_completed_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
815 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_secs(30));
816
817 let segments = FRAMES[0].segment(MTU)?;
818 let segment_1 = segments[0].clone();
819
820 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
821
822 fragmented
823 .push_segment(segment_1)
824 .expect_err("must fail pushing segment of a completed frame");
825
826 Ok(())
827 }
828
829 #[async_std::test]
830 async fn pushing_segment_of_an_evicted_frame_into_reassembler_should_fail() -> anyhow::Result<()> {
831 let (fragmented, _reassembled) = FrameReassembler::new(Duration::from_millis(5));
832
833 let mut segments = FRAMES[0].segment(MTU)?;
834 let segment_1 = segments.pop().unwrap(); segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
837
838 async_std::task::sleep(Duration::from_millis(10)).await;
839 assert_eq!(1, fragmented.evict()?);
840
841 fragmented
842 .push_segment(segment_1)
843 .expect_err("must fail pushing segment of an evicted frame");
844
845 Ok(())
846 }
847
848 #[async_std::test]
849 async fn frame_reassembler_reassembles_single_frame() -> anyhow::Result<()> {
850 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
851
852 let mut rng = thread_rng();
853
854 let frame = FRAMES[0].clone();
855 let mut segments = frame.segment(MTU)?;
856 segments.shuffle(&mut rng);
857
858 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
859
860 drop(fragmented);
861 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
862
863 assert_eq!(1, reassembled_frames.len());
864 assert_eq!(frame, reassembled_frames[0]);
865
866 Ok(())
867 }
868
869 #[async_std::test]
870 async fn frame_reassembler_reassembles_shuffled_randomized_frames() -> anyhow::Result<()> {
871 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
872
873 SEGMENTS.iter().cloned().try_for_each(|b| fragmented.push_segment(b))?;
874
875 assert_eq!(0, fragmented.evict().unwrap());
876 drop(fragmented);
877
878 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
879
880 reassembled_frames
881 .into_par_iter()
882 .enumerate()
883 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
884
885 Ok(())
886 }
887
888 #[async_std::test]
889 async fn frame_reassembler_reassembles_shuffled_randomized_frames_in_parallel() -> anyhow::Result<()> {
890 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
891
892 SEGMENTS
893 .par_iter()
894 .cloned()
895 .try_for_each(|b| fragmented.push_segment(b))?;
896
897 assert_eq!(0, fragmented.evict()?);
898 drop(fragmented);
899
900 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
901
902 reassembled_frames
903 .into_par_iter()
904 .enumerate()
905 .for_each(|(i, frame)| assert_eq!(frame, FRAMES[i]));
906
907 Ok(())
908 }
909
910 #[async_std::test]
911 async fn frame_reassembler_should_evict_expired_incomplete_frames() -> anyhow::Result<()> {
912 let frames = vec![
913 Frame {
914 frame_id: 1,
915 data: hex!("deadbeefcafebabe").into(),
916 },
917 Frame {
918 frame_id: 2,
919 data: hex!("feedbeefbaadcafe").into(),
920 },
921 Frame {
922 frame_id: 3,
923 data: hex!("00112233abcd").into(),
924 },
925 ];
926
927 let mut segments = frames
928 .iter()
929 .flat_map(|f| f.segment(3).unwrap())
930 .collect::<VecDeque<_>>();
931 segments.retain(|s| s.frame_id != 2 || s.seq_idx != 2); let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
934
935 segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
936
937 let frames_cpy = frames.clone();
938 let jh = async_std::task::spawn(async move {
939 pin_mut!(reassembled);
940
941 assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
943
944 assert!(matches!(
946 reassembled.try_next().await,
947 Err(SessionError::FrameDiscarded(2))
948 ));
949
950 assert_eq!(Some(frames_cpy[2].clone()), reassembled.try_next().await?);
952
953 Ok(())
954 });
955
956 async_std::task::sleep(Duration::from_millis(20)).await;
957
958 assert_eq!(2, fragmented.evict()?); jh.await
961 }
962
963 #[async_std::test]
964 async fn frame_reassembler_should_evict_frame_that_never_arrived() -> anyhow::Result<()> {
965 let frames = vec![
966 Frame {
967 frame_id: 1,
968 data: hex!("deadbeefcafebabe").into(),
969 },
970 Frame {
971 frame_id: 3,
972 data: hex!("00112233abcd").into(),
973 },
974 ];
975
976 let segments = frames
977 .iter()
978 .flat_map(|f| f.segment(3).unwrap())
979 .collect::<VecDeque<_>>();
980
981 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(10));
982
983 segments.into_iter().try_for_each(|b| fragmented.push_segment(b))?;
984
985 let flushed = Arc::new(AtomicBool::new(false));
986
987 let flushed_cpy = flushed.clone();
988 let frames_cpy = frames.clone();
989 let jh = async_std::task::spawn(async move {
990 pin_mut!(reassembled);
991
992 assert_eq!(Some(frames_cpy[0].clone()), reassembled.try_next().await?);
994
995 assert!(!flushed_cpy.load(Ordering::SeqCst));
996
997 assert_eq!(Some(frames_cpy[1].clone()), reassembled.try_next().await?);
999
1000 assert!(flushed_cpy.load(Ordering::SeqCst));
1002
1003 Ok(())
1004 });
1005
1006 async_std::task::sleep(Duration::from_millis(20)).await;
1007
1008 flushed.store(true, Ordering::SeqCst);
1010 assert_eq!(2, fragmented.evict()?); jh.await
1013 }
1014
1015 #[async_std::test]
1016 async fn frame_reassembler_reassembles_randomized_delayed_frames_in_parallel() -> anyhow::Result<()> {
1017 let frames = FRAMES.iter().take(100).collect::<Vec<_>>();
1018
1019 let segments = frames
1020 .iter()
1021 .flat_map(|frame| frame.segment(MTU).unwrap())
1022 .collect::<Vec<_>>();
1023
1024 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_secs(30));
1025
1026 futures::stream::iter(segments)
1027 .map(|segment| {
1028 let delay = Duration::from_millis(thread_rng().gen_range(0..10u64));
1029 async_std::task::spawn(async move {
1030 async_std::task::sleep(delay).await;
1031 Ok(segment)
1032 })
1033 })
1034 .buffer_unordered(4)
1035 .forward(fragmented)
1036 .await
1037 .unwrap();
1038
1039 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1040
1041 reassembled_frames
1042 .into_par_iter()
1043 .enumerate()
1044 .for_each(|(i, frame)| assert_eq!(&frame, frames[i]));
1045
1046 Ok(())
1047 }
1048
1049 fn corrupt_frames(
1051 num_frames: u32,
1052 corrupted_ratio: f32,
1053 ) -> (Vec<Segment>, Vec<&'static Frame>, HashSet<SegmentId>) {
1054 assert!((0.0..=1.0).contains(&corrupted_ratio));
1055
1056 let mut rng = rand::rngs::StdRng::from_seed(*RAND_SEED);
1057
1058 let (excluded_frame_ids, excluded_segments): (HashSet<FrameId>, HashSet<SegmentId>) = (1..num_frames + 1)
1059 .choose_multiple(&mut rng, ((num_frames as f32) * corrupted_ratio) as usize)
1060 .into_iter() .map(|frame_id| {
1062 (
1063 frame_id,
1064 SegmentId(
1065 frame_id,
1066 rng.gen_range(0..SEGMENTS.iter().find(|s| s.frame_id == frame_id).unwrap().seq_len),
1067 ),
1068 )
1069 })
1070 .unzip();
1071
1072 let segments = SEGMENTS
1073 .par_iter()
1074 .filter(|s| s.frame_id < num_frames && !excluded_segments.contains(&SegmentId(s.frame_id, s.seq_idx)))
1075 .cloned()
1076 .collect::<Vec<_>>();
1077
1078 let expected_frames = FRAMES
1079 .par_iter()
1080 .filter(|f| f.frame_id < num_frames && !excluded_frame_ids.contains(&f.frame_id))
1081 .collect::<Vec<_>>();
1082
1083 (segments, expected_frames, excluded_segments)
1084 }
1085
1086 #[async_std::test]
1087 async fn frame_reassembler_yields_correct_frames_when_also_corrupted_frames_are_present() -> anyhow::Result<()> {
1088 let (segments, expected_frames, excluded) = corrupt_frames(FRAME_COUNT / 4, 0.3);
1090
1091 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1092
1093 segments.into_iter().try_for_each(|s| fragmented.push_segment(s))?;
1094
1095 let computed_missing = fragmented
1096 .incomplete_frames()
1097 .into_par_iter()
1098 .flat_map_iter(|e| e.into_missing_segments())
1099 .collect::<HashSet<_>>();
1100
1101 assert!(computed_missing.par_iter().all(|s| excluded.contains(s)));
1102 async_std::task::sleep(Duration::from_millis(25)).await;
1109 drop(fragmented);
1110
1111 let (reassembled_frames, discarded_frames) = reassembled
1112 .map(|f| match f {
1113 Ok(f) => (Some(f), None),
1114 Err(e) => (None, Some(e)),
1115 })
1116 .unzip::<_, _, Vec<_>, Vec<_>>()
1117 .await;
1118
1119 let reassembled_frames = reassembled_frames
1120 .into_par_iter()
1121 .filter_map(identity)
1122 .collect::<Vec<_>>();
1123
1124 (reassembled_frames, expected_frames)
1125 .into_par_iter()
1126 .all(|(a, b)| a.eq(b));
1127
1128 let discarded_frames = discarded_frames
1129 .into_par_iter()
1130 .filter_map(|s| match s {
1131 Some(SessionError::FrameDiscarded(f)) => Some(f),
1132 _ => None,
1133 })
1134 .collect::<Vec<_>>();
1135
1136 let expected_discarded_frames = excluded.into_par_iter().map(|s| s.0).collect::<Vec<_>>();
1137
1138 (discarded_frames, expected_discarded_frames)
1139 .into_par_iter()
1140 .all(|(a, b)| a == b);
1141
1142 Ok(())
1143 }
1144
1145 #[async_std::test]
1146 async fn frame_reassembler_yields_no_frames_when_all_corrupted() -> anyhow::Result<()> {
1147 let (segments, expected_frames, _) = corrupt_frames(1000, 1.0);
1149 assert!(expected_frames.is_empty());
1150
1151 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(100));
1152
1153 segments.into_par_iter().try_for_each(|s| fragmented.push_segment(s))?;
1154 drop(fragmented);
1155
1156 let reassembled_frames = reassembled.try_collect::<Vec<_>>().await?;
1157
1158 assert!(reassembled_frames.is_empty());
1159
1160 Ok(())
1161 }
1162
1163 fn create_unreliable_segment_stream(
1164 num_frames: usize,
1165 max_latency: Duration,
1166 mixing_factor: f64,
1167 corruption_ratio: f64,
1168 ) -> (impl Stream<Item = Segment>, Vec<&'static Frame>) {
1169 let mut segments = FRAMES
1170 .par_iter()
1171 .take(num_frames)
1172 .flat_map(|f| f.segment(MTU).unwrap())
1173 .collect::<VecDeque<_>>();
1174
1175 let (corrupted_frames, corrupted_segments): (HashSet<FrameId>, HashSet<SegmentId>) = segments
1176 .iter()
1177 .choose_multiple(
1178 &mut thread_rng(),
1179 (segments.len() as f64 * corruption_ratio).round() as usize,
1180 )
1181 .into_par_iter()
1182 .map(|s| (s.frame_id, SegmentId(s.frame_id, s.seq_idx)))
1183 .unzip();
1184
1185 (
1186 stream! {
1187 let mut rng = thread_rng();
1188 let mut distr = Normal::new(0.0, mixing_factor).unwrap();
1189 while !segments.is_empty() {
1190 let segment = segments.remove(sample_index(&mut distr, &mut rng, segments.len())).unwrap();
1191
1192 if !corrupted_segments.contains(&SegmentId(segment.frame_id, segment.seq_idx)) {
1193 async_std::task::sleep(max_latency.mul_f64(rng.gen())).await;
1194 yield segment;
1195 }
1196 }
1197 },
1198 FRAMES
1199 .par_iter()
1200 .filter(|f| !corrupted_frames.contains(&f.frame_id))
1201 .collect(),
1202 )
1203 }
1204
1205 #[async_std::test]
1206 async fn frame_reassembler_yields_and_evicts_frames_on_unreliable_network() -> anyhow::Result<()> {
1207 let (fragmented, reassembled) = FrameReassembler::new(Duration::from_millis(25));
1208 let fragmented = Arc::new(fragmented);
1209
1210 let done = Arc::new(AtomicBool::new(false));
1211 let done_clone = done.clone();
1212 let frag_clone = fragmented.clone();
1213 let eviction_jh = async_std::task::spawn(async move {
1214 while !done_clone.load(Ordering::SeqCst) {
1215 async_std::task::sleep(Duration::from_millis(25)).await;
1216 frag_clone.evict().unwrap();
1217 }
1218 });
1219
1220 let (stream, expected_frames) =
1222 create_unreliable_segment_stream(200, Duration::from_millis(2), MIXING_FACTOR, 0.2);
1223 stream
1224 .map(Ok)
1225 .try_for_each(|s| futures::future::ready(fragmented.push_segment(s)))
1226 .await?;
1227
1228 done.store(true, Ordering::SeqCst);
1229 eviction_jh.await;
1230 drop(fragmented);
1231
1232 let reassembled_frames = reassembled
1233 .filter(|f| futures::future::ready(f.is_ok())) .try_collect::<Vec<_>>()
1235 .await?;
1236 reassembled_frames
1237 .into_iter()
1238 .enumerate()
1239 .for_each(|(i, frame)| assert_eq!(&frame, expected_frames[i]));
1240
1241 Ok(())
1242 }
1243}