hopr_network_types/session/
protocol.rs1use std::borrow::Cow;
54use std::collections::{BTreeMap, BTreeSet};
55use std::fmt::{Display, Formatter};
56use std::mem;
57
58use crate::errors::NetworkTypeError;
59use crate::session::errors::SessionError;
60use crate::session::frame::{FrameId, FrameInfo, Segment, SegmentId, SeqNum};
61
62#[derive(Debug, Clone, PartialEq, Eq, Default)]
66pub struct SegmentRequest<const C: usize>(BTreeMap<FrameId, SeqNum>);
67
68impl<const C: usize> SegmentRequest<C> {
69 pub const ENTRY_SIZE: usize = mem::size_of::<FrameId>() + mem::size_of::<SeqNum>();
71
72 pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = mem::size_of::<SeqNum>() * 8;
74
75 pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
77
78 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
79
80 pub fn len(&self) -> usize {
82 self.0
83 .values()
84 .take(Self::MAX_ENTRIES)
85 .map(|e| e.count_ones() as usize)
86 .sum()
87 }
88
89 pub fn is_empty(&self) -> bool {
91 self.0.is_empty()
92 }
93}
94
95pub struct SegmentIdIter(Vec<SegmentId>);
97
98impl Iterator for SegmentIdIter {
99 type Item = SegmentId;
100
101 fn next(&mut self) -> Option<Self::Item> {
102 self.0.pop()
103 }
104}
105
106impl<const C: usize> IntoIterator for SegmentRequest<C> {
107 type Item = SegmentId;
108 type IntoIter = SegmentIdIter;
109
110 fn into_iter(self) -> Self::IntoIter {
111 let seq_size = mem::size_of::<SeqNum>() * 8;
112 let mut ret = SegmentIdIter(Vec::with_capacity(seq_size * 8 * self.0.len()));
113 for (frame_id, missing) in self.0 {
114 for i in (0..seq_size).rev() {
115 let mask = (1 << i) as SeqNum;
116 if (mask & missing) != 0 {
117 ret.0.push(SegmentId(frame_id, i as SeqNum));
118 }
119 }
120 }
121 ret.0.shrink_to_fit();
122 ret
123 }
124}
125
126impl<const C: usize> FromIterator<FrameInfo> for SegmentRequest<C> {
127 fn from_iter<T: IntoIterator<Item = FrameInfo>>(iter: T) -> Self {
128 let mut ret = Self::default();
129 for frame in iter.into_iter().take(Self::MAX_ENTRIES) {
130 let frame_id = frame.frame_id;
131 let missing = frame
132 .iter_missing_sequence_indices()
133 .filter(|s| *s < Self::MAX_MISSING_SEGMENTS_PER_FRAME as SeqNum)
134 .map(|idx| 1 << idx)
135 .fold(SeqNum::default(), |acc, n| acc | n);
136 ret.0.insert(frame_id, missing);
137 }
138 ret
139 }
140}
141
142impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
143 type Error = SessionError;
144
145 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
146 if value.len() == Self::SIZE {
147 let mut ret = Self::default();
148 for (frame_id, missing) in value
149 .chunks_exact(Self::ENTRY_SIZE)
150 .map(|c| c.split_at(mem::size_of::<FrameId>()))
151 {
152 let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
153 if frame_id > 0 {
154 ret.0.insert(
155 frame_id,
156 SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
157 );
158 }
159 }
160 Ok(ret)
161 } else {
162 Err(SessionError::ParseError)
163 }
164 }
165}
166
167impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
168 fn from(value: SegmentRequest<C>) -> Self {
169 let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
170 let mut offset = 0;
171 for (frame_id, seq_num) in value.0 {
172 if offset + mem::size_of::<FrameId>() + mem::size_of::<SeqNum>() < C {
173 ret[offset..offset + mem::size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
174 offset += mem::size_of::<FrameId>();
175 ret[offset..offset + mem::size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
176 offset += mem::size_of::<SeqNum>();
177 } else {
178 break;
179 }
180 }
181 ret
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Default)]
189pub struct FrameAcknowledgements<const C: usize>(BTreeSet<FrameId>);
190
191impl<const C: usize> FrameAcknowledgements<C> {
192 pub const MAX_ACK_FRAMES: usize = Self::SIZE / mem::size_of::<FrameId>();
194
195 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
196
197 #[inline]
201 pub fn push(&mut self, frame_id: FrameId) -> bool {
202 !self.is_full() && self.0.insert(frame_id)
203 }
204
205 #[inline]
207 pub fn len(&self) -> usize {
208 self.0.len()
209 }
210
211 pub fn is_empty(&self) -> bool {
213 self.0.is_empty()
214 }
215
216 #[inline]
219 pub fn is_full(&self) -> bool {
220 self.0.len() == Self::MAX_ACK_FRAMES
221 }
222}
223
224impl<const C: usize> From<Vec<FrameId>> for FrameAcknowledgements<C> {
225 fn from(value: Vec<FrameId>) -> Self {
226 Self(
227 value
228 .into_iter()
229 .take(Self::MAX_ACK_FRAMES)
230 .filter(|v| *v > 0)
231 .collect(),
232 )
233 }
234}
235
236impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
237 type Item = FrameId;
238 type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
239
240 fn into_iter(self) -> Self::IntoIter {
241 self.0.into_iter()
242 }
243}
244
245impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
246 type Error = SessionError;
247
248 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
249 if value.len() == Self::SIZE {
250 Ok(Self(
251 value
253 .chunks_exact(mem::size_of::<FrameId>())
254 .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
255 .filter(|f| *f > 0)
256 .collect(),
257 ))
258 } else {
259 Err(SessionError::ParseError)
260 }
261 }
262}
263
264impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
265 fn from(value: FrameAcknowledgements<C>) -> Self {
266 value
267 .0
268 .iter()
269 .flat_map(|v| v.to_be_bytes())
270 .chain(std::iter::repeat(0_u8))
271 .take(FrameAcknowledgements::<C>::SIZE)
272 .collect::<Vec<_>>()
273 }
274}
275
276#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
280#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
281pub enum SessionMessage<const C: usize> {
282 Segment(Segment),
284 Request(SegmentRequest<C>),
286 Acknowledge(FrameAcknowledgements<C>),
288}
289
290impl<const C: usize> Display for SessionMessage<C> {
291 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292 match &self {
293 SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
294 SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
295 SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
296 }
297 }
298}
299
300impl<const C: usize> SessionMessage<C> {
301 pub const HEADER_SIZE: usize = 1 + mem::size_of::<SessionMessageDiscriminants>() + mem::size_of::<u16>();
305
306 pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
310
311 pub const MAX_MESSAGE_SIZE: usize = 1492 - Self::SEGMENT_OVERHEAD;
315
316 pub const VERSION: u8 = 1;
318
319 pub const MAX_SEGMENTS_PER_FRAME: usize = SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME;
321
322 pub fn minimum_message_size() -> usize {
324 Self::HEADER_SIZE
326 + Segment::MINIMUM_SIZE
327 .min(SegmentRequest::<C>::SIZE)
328 .min(FrameAcknowledgements::<C>::SIZE)
329 }
330
331 pub fn into_encoded(self) -> Box<[u8]> {
333 Vec::from(self).into_boxed_slice()
334 }
335}
336
337impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
338 type Error = SessionError;
339
340 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
341 SessionMessageIter::from(value).try_next()
342 }
343}
344
345impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
346 fn from(value: SessionMessage<C>) -> Self {
347 let disc = SessionMessageDiscriminants::from(&value) as u8;
348
349 let msg = match value {
350 SessionMessage::Segment(s) => Vec::from(s),
351 SessionMessage::Request(r) => Vec::from(r),
352 SessionMessage::Acknowledge(a) => Vec::from(a),
353 };
354
355 let msg_len = msg.len() as u16;
356
357 let mut ret = Vec::with_capacity(SessionMessage::<C>::HEADER_SIZE + msg_len as usize);
358 ret.push(SessionMessage::<C>::VERSION);
359 ret.push(disc);
360 ret.extend(msg_len.to_be_bytes());
361 ret.extend(msg);
362 ret
363 }
364}
365
366#[derive(Debug, Clone)]
377pub struct SessionMessageIter<'a, const C: usize> {
378 data: Cow<'a, [u8]>,
379 offset: usize,
380 last_err: Option<SessionError>,
381}
382
383impl<const C: usize> SessionMessageIter<'_, C> {
384 pub fn last_error(&self) -> Option<&SessionError> {
389 self.last_err.as_ref()
390 }
391
392 pub fn is_done(&self) -> bool {
397 self.last_err.is_some() || self.data.len() - self.offset < SessionMessage::<C>::minimum_message_size()
398 }
399
400 fn try_next(&mut self) -> Result<SessionMessage<C>, SessionError> {
402 let mut offset = self.offset;
403
404 if self.data[offset] != SessionMessage::<C>::VERSION {
406 return Err(SessionError::WrongVersion);
407 }
408 offset += 1;
409
410 let disc = self.data[offset];
412 offset += 1;
413
414 let len = u16::from_be_bytes(
416 self.data[offset..offset + mem::size_of::<u16>()]
417 .try_into()
418 .map_err(|_| SessionError::IncorrectMessageLength)?,
419 ) as usize;
420 offset += mem::size_of::<u16>();
421
422 if len > SessionMessage::<C>::MAX_MESSAGE_SIZE {
423 return Err(SessionError::IncorrectMessageLength);
424 }
425
426 let reserved = len & 0b111111_0000000000;
429
430 if reserved != 0 {
432 return Err(SessionError::ParseError);
433 }
434
435 let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
437 SessionMessageDiscriminants::Segment => {
438 SessionMessage::Segment(self.data[offset..offset + len].try_into()?)
439 }
440 SessionMessageDiscriminants::Request => {
441 SessionMessage::Request(self.data[offset..offset + len].try_into()?)
442 }
443 SessionMessageDiscriminants::Acknowledge => {
444 SessionMessage::Acknowledge(self.data[offset..offset + len].try_into()?)
445 }
446 };
447
448 self.offset = offset + len;
450 Ok(res)
451 }
452}
453
454impl<'a, const C: usize, T: Into<Cow<'a, [u8]>>> From<T> for SessionMessageIter<'a, C> {
455 fn from(value: T) -> Self {
456 Self {
457 data: value.into(),
458 offset: 0,
459 last_err: None,
460 }
461 }
462}
463
464impl<const C: usize> Iterator for SessionMessageIter<'_, C> {
465 type Item = Result<SessionMessage<C>, NetworkTypeError>;
466
467 fn next(&mut self) -> Option<Self::Item> {
468 if !self.is_done() {
469 self.try_next()
470 .inspect_err(|e| self.last_err = Some(e.clone()))
471 .map_err(NetworkTypeError::SessionProtocolError)
472 .into()
473 } else {
474 None
475 }
476 }
477}
478
479impl<const C: usize> std::iter::FusedIterator for SessionMessageIter<'_, C> {}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::session::Frame;
485 use bitvec::array::BitArray;
486 use bitvec::bitarr;
487 use hex_literal::hex;
488 use hopr_platform::time::native::current_time;
489 use rand::prelude::IteratorRandom;
490 use rand::{thread_rng, Rng};
491 use std::time::SystemTime;
492
493 #[test]
494 fn ensure_session_protocol_version_1_values() {
495 assert_eq!(1, SessionMessage::<0>::VERSION);
497 assert_eq!(4, SessionMessage::<0>::HEADER_SIZE);
498 assert_eq!(10, SessionMessage::<0>::SEGMENT_OVERHEAD);
499 assert_eq!(8, SessionMessage::<0>::MAX_SEGMENTS_PER_FRAME);
500
501 assert!(SessionMessage::<0>::MAX_MESSAGE_SIZE < 2048);
502 }
503
504 #[test]
505 fn segment_request_should_be_constructible_from_frame_info() {
506 let frames = (1..20)
507 .map(|i| {
508 let mut missing_segments = BitArray::ZERO;
509 (0..7_usize)
510 .choose_multiple(&mut thread_rng(), 4)
511 .into_iter()
512 .for_each(|i| missing_segments.set(i, true));
513 FrameInfo {
514 frame_id: i,
515 missing_segments,
516 total_segments: 8,
517 last_update: SystemTime::UNIX_EPOCH,
518 }
519 })
520 .collect::<Vec<_>>();
521
522 let mut req = SegmentRequest::<466>::from_iter(frames.clone())
523 .into_iter()
524 .collect::<Vec<_>>();
525 req.sort();
526
527 assert_eq!(frames.len() * 4, req.len());
528 assert_eq!(
529 req,
530 frames
531 .into_iter()
532 .flat_map(|f| f.into_missing_segments())
533 .collect::<Vec<_>>()
534 );
535 }
536
537 #[test]
538 fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
539 const SEG_SIZE: usize = 8;
540
541 let mut segments = Frame {
542 frame_id: 10,
543 data: hex!("deadbeefcafebabe").into(),
544 }
545 .segment(SEG_SIZE)?;
546
547 const MTU: usize = SEG_SIZE + Segment::HEADER_SIZE + 2;
548
549 let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
550 let data = Vec::from(msg_1.clone());
551 let msg_2 = SessionMessage::try_from(&data[..])?;
552
553 assert_eq!(msg_1, msg_2);
554
555 Ok(())
556 }
557
558 #[test]
559 fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
560 let frame_info = FrameInfo {
561 frame_id: 10,
562 total_segments: 255,
563 missing_segments: bitarr![1; 256],
564 last_update: SystemTime::now(),
565 };
566
567 let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter(vec![frame_info]));
568 let data = Vec::from(msg_1.clone());
569 let msg_2 = SessionMessage::try_from(&data[..])?;
570
571 assert_eq!(msg_1, msg_2);
572
573 match msg_1 {
574 SessionMessage::Request(r) => {
575 let missing_segments = r.into_iter().collect::<Vec<_>>();
576 let expected = (0..=7).map(|s| SegmentId(10, s)).collect::<Vec<_>>();
577 assert_eq!(expected, missing_segments);
578 }
579 _ => panic!("invalid type"),
580 }
581
582 Ok(())
583 }
584
585 #[test]
586 fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
587 let mut rng = thread_rng();
588 let frame_ids: Vec<u32> = (0..500).map(|_| rng.gen()).collect();
589
590 let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.into());
591 let data = Vec::from(msg_1.clone());
592 let msg_2 = SessionMessage::try_from(&data[..])?;
593
594 assert_eq!(msg_1, msg_2);
595
596 Ok(())
597 }
598
599 #[test]
600 fn session_message_segment_request_should_yield_correct_bitset_values() {
601 let seg_req = SegmentRequest::<466>([(10, 0b00100100)].into());
602
603 let mut iter = seg_req.into_iter();
604 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
605 assert_eq!(iter.next(), Some(SegmentId(10, 5)));
606 assert_eq!(iter.next(), None);
607
608 let mut frame_info = FrameInfo {
609 frame_id: 10,
610 missing_segments: bitarr![0; 256],
611 total_segments: 10,
612 last_update: current_time(),
613 };
614 frame_info.missing_segments.set(2, true);
615 frame_info.missing_segments.set(5, true);
616
617 let mut iter = SegmentRequest::<466>::from_iter(vec![frame_info]).into_iter();
618 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
619 assert_eq!(iter.next(), Some(SegmentId(10, 5)));
620 assert_eq!(iter.next(), None);
621 }
622
623 #[test]
624 fn session_message_iter_should_be_empty_if_slice_has_no_messages() {
625 const MTU: usize = 462;
626
627 let mut iter = SessionMessageIter::<MTU>::from(Vec::<u8>::new());
628 assert!(iter.next().is_none());
629 assert!(iter.is_done());
630
631 let mut iter = SessionMessageIter::<MTU>::from(&[0u8; 2]);
632 assert!(iter.next().is_none());
633 assert!(iter.is_done());
634 }
635
636 #[test]
637 fn session_message_iter_should_deserialize_multiple_messages() -> anyhow::Result<()> {
638 const MTU: usize = 462;
639
640 let mut messages_1 = Frame {
641 frame_id: 10,
642 data: hopr_crypto_random::random_bytes::<1500>().into(),
643 }
644 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
645 .into_iter()
646 .map(SessionMessage::<MTU>::Segment)
647 .collect::<Vec<_>>();
648
649 let frame_info = FrameInfo {
650 frame_id: 10,
651 total_segments: 255,
652 missing_segments: bitarr![1; 256],
653 last_update: SystemTime::now(),
654 };
655
656 messages_1.push(SessionMessage::<MTU>::Request(SegmentRequest::from_iter(vec![
657 frame_info,
658 ])));
659
660 let mut rng = thread_rng();
661 let frame_ids: Vec<u32> = (0..100).map(|_| rng.gen()).collect();
662 messages_1.push(SessionMessage::<MTU>::Acknowledge(frame_ids.into()));
663
664 let iter = SessionMessageIter::<MTU>::from(
665 messages_1
666 .iter()
667 .cloned()
668 .flat_map(|m| m.into_encoded().into_vec())
669 .chain(std::iter::repeat(0).take(10))
670 .collect::<Vec<u8>>(),
671 );
672
673 let messages_2 = iter.collect::<Result<Vec<_>, _>>()?;
674 assert_eq!(messages_1, messages_2);
675
676 Ok(())
677 }
678
679 #[test]
680 fn session_message_iter_should_not_contain_error_when_consuming_everything() -> anyhow::Result<()> {
681 const MTU: usize = 462;
682
683 let messages = Frame {
684 frame_id: 10,
685 data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
686 }
687 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
688 .into_iter()
689 .map(SessionMessage::<MTU>::Segment)
690 .collect::<Vec<_>>();
691
692 assert_eq!(4, messages.len());
693
694 let data = messages
695 .iter()
696 .cloned()
697 .flat_map(|m| m.into_encoded().into_vec())
698 .chain(std::iter::repeat(0u8).take(10))
699 .collect::<Vec<_>>();
700
701 let mut iter = SessionMessageIter::<MTU>::from(data);
702 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
703 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
704 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[2]));
705 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[3]));
706
707 assert!(iter.next().is_none());
708 assert!(iter.last_error().is_none());
709 assert!(iter.is_done());
710
711 Ok(())
712 }
713
714 #[test]
715 fn session_message_iter_should_not_yield_more_after_error() -> anyhow::Result<()> {
716 const MTU: usize = 462;
717
718 let messages = Frame {
719 frame_id: 10,
720 data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
721 }
722 .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
723 .into_iter()
724 .map(SessionMessage::<MTU>::Segment)
725 .collect::<Vec<_>>();
726
727 assert_eq!(4, messages.len());
728
729 let data = messages
730 .iter()
731 .cloned()
732 .enumerate()
733 .flat_map(|(i, m)| {
734 if i == 2 {
735 Vec::from(hopr_crypto_random::random_bytes::<MTU>())
736 } else {
737 m.into_encoded().into_vec()
738 }
739 })
740 .collect::<Vec<_>>();
741
742 let mut iter = SessionMessageIter::<MTU>::from(data);
743 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
744 assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
745
746 let err = iter.next();
747 assert!(matches!(err, Some(Err(_))));
748 assert!(iter.is_done());
749 assert!(iter.last_error().is_some());
750
751 assert!(iter.next().is_none());
752
753 Ok(())
754 }
755}