hopr_network_types/session/
protocol.rs

1//! # `Session` protocol messages
2//!
3//! The protocol components are built via low-level types of the `frame` module, such as
4//! [`Segment`] and [`Frame`](crate::session::Frame).
5//! Most importantly, the `Session` protocol fixes the maximum number of segments per frame
6//! to 8 (see [`MAX_SEGMENTS_PER_FRAME`](SessionMessage::MAX_SEGMENTS_PER_FRAME)).
7//! Since each segment must fit within a maximum transmission unit (MTU),
8//! a frame can be at most *eight* times the size of the MTU.
9//!
10//! The [current version](SessionMessage::VERSION) of the protocol consists of three
11//! messages that are sent and received via the underlying transport:
12//! - [`Segment message`](Segment)
13//! - [`Retransmission request`](SegmentRequest)
14//! - [`Frame acknowledgement`](FrameAcknowledgements)
15//!
16//! All of these messages are bundled within the [`SessionMessage`] enum,
17//! which is then [encoded](SessionMessage::into_encoded) as a byte array of a maximum
18//! MTU size `C` (which is a generic const argument of the `SessionMessage` type).
19//! The header of the `SessionMessage` encoding consists of the [`version`](SessionMessage::VERSION)
20//! byte, followed by the discriminator byte of one of the above messages and then followed by
21//! the message length and message's encoding itself.
22//!
23//! Multiple [`SessionMessages`](SessionMessage) can be read from a binary blob using the
24//! [`SessionMessageIter`].
25//!
26//! ## Segment message ([`Segment`](SessionMessage::Segment))
27//! The Segment message contains the payload [`Segment`] of some [`Frame`](crate::session::Frame).
28//! The size of this message can range from [`the minimum message size`](SessionMessage::minimum_message_size)
29//! up to `C`.
30//!
31//! ## Retransmission request message ([`Request`](SessionMessage::Request))
32//! Contains a request for retransmission of missing segments in a frame. This is sent from
33//! the segment recipient to the sender, once it realizes some of the received frames are incomplete
34//! (after a certain period of time).
35//!
36//! The encoding of this message consists of pairs of [frame ID](FrameId) and
37//! a single byte bitmap of requested segments in this frame.
38//! Each pair is therefore [`ENTRY_SIZE`](SegmentRequest::ENTRY_SIZE) bytes long.
39//! There can be at most [`MAX_ENTRIES`](SegmentRequest::MAX_ENTRIES)
40//! in a single Retransmission request message, given `C` as the MTU size. If the message contains
41//! fewer entries, it is padded with zeros (0 is not a valid frame ID).
42//! If more frames have missing segments, multiple retransmission request messages need to be sent.
43//!
44//! ## Frame acknowledgement message ([`Acknowledge`](SessionMessage::Acknowledge))
45//! This message is sent from the segment recipient to the segment sender, to acknowledge that
46//! all segments of certain frames have been completely and correctly received by the recipient.
47//!
48//! The message consists simply of a [frame ID](FrameId) list of the completely received
49//! frames. There can be at most [`MAX_ACK_FRAMES`](FrameAcknowledgements::MAX_ACK_FRAMES)
50//! per message. If more frames need to be acknowledged, more messages need to be sent.
51//! If the message contains fewer entries, it is padded with zeros (0 is not a valid frame ID).
52//!
53use std::borrow::Cow;
54use std::collections::{BTreeMap, BTreeSet};
55use std::fmt::{Display, Formatter};
56use std::mem;
57
58use crate::errors::NetworkTypeError;
59use crate::session::errors::SessionError;
60use crate::session::frame::{FrameId, FrameInfo, Segment, SegmentId, SeqNum};
61
62/// Holds the Segment Retransmission Request message.
63/// That is an ordered map of frame IDs and a bitmap of missing segments in each frame.
64/// The bitmap can cover up a request for up to [`SegmentRequest::MAX_ENTRIES`] segments.
65#[derive(Debug, Clone, PartialEq, Eq, Default)]
66pub struct SegmentRequest<const C: usize>(BTreeMap<FrameId, SeqNum>);
67
68impl<const C: usize> SegmentRequest<C> {
69    /// Size of a single segment retransmission request entry.
70    pub const ENTRY_SIZE: usize = mem::size_of::<FrameId>() + mem::size_of::<SeqNum>();
71
72    /// Maximum number of missing segments per frame.
73    pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = mem::size_of::<SeqNum>() * 8;
74
75    /// Maximum number of segment retransmission entries.
76    pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
77
78    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
79
80    /// Returns the number of segments to retransmit.
81    pub fn len(&self) -> usize {
82        self.0
83            .values()
84            .take(Self::MAX_ENTRIES)
85            .map(|e| e.count_ones() as usize)
86            .sum()
87    }
88
89    /// Returns true if there are no segments to retransmit in this request.
90    pub fn is_empty(&self) -> bool {
91        self.0.is_empty()
92    }
93}
94
95/// Iterator over [`SegmentId`] in [`SegmentRequest`].
96pub struct SegmentIdIter(Vec<SegmentId>);
97
98impl Iterator for SegmentIdIter {
99    type Item = SegmentId;
100
101    fn next(&mut self) -> Option<Self::Item> {
102        self.0.pop()
103    }
104}
105
106impl<const C: usize> IntoIterator for SegmentRequest<C> {
107    type Item = SegmentId;
108    type IntoIter = SegmentIdIter;
109
110    fn into_iter(self) -> Self::IntoIter {
111        let seq_size = mem::size_of::<SeqNum>() * 8;
112        let mut ret = SegmentIdIter(Vec::with_capacity(seq_size * 8 * self.0.len()));
113        for (frame_id, missing) in self.0 {
114            for i in (0..seq_size).rev() {
115                let mask = (1 << i) as SeqNum;
116                if (mask & missing) != 0 {
117                    ret.0.push(SegmentId(frame_id, i as SeqNum));
118                }
119            }
120        }
121        ret.0.shrink_to_fit();
122        ret
123    }
124}
125
126impl<const C: usize> FromIterator<FrameInfo> for SegmentRequest<C> {
127    fn from_iter<T: IntoIterator<Item = FrameInfo>>(iter: T) -> Self {
128        let mut ret = Self::default();
129        for frame in iter.into_iter().take(Self::MAX_ENTRIES) {
130            let frame_id = frame.frame_id;
131            let missing = frame
132                .iter_missing_sequence_indices()
133                .filter(|s| *s < Self::MAX_MISSING_SEGMENTS_PER_FRAME as SeqNum)
134                .map(|idx| 1 << idx)
135                .fold(SeqNum::default(), |acc, n| acc | n);
136            ret.0.insert(frame_id, missing);
137        }
138        ret
139    }
140}
141
142impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
143    type Error = SessionError;
144
145    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
146        if value.len() == Self::SIZE {
147            let mut ret = Self::default();
148            for (frame_id, missing) in value
149                .chunks_exact(Self::ENTRY_SIZE)
150                .map(|c| c.split_at(mem::size_of::<FrameId>()))
151            {
152                let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
153                if frame_id > 0 {
154                    ret.0.insert(
155                        frame_id,
156                        SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
157                    );
158                }
159            }
160            Ok(ret)
161        } else {
162            Err(SessionError::ParseError)
163        }
164    }
165}
166
167impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
168    fn from(value: SegmentRequest<C>) -> Self {
169        let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
170        let mut offset = 0;
171        for (frame_id, seq_num) in value.0 {
172            if offset + mem::size_of::<FrameId>() + mem::size_of::<SeqNum>() < C {
173                ret[offset..offset + mem::size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
174                offset += mem::size_of::<FrameId>();
175                ret[offset..offset + mem::size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
176                offset += mem::size_of::<SeqNum>();
177            } else {
178                break;
179            }
180        }
181        ret
182    }
183}
184
185/// Holds the Frame Acknowledgement message.
186/// This carries an ordered set of up to [`FrameAcknowledgements::MAX_ACK_FRAMES`] [frame IDs](FrameId) that have
187/// been acknowledged by the counterparty.
188#[derive(Debug, Clone, PartialEq, Eq, Default)]
189pub struct FrameAcknowledgements<const C: usize>(BTreeSet<FrameId>);
190
191impl<const C: usize> FrameAcknowledgements<C> {
192    /// Maximum number of [frame IDs](FrameId) that can be accommodated.
193    pub const MAX_ACK_FRAMES: usize = Self::SIZE / mem::size_of::<FrameId>();
194
195    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
196
197    /// Pushes the frame ID.
198    /// Returns true if the value has been pushed or false it the container is full or already
199    /// contains that value.
200    #[inline]
201    pub fn push(&mut self, frame_id: FrameId) -> bool {
202        !self.is_full() && self.0.insert(frame_id)
203    }
204
205    /// Number of acknowledged frame IDs in this instance.
206    #[inline]
207    pub fn len(&self) -> usize {
208        self.0.len()
209    }
210
211    /// Returns true if there are no frame IDs in this instance.
212    pub fn is_empty(&self) -> bool {
213        self.0.is_empty()
214    }
215
216    /// Indicates whether the [maximum number of frame IDs](FrameAcknowledgements::MAX_ACK_FRAMES)
217    /// has been reached.
218    #[inline]
219    pub fn is_full(&self) -> bool {
220        self.0.len() == Self::MAX_ACK_FRAMES
221    }
222}
223
224impl<const C: usize> From<Vec<FrameId>> for FrameAcknowledgements<C> {
225    fn from(value: Vec<FrameId>) -> Self {
226        Self(
227            value
228                .into_iter()
229                .take(Self::MAX_ACK_FRAMES)
230                .filter(|v| *v > 0)
231                .collect(),
232        )
233    }
234}
235
236impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
237    type Item = FrameId;
238    type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
239
240    fn into_iter(self) -> Self::IntoIter {
241        self.0.into_iter()
242    }
243}
244
245impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
246    type Error = SessionError;
247
248    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
249        if value.len() == Self::SIZE {
250            Ok(Self(
251                // chunks_exact discards the remainder bytes
252                value
253                    .chunks_exact(mem::size_of::<FrameId>())
254                    .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
255                    .filter(|f| *f > 0)
256                    .collect(),
257            ))
258        } else {
259            Err(SessionError::ParseError)
260        }
261    }
262}
263
264impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
265    fn from(value: FrameAcknowledgements<C>) -> Self {
266        value
267            .0
268            .iter()
269            .flat_map(|v| v.to_be_bytes())
270            .chain(std::iter::repeat(0_u8))
271            .take(FrameAcknowledgements::<C>::SIZE)
272            .collect::<Vec<_>>()
273    }
274}
275
276/// Contains all messages of the Session sub-protocol.
277///
278/// The maximum size of the Session sub-protocol message is given by `C`.
279#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
280#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
281pub enum SessionMessage<const C: usize> {
282    /// Represents a message containing a segment.
283    Segment(Segment),
284    /// Represents a message containing a [request](SegmentRequest) for segments.
285    Request(SegmentRequest<C>),
286    /// Represents a message containing [frame acknowledgements](FrameAcknowledgements).
287    Acknowledge(FrameAcknowledgements<C>),
288}
289
290impl<const C: usize> Display for SessionMessage<C> {
291    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292        match &self {
293            SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
294            SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
295            SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
296        }
297    }
298}
299
300impl<const C: usize> SessionMessage<C> {
301    /// Header size of the session message.
302    /// This is currently the version byte, the size of [SessionMessageDiscriminants] representation
303    /// and two bytes for the message length.
304    pub const HEADER_SIZE: usize = 1 + mem::size_of::<SessionMessageDiscriminants>() + mem::size_of::<u16>();
305
306    /// Size of the overhead that's added to the raw payload of each [`Segment`].
307    ///
308    /// This amounts to [`SessionMessage::HEADER_SIZE`] + [`Segment::HEADER_SIZE`].
309    pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
310
311    /// Maximum size of the Session protocol message.
312    ///
313    /// This is equal to the typical Ethernet MTU size minus [`Self::SEGMENT_OVERHEAD`].
314    pub const MAX_MESSAGE_SIZE: usize = 1492 - Self::SEGMENT_OVERHEAD;
315
316    /// Current version of the protocol.
317    pub const VERSION: u8 = 1;
318
319    /// Maximum number of segments per frame.
320    pub const MAX_SEGMENTS_PER_FRAME: usize = SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME;
321
322    /// Returns the minimum size of a [SessionMessage].
323    pub fn minimum_message_size() -> usize {
324        // Make this a "const fn" once "min" is const fn too
325        Self::HEADER_SIZE
326            + Segment::MINIMUM_SIZE
327                .min(SegmentRequest::<C>::SIZE)
328                .min(FrameAcknowledgements::<C>::SIZE)
329    }
330
331    /// Convenience method to encode the session message.
332    pub fn into_encoded(self) -> Box<[u8]> {
333        Vec::from(self).into_boxed_slice()
334    }
335}
336
337impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
338    type Error = SessionError;
339
340    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
341        SessionMessageIter::from(value).try_next()
342    }
343}
344
345impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
346    fn from(value: SessionMessage<C>) -> Self {
347        let disc = SessionMessageDiscriminants::from(&value) as u8;
348
349        let msg = match value {
350            SessionMessage::Segment(s) => Vec::from(s),
351            SessionMessage::Request(r) => Vec::from(r),
352            SessionMessage::Acknowledge(a) => Vec::from(a),
353        };
354
355        let msg_len = msg.len() as u16;
356
357        let mut ret = Vec::with_capacity(SessionMessage::<C>::HEADER_SIZE + msg_len as usize);
358        ret.push(SessionMessage::<C>::VERSION);
359        ret.push(disc);
360        ret.extend(msg_len.to_be_bytes());
361        ret.extend(msg);
362        ret
363    }
364}
365
366/// Allows parsing of multiple [`SessionMessages`](SessionMessage)
367/// from a borrowed or an owned binary chunk.
368///
369/// The iterator will yield [`SessionMessages`](SessionMessage) until all the messages from
370/// the underlying data chunk are completely parsed or an error occurs.
371///
372/// In other words, it keeps yielding `Some(Ok(_))` until it yields either `None`
373/// or `Some(Err(_))` immediately followed by `None`.
374///
375/// This iterator is [fused](std::iter::FusedIterator).
376#[derive(Debug, Clone)]
377pub struct SessionMessageIter<'a, const C: usize> {
378    data: Cow<'a, [u8]>,
379    offset: usize,
380    last_err: Option<SessionError>,
381}
382
383impl<const C: usize> SessionMessageIter<'_, C> {
384    /// Determines if there was an error reading the last message.
385    ///
386    /// If this function returns some error value, the iterator will not
387    /// yield any more messages.
388    pub fn last_error(&self) -> Option<&SessionError> {
389        self.last_err.as_ref()
390    }
391
392    /// Check if this iterator can yield any more messages.
393    ///
394    /// Returns `true` only if a [prior error](SessionMessageIter::last_error) occurred or all useful bytes
395    /// from the underlying chunk were consumed and all messages were parsed.
396    pub fn is_done(&self) -> bool {
397        self.last_err.is_some() || self.data.len() - self.offset < SessionMessage::<C>::minimum_message_size()
398    }
399
400    /// Attempts to parse the current message and moves the offset if successful.
401    fn try_next(&mut self) -> Result<SessionMessage<C>, SessionError> {
402        let mut offset = self.offset;
403
404        // Protocol version
405        if self.data[offset] != SessionMessage::<C>::VERSION {
406            return Err(SessionError::WrongVersion);
407        }
408        offset += 1;
409
410        // Message discriminant
411        let disc = self.data[offset];
412        offset += 1;
413
414        // Message length
415        let len = u16::from_be_bytes(
416            self.data[offset..offset + mem::size_of::<u16>()]
417                .try_into()
418                .map_err(|_| SessionError::IncorrectMessageLength)?,
419        ) as usize;
420        offset += mem::size_of::<u16>();
421
422        if len > SessionMessage::<C>::MAX_MESSAGE_SIZE {
423            return Err(SessionError::IncorrectMessageLength);
424        }
425
426        // The upper 6 bits of the size are reserved for future use,
427        // since MAX_MESSAGE_SIZE always fits within 10 bits (<= MAX_MESSAGE_SIZE = 1500)
428        let reserved = len & 0b111111_0000000000;
429
430        // In version 1 check that the reserved bits are all 0
431        if reserved != 0 {
432            return Err(SessionError::ParseError);
433        }
434
435        // Read the message
436        let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
437            SessionMessageDiscriminants::Segment => {
438                SessionMessage::Segment(self.data[offset..offset + len].try_into()?)
439            }
440            SessionMessageDiscriminants::Request => {
441                SessionMessage::Request(self.data[offset..offset + len].try_into()?)
442            }
443            SessionMessageDiscriminants::Acknowledge => {
444                SessionMessage::Acknowledge(self.data[offset..offset + len].try_into()?)
445            }
446        };
447
448        // Move the internal offset only once the message has been fully parsed
449        self.offset = offset + len;
450        Ok(res)
451    }
452}
453
454impl<'a, const C: usize, T: Into<Cow<'a, [u8]>>> From<T> for SessionMessageIter<'a, C> {
455    fn from(value: T) -> Self {
456        Self {
457            data: value.into(),
458            offset: 0,
459            last_err: None,
460        }
461    }
462}
463
464impl<const C: usize> Iterator for SessionMessageIter<'_, C> {
465    type Item = Result<SessionMessage<C>, NetworkTypeError>;
466
467    fn next(&mut self) -> Option<Self::Item> {
468        if !self.is_done() {
469            self.try_next()
470                .inspect_err(|e| self.last_err = Some(e.clone()))
471                .map_err(NetworkTypeError::SessionProtocolError)
472                .into()
473        } else {
474            None
475        }
476    }
477}
478
479impl<const C: usize> std::iter::FusedIterator for SessionMessageIter<'_, C> {}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::session::Frame;
485    use bitvec::array::BitArray;
486    use bitvec::bitarr;
487    use hex_literal::hex;
488    use hopr_platform::time::native::current_time;
489    use rand::prelude::IteratorRandom;
490    use rand::{thread_rng, Rng};
491    use std::time::SystemTime;
492
493    #[test]
494    fn ensure_session_protocol_version_1_values() {
495        // All of these values are independent of C, so we can set C = 0
496        assert_eq!(1, SessionMessage::<0>::VERSION);
497        assert_eq!(4, SessionMessage::<0>::HEADER_SIZE);
498        assert_eq!(10, SessionMessage::<0>::SEGMENT_OVERHEAD);
499        assert_eq!(8, SessionMessage::<0>::MAX_SEGMENTS_PER_FRAME);
500
501        assert!(SessionMessage::<0>::MAX_MESSAGE_SIZE < 2048);
502    }
503
504    #[test]
505    fn segment_request_should_be_constructible_from_frame_info() {
506        let frames = (1..20)
507            .map(|i| {
508                let mut missing_segments = BitArray::ZERO;
509                (0..7_usize)
510                    .choose_multiple(&mut thread_rng(), 4)
511                    .into_iter()
512                    .for_each(|i| missing_segments.set(i, true));
513                FrameInfo {
514                    frame_id: i,
515                    missing_segments,
516                    total_segments: 8,
517                    last_update: SystemTime::UNIX_EPOCH,
518                }
519            })
520            .collect::<Vec<_>>();
521
522        let mut req = SegmentRequest::<466>::from_iter(frames.clone())
523            .into_iter()
524            .collect::<Vec<_>>();
525        req.sort();
526
527        assert_eq!(frames.len() * 4, req.len());
528        assert_eq!(
529            req,
530            frames
531                .into_iter()
532                .flat_map(|f| f.into_missing_segments())
533                .collect::<Vec<_>>()
534        );
535    }
536
537    #[test]
538    fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
539        const SEG_SIZE: usize = 8;
540
541        let mut segments = Frame {
542            frame_id: 10,
543            data: hex!("deadbeefcafebabe").into(),
544        }
545        .segment(SEG_SIZE)?;
546
547        const MTU: usize = SEG_SIZE + Segment::HEADER_SIZE + 2;
548
549        let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
550        let data = Vec::from(msg_1.clone());
551        let msg_2 = SessionMessage::try_from(&data[..])?;
552
553        assert_eq!(msg_1, msg_2);
554
555        Ok(())
556    }
557
558    #[test]
559    fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
560        let frame_info = FrameInfo {
561            frame_id: 10,
562            total_segments: 255,
563            missing_segments: bitarr![1; 256],
564            last_update: SystemTime::now(),
565        };
566
567        let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter(vec![frame_info]));
568        let data = Vec::from(msg_1.clone());
569        let msg_2 = SessionMessage::try_from(&data[..])?;
570
571        assert_eq!(msg_1, msg_2);
572
573        match msg_1 {
574            SessionMessage::Request(r) => {
575                let missing_segments = r.into_iter().collect::<Vec<_>>();
576                let expected = (0..=7).map(|s| SegmentId(10, s)).collect::<Vec<_>>();
577                assert_eq!(expected, missing_segments);
578            }
579            _ => panic!("invalid type"),
580        }
581
582        Ok(())
583    }
584
585    #[test]
586    fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
587        let mut rng = thread_rng();
588        let frame_ids: Vec<u32> = (0..500).map(|_| rng.gen()).collect();
589
590        let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.into());
591        let data = Vec::from(msg_1.clone());
592        let msg_2 = SessionMessage::try_from(&data[..])?;
593
594        assert_eq!(msg_1, msg_2);
595
596        Ok(())
597    }
598
599    #[test]
600    fn session_message_segment_request_should_yield_correct_bitset_values() {
601        let seg_req = SegmentRequest::<466>([(10, 0b00100100)].into());
602
603        let mut iter = seg_req.into_iter();
604        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
605        assert_eq!(iter.next(), Some(SegmentId(10, 5)));
606        assert_eq!(iter.next(), None);
607
608        let mut frame_info = FrameInfo {
609            frame_id: 10,
610            missing_segments: bitarr![0; 256],
611            total_segments: 10,
612            last_update: current_time(),
613        };
614        frame_info.missing_segments.set(2, true);
615        frame_info.missing_segments.set(5, true);
616
617        let mut iter = SegmentRequest::<466>::from_iter(vec![frame_info]).into_iter();
618        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
619        assert_eq!(iter.next(), Some(SegmentId(10, 5)));
620        assert_eq!(iter.next(), None);
621    }
622
623    #[test]
624    fn session_message_iter_should_be_empty_if_slice_has_no_messages() {
625        const MTU: usize = 462;
626
627        let mut iter = SessionMessageIter::<MTU>::from(Vec::<u8>::new());
628        assert!(iter.next().is_none());
629        assert!(iter.is_done());
630
631        let mut iter = SessionMessageIter::<MTU>::from(&[0u8; 2]);
632        assert!(iter.next().is_none());
633        assert!(iter.is_done());
634    }
635
636    #[test]
637    fn session_message_iter_should_deserialize_multiple_messages() -> anyhow::Result<()> {
638        const MTU: usize = 462;
639
640        let mut messages_1 = Frame {
641            frame_id: 10,
642            data: hopr_crypto_random::random_bytes::<1500>().into(),
643        }
644        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
645        .into_iter()
646        .map(SessionMessage::<MTU>::Segment)
647        .collect::<Vec<_>>();
648
649        let frame_info = FrameInfo {
650            frame_id: 10,
651            total_segments: 255,
652            missing_segments: bitarr![1; 256],
653            last_update: SystemTime::now(),
654        };
655
656        messages_1.push(SessionMessage::<MTU>::Request(SegmentRequest::from_iter(vec![
657            frame_info,
658        ])));
659
660        let mut rng = thread_rng();
661        let frame_ids: Vec<u32> = (0..100).map(|_| rng.gen()).collect();
662        messages_1.push(SessionMessage::<MTU>::Acknowledge(frame_ids.into()));
663
664        let iter = SessionMessageIter::<MTU>::from(
665            messages_1
666                .iter()
667                .cloned()
668                .flat_map(|m| m.into_encoded().into_vec())
669                .chain(std::iter::repeat(0).take(10))
670                .collect::<Vec<u8>>(),
671        );
672
673        let messages_2 = iter.collect::<Result<Vec<_>, _>>()?;
674        assert_eq!(messages_1, messages_2);
675
676        Ok(())
677    }
678
679    #[test]
680    fn session_message_iter_should_not_contain_error_when_consuming_everything() -> anyhow::Result<()> {
681        const MTU: usize = 462;
682
683        let messages = Frame {
684            frame_id: 10,
685            data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
686        }
687        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
688        .into_iter()
689        .map(SessionMessage::<MTU>::Segment)
690        .collect::<Vec<_>>();
691
692        assert_eq!(4, messages.len());
693
694        let data = messages
695            .iter()
696            .cloned()
697            .flat_map(|m| m.into_encoded().into_vec())
698            .chain(std::iter::repeat(0u8).take(10))
699            .collect::<Vec<_>>();
700
701        let mut iter = SessionMessageIter::<MTU>::from(data);
702        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
703        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
704        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[2]));
705        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[3]));
706
707        assert!(iter.next().is_none());
708        assert!(iter.last_error().is_none());
709        assert!(iter.is_done());
710
711        Ok(())
712    }
713
714    #[test]
715    fn session_message_iter_should_not_yield_more_after_error() -> anyhow::Result<()> {
716        const MTU: usize = 462;
717
718        let messages = Frame {
719            frame_id: 10,
720            data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
721        }
722        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
723        .into_iter()
724        .map(SessionMessage::<MTU>::Segment)
725        .collect::<Vec<_>>();
726
727        assert_eq!(4, messages.len());
728
729        let data = messages
730            .iter()
731            .cloned()
732            .enumerate()
733            .flat_map(|(i, m)| {
734                if i == 2 {
735                    Vec::from(hopr_crypto_random::random_bytes::<MTU>())
736                } else {
737                    m.into_encoded().into_vec()
738                }
739            })
740            .collect::<Vec<_>>();
741
742        let mut iter = SessionMessageIter::<MTU>::from(data);
743        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
744        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
745
746        let err = iter.next();
747        assert!(matches!(err, Some(Err(_))));
748        assert!(iter.is_done());
749        assert!(iter.last_error().is_some());
750
751        assert!(iter.next().is_none());
752
753        Ok(())
754    }
755}