1use std::{
53 borrow::Cow,
54 collections::{BTreeMap, BTreeSet},
55 fmt::{Display, Formatter},
56 mem,
57};
58
59use bitvec::field::BitField;
60
61use crate::{
62 errors::NetworkTypeError,
63 session::{
64 errors::SessionError,
65 frame::{FrameId, FrameInfo, MissingSegmentsBitmap, Segment, SegmentId, SeqNum},
66 },
67};
68
69#[derive(Debug, Clone, PartialEq, Eq, Default)]
73pub struct SegmentRequest<const C: usize>(BTreeMap<FrameId, SeqNum>);
74
75impl<const C: usize> SegmentRequest<C> {
76 pub const ENTRY_SIZE: usize = mem::size_of::<FrameId>() + mem::size_of::<SeqNum>();
78 pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
80 pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = SeqNum::BITS as usize;
82 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
83
84 pub fn len(&self) -> usize {
86 self.0
87 .values()
88 .take(Self::MAX_ENTRIES)
89 .map(|e| e.count_ones() as usize)
90 .sum()
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.0.is_empty()
96 }
97}
98
99impl<const C: usize> IntoIterator for SegmentRequest<C> {
100 type IntoIter = std::vec::IntoIter<SegmentId>;
101 type Item = SegmentId;
102
103 fn into_iter(self) -> Self::IntoIter {
104 let seq_size = SeqNum::BITS as usize;
105 let mut ret = Vec::with_capacity(seq_size * self.0.len());
106 for (frame_id, missing) in self.0 {
107 ret.extend(
108 MissingSegmentsBitmap::from([missing])
109 .iter_ones()
110 .map(|i| SegmentId(frame_id, i as SeqNum)),
111 );
112 }
113 ret.into_iter()
114 }
115}
116
117impl<const C: usize> FromIterator<FrameInfo> for SegmentRequest<C> {
118 fn from_iter<T: IntoIterator<Item = FrameInfo>>(iter: T) -> Self {
119 let mut ret = Self::default();
120 for frame in iter.into_iter().take(Self::MAX_ENTRIES) {
121 let frame_id = frame.frame_id;
122 ret.0.insert(frame_id, frame.missing_segments.load());
123 }
124 ret
125 }
126}
127
128impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
129 type Error = SessionError;
130
131 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
132 if value.len() == Self::SIZE {
133 let mut ret = Self::default();
134 for (frame_id, missing) in value
135 .chunks_exact(Self::ENTRY_SIZE)
136 .map(|c| c.split_at(mem::size_of::<FrameId>()))
137 {
138 let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
139 if frame_id > 0 {
140 ret.0.insert(
141 frame_id,
142 SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
143 );
144 }
145 }
146 Ok(ret)
147 } else {
148 Err(SessionError::ParseError)
149 }
150 }
151}
152
153impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
154 fn from(value: SegmentRequest<C>) -> Self {
155 let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
156 let mut offset = 0;
157 for (frame_id, seq_num) in value.0 {
158 if offset + mem::size_of::<FrameId>() + mem::size_of::<SeqNum>() < C {
159 ret[offset..offset + mem::size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
160 offset += mem::size_of::<FrameId>();
161 ret[offset..offset + mem::size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
162 offset += mem::size_of::<SeqNum>();
163 } else {
164 break;
165 }
166 }
167 ret
168 }
169}
170
171#[derive(Debug, Clone, PartialEq, Eq, Default)]
175pub struct FrameAcknowledgements<const C: usize>(BTreeSet<FrameId>);
176
177impl<const C: usize> FrameAcknowledgements<C> {
178 pub const MAX_ACK_FRAMES: usize = Self::SIZE / mem::size_of::<FrameId>();
180 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
181
182 #[inline]
186 pub fn push(&mut self, frame_id: FrameId) -> bool {
187 !self.is_full() && self.0.insert(frame_id)
188 }
189
190 #[inline]
192 pub fn len(&self) -> usize {
193 self.0.len()
194 }
195
196 pub fn is_empty(&self) -> bool {
198 self.0.is_empty()
199 }
200
201 #[inline]
204 pub fn is_full(&self) -> bool {
205 self.0.len() == Self::MAX_ACK_FRAMES
206 }
207}
208
209impl<const C: usize> From<Vec<FrameId>> for FrameAcknowledgements<C> {
210 fn from(value: Vec<FrameId>) -> Self {
211 Self(
212 value
213 .into_iter()
214 .take(Self::MAX_ACK_FRAMES)
215 .filter(|v| *v > 0)
216 .collect(),
217 )
218 }
219}
220
221impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
222 type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
223 type Item = FrameId;
224
225 fn into_iter(self) -> Self::IntoIter {
226 self.0.into_iter()
227 }
228}
229
230impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
231 type Error = SessionError;
232
233 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
234 if value.len() == Self::SIZE {
235 Ok(Self(
236 value
238 .chunks_exact(mem::size_of::<FrameId>())
239 .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
240 .filter(|f| *f > 0)
241 .collect(),
242 ))
243 } else {
244 Err(SessionError::ParseError)
245 }
246 }
247}
248
249impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
250 fn from(value: FrameAcknowledgements<C>) -> Self {
251 value
252 .0
253 .iter()
254 .flat_map(|v| v.to_be_bytes())
255 .chain(std::iter::repeat(0_u8))
256 .take(FrameAcknowledgements::<C>::SIZE)
257 .collect::<Vec<_>>()
258 }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
265#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
266pub enum SessionMessage<const C: usize> {
267 Segment(Segment),
269 Request(SegmentRequest<C>),
271 Acknowledge(FrameAcknowledgements<C>),
273}
274
275impl<const C: usize> Display for SessionMessage<C> {
276 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
277 match &self {
278 SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
279 SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
280 SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
281 }
282 }
283}
284
285impl<const C: usize> SessionMessage<C> {
286 pub const HEADER_SIZE: usize = 1 + mem::size_of::<SessionMessageDiscriminants>() + mem::size_of::<u16>();
290 pub const MAX_MESSAGE_SIZE: usize = 1492 - Self::SEGMENT_OVERHEAD;
294 pub const MAX_SEGMENTS_PER_FRAME: usize = SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME;
296 pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
300 pub const VERSION: u8 = 1;
302
303 pub fn minimum_message_size() -> usize {
305 Self::HEADER_SIZE
307 + Segment::MINIMUM_SIZE
308 .min(SegmentRequest::<C>::SIZE)
309 .min(FrameAcknowledgements::<C>::SIZE)
310 }
311
312 pub fn into_encoded(self) -> Box<[u8]> {
314 Vec::from(self).into_boxed_slice()
315 }
316}
317
318impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
319 type Error = SessionError;
320
321 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
322 SessionMessageIter::from(value).try_next()
323 }
324}
325
326impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
327 fn from(value: SessionMessage<C>) -> Self {
328 let disc = SessionMessageDiscriminants::from(&value) as u8;
329
330 let msg = match value {
331 SessionMessage::Segment(s) => Vec::from(s),
332 SessionMessage::Request(r) => Vec::from(r),
333 SessionMessage::Acknowledge(a) => Vec::from(a),
334 };
335
336 let msg_len = msg.len() as u16;
337
338 let mut ret = Vec::with_capacity(SessionMessage::<C>::HEADER_SIZE + msg_len as usize);
339 ret.push(SessionMessage::<C>::VERSION);
340 ret.push(disc);
341 ret.extend(msg_len.to_be_bytes());
342 ret.extend(msg);
343 ret
344 }
345}
346
347#[derive(Debug, Clone)]
358pub struct SessionMessageIter<'a, const C: usize> {
359 data: Cow<'a, [u8]>,
360 offset: usize,
361 last_err: Option<SessionError>,
362}
363
364impl<const C: usize> SessionMessageIter<'_, C> {
365 pub fn last_error(&self) -> Option<&SessionError> {
370 self.last_err.as_ref()
371 }
372
373 pub fn is_done(&self) -> bool {
378 self.last_err.is_some() || self.data.len() - self.offset < SessionMessage::<C>::minimum_message_size()
379 }
380
381 fn try_next(&mut self) -> Result<SessionMessage<C>, SessionError> {
383 let mut offset = self.offset;
384
385 if self.data[offset] != SessionMessage::<C>::VERSION {
387 return Err(SessionError::WrongVersion);
388 }
389 offset += 1;
390
391 let disc = self.data[offset];
393 offset += 1;
394
395 let len = u16::from_be_bytes(
397 self.data[offset..offset + mem::size_of::<u16>()]
398 .try_into()
399 .map_err(|_| SessionError::IncorrectMessageLength)?,
400 ) as usize;
401 offset += mem::size_of::<u16>();
402
403 if len > SessionMessage::<C>::MAX_MESSAGE_SIZE {
404 return Err(SessionError::IncorrectMessageLength);
405 }
406
407 let reserved = len & 0b111111_0000000000;
410
411 if reserved != 0 {
413 return Err(SessionError::ParseError);
414 }
415
416 let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
418 SessionMessageDiscriminants::Segment => {
419 SessionMessage::Segment(self.data[offset..offset + len].try_into()?)
420 }
421 SessionMessageDiscriminants::Request => {
422 SessionMessage::Request(self.data[offset..offset + len].try_into()?)
423 }
424 SessionMessageDiscriminants::Acknowledge => {
425 SessionMessage::Acknowledge(self.data[offset..offset + len].try_into()?)
426 }
427 };
428
429 self.offset = offset + len;
431 Ok(res)
432 }
433}
434
435impl<'a, const C: usize, T: Into<Cow<'a, [u8]>>> From<T> for SessionMessageIter<'a, C> {
436 fn from(value: T) -> Self {
437 Self {
438 data: value.into(),
439 offset: 0,
440 last_err: None,
441 }
442 }
443}
444
445impl<const C: usize> Iterator for SessionMessageIter<'_, C> {
446 type Item = Result<SessionMessage<C>, NetworkTypeError>;
447
448 fn next(&mut self) -> Option<Self::Item> {
449 if !self.is_done() {
450 self.try_next()
451 .inspect_err(|e| self.last_err = Some(e.clone()))
452 .map_err(NetworkTypeError::SessionProtocolError)
453 .into()
454 } else {
455 None
456 }
457 }
458}
459
460impl<const C: usize> std::iter::FusedIterator for SessionMessageIter<'_, C> {}
461
462#[cfg(test)]
463mod tests {
464 use std::time::SystemTime;
465
466 use bitvec::bitarr;
467 use hex_literal::hex;
468 use hopr_platform::time::native::current_time;
469 use rand::{Rng, prelude::IteratorRandom, thread_rng};
470
471 use super::*;
472 use crate::session::{
473 Frame,
474 frame::{MissingSegmentsBitmap, NO_MISSING_SEGMENTS},
475 };
476
477 pub const ALL_MISSING_SEGMENTS: MissingSegmentsBitmap =
478 bitarr![SeqNum, bitvec::prelude::Msb0; 1; SeqNum::BITS as usize];
479
480 #[test]
481 fn ensure_session_protocol_version_1_values() {
482 assert_eq!(1, SessionMessage::<0>::VERSION);
484 assert_eq!(4, SessionMessage::<0>::HEADER_SIZE);
485 assert_eq!(10, SessionMessage::<0>::SEGMENT_OVERHEAD);
486 assert_eq!(8, SessionMessage::<0>::MAX_SEGMENTS_PER_FRAME);
487
488 const _: () = {
489 assert!(SessionMessage::<0>::MAX_MESSAGE_SIZE < 2048);
490 };
491 }
492
493 #[test]
494 fn segment_request_should_be_constructible_from_frame_info() {
495 let frames = (1..20)
496 .map(|i| {
497 let mut missing_segments = NO_MISSING_SEGMENTS;
498 (0..7_usize)
499 .choose_multiple(&mut thread_rng(), 4)
500 .into_iter()
501 .for_each(|i| missing_segments.set(i, true));
502 FrameInfo {
503 frame_id: i,
504 missing_segments,
505 total_segments: 8,
506 last_update: SystemTime::UNIX_EPOCH,
507 }
508 })
509 .collect::<Vec<_>>();
510
511 let mut req = SegmentRequest::<466>::from_iter(frames.clone())
512 .into_iter()
513 .collect::<Vec<_>>();
514 req.sort();
515
516 assert_eq!(frames.len() * 4, req.len());
517 assert_eq!(
518 req,
519 frames
520 .into_iter()
521 .flat_map(|f| f.into_missing_segments())
522 .collect::<Vec<_>>()
523 );
524 }
525
526 #[test]
527 fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
528 const SEG_SIZE: usize = 8;
529
530 let mut segments = Frame {
531 frame_id: 10,
532 data: hex!("deadbeefcafebabe").into(),
533 }
534 .segment(SEG_SIZE)?;
535
536 const MTU: usize = SEG_SIZE + Segment::HEADER_SIZE + 2;
537
538 let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
539 let data = Vec::from(msg_1.clone());
540 let msg_2 = SessionMessage::try_from(&data[..])?;
541
542 assert_eq!(msg_1, msg_2);
543
544 Ok(())
545 }
546
547 #[test]
548 fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
549 let frame_info = FrameInfo {
550 frame_id: 10,
551 total_segments: 8,
552 missing_segments: [0b10100001].into(),
553 last_update: SystemTime::now(),
554 };
555
556 let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter(vec![frame_info]));
557 let data = Vec::from(msg_1.clone());
558 let msg_2 = SessionMessage::try_from(&data[..])?;
559
560 assert_eq!(msg_1, msg_2);
561
562 match msg_1 {
563 SessionMessage::Request(r) => {
564 let missing_segments = r.into_iter().collect::<Vec<_>>();
565 let expected = vec![SegmentId(10, 0), SegmentId(10, 2), SegmentId(10, 7)];
566 assert_eq!(expected, missing_segments);
567 }
568 _ => panic!("invalid type"),
569 }
570
571 Ok(())
572 }
573
574 #[test]
575 fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
576 let mut rng = thread_rng();
577 let frame_ids: Vec<u32> = (0..500).map(|_| rng.r#gen()).collect();
578
579 let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.into());
580 let data = Vec::from(msg_1.clone());
581 let msg_2 = SessionMessage::try_from(&data[..])?;
582
583 assert_eq!(msg_1, msg_2);
584
585 Ok(())
586 }
587
588 #[test]
589 fn session_message_segment_request_should_yield_correct_bitset_values() {
590 let seg_req = SegmentRequest::<466>([(3, 0b01000001), (10, 0b00101000)].into());
591
592 let mut iter = seg_req.into_iter();
593 assert_eq!(iter.next(), Some(SegmentId(3, 1)));
594 assert_eq!(iter.next(), Some(SegmentId(3, 7)));
595 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
596 assert_eq!(iter.next(), Some(SegmentId(10, 4)));
597 assert_eq!(iter.next(), None);
598
599 let mut frame_info = FrameInfo {
600 frame_id: 10,
601 missing_segments: NO_MISSING_SEGMENTS,
602 total_segments: 10,
603 last_update: current_time(),
604 };
605 frame_info.missing_segments.set(2, true);
606 frame_info.missing_segments.set(4, true);
607
608 let mut iter = frame_info.clone().into_missing_segments();
609
610 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
611 assert_eq!(iter.next(), Some(SegmentId(10, 4)));
612 assert_eq!(iter.next(), None);
613
614 let mut iter = SegmentRequest::<466>::from_iter(vec![frame_info]).into_iter();
615 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
616 assert_eq!(iter.next(), Some(SegmentId(10, 4)));
617 assert_eq!(iter.next(), None);
618 }
619
620 #[test]
621 fn session_message_iter_should_be_empty_if_slice_has_no_messages() {
622 const MTU: usize = 462;
623
624 let mut iter = SessionMessageIter::<MTU>::from(Vec::<u8>::new());
625 assert!(iter.next().is_none());
626 assert!(iter.is_done());
627
628 let mut iter = SessionMessageIter::<MTU>::from(&[0u8; 2]);
629 assert!(iter.next().is_none());
630 assert!(iter.is_done());
631 }
632
633 #[test]
634 fn session_message_iter_should_deserialize_multiple_messages() -> anyhow::Result<()> {
635 const MTU: usize = 462;
636
637 let mut messages_1 = Frame {
638 frame_id: 10,
639 data: hopr_crypto_random::random_bytes::<1500>().into(),
640 }
641 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
642 .into_iter()
643 .map(SessionMessage::<MTU>::Segment)
644 .collect::<Vec<_>>();
645
646 let frame_info = FrameInfo {
647 frame_id: 10,
648 total_segments: 255,
649 missing_segments: ALL_MISSING_SEGMENTS,
650 last_update: SystemTime::now(),
651 };
652
653 messages_1.push(SessionMessage::<MTU>::Request(SegmentRequest::from_iter(vec![
654 frame_info,
655 ])));
656
657 let mut rng = thread_rng();
658 let frame_ids: Vec<u32> = (0..100).map(|_| rng.r#gen()).collect();
659 messages_1.push(SessionMessage::<MTU>::Acknowledge(frame_ids.into()));
660
661 let iter = SessionMessageIter::<MTU>::from(
662 messages_1
663 .iter()
664 .cloned()
665 .flat_map(|m| m.into_encoded().into_vec())
666 .chain(std::iter::repeat_n(0, 10))
667 .collect::<Vec<u8>>(),
668 );
669
670 let messages_2 = iter.collect::<Result<Vec<_>, _>>()?;
671 assert_eq!(messages_1, messages_2);
672
673 Ok(())
674 }
675
676 #[test]
677 fn session_message_iter_should_not_contain_error_when_consuming_everything() -> anyhow::Result<()> {
678 const MTU: usize = 462;
679
680 let messages = Frame {
681 frame_id: 10,
682 data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
683 }
684 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
685 .into_iter()
686 .map(SessionMessage::<MTU>::Segment)
687 .collect::<Vec<_>>();
688
689 assert_eq!(4, messages.len());
690
691 let data = messages
692 .iter()
693 .cloned()
694 .flat_map(|m| m.into_encoded().into_vec())
695 .chain(std::iter::repeat_n(0u8, 10))
696 .collect::<Vec<_>>();
697
698 let mut iter = SessionMessageIter::<MTU>::from(data);
699 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
700 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
701 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[2]));
702 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[3]));
703
704 assert!(iter.next().is_none());
705 assert!(iter.last_error().is_none());
706 assert!(iter.is_done());
707
708 Ok(())
709 }
710
711 #[test]
712 fn session_message_iter_should_not_yield_more_after_error() -> anyhow::Result<()> {
713 const MTU: usize = 462;
714
715 let messages = Frame {
716 frame_id: 10,
717 data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
718 }
719 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
720 .into_iter()
721 .map(SessionMessage::<MTU>::Segment)
722 .collect::<Vec<_>>();
723
724 assert_eq!(4, messages.len());
725
726 let data = messages
727 .iter()
728 .cloned()
729 .enumerate()
730 .flat_map(|(i, m)| {
731 if i == 2 {
732 Vec::from(hopr_crypto_random::random_bytes::<MTU>())
733 } else {
734 m.into_encoded().into_vec()
735 }
736 })
737 .collect::<Vec<_>>();
738
739 let mut iter = SessionMessageIter::<MTU>::from(data);
740 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
741 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
742
743 let err = iter.next();
744 assert!(matches!(err, Some(Err(_))));
745 assert!(iter.is_done());
746 assert!(iter.last_error().is_some());
747
748 assert!(iter.next().is_none());
749
750 Ok(())
751 }
752}