hopr_network_types/session/
protocol.rs

1//! # `Session` protocol messages
2//!
3//! The protocol components are built via low-level types of the `frame` module, such as
4//! [`Segment`] and [`Frame`](crate::session::Frame).
5//! Most importantly, the `Session` protocol fixes the maximum number of segments per frame
6//! to 8 (see [`MAX_SEGMENTS_PER_FRAME`](SessionMessage::MAX_SEGMENTS_PER_FRAME)).
7//! Since each segment must fit within a maximum transmission unit (MTU),
8//! a frame can be at most *eight* times the size of the MTU.
9//!
10//! The [current version](SessionMessage::VERSION) of the protocol consists of three
11//! messages that are sent and received via the underlying transport:
12//! - [`Segment message`](Segment)
13//! - [`Retransmission request`](SegmentRequest)
14//! - [`Frame acknowledgement`](FrameAcknowledgements)
15//!
16//! All of these messages are bundled within the [`SessionMessage`] enum,
17//! which is then [encoded](SessionMessage::into_encoded) as a byte array of a maximum
18//! MTU size `C` (which is a generic const argument of the `SessionMessage` type).
19//! The header of the `SessionMessage` encoding consists of the [`version`](SessionMessage::VERSION)
20//! byte, followed by the discriminator byte of one of the above messages and then followed by
21//! the message length and message's encoding itself.
22//!
23//! Multiple [`SessionMessages`](SessionMessage) can be read from a binary blob using the
24//! [`SessionMessageIter`].
25//!
26//! ## Segment message ([`Segment`](SessionMessage::Segment))
27//! The Segment message contains the payload [`Segment`] of some [`Frame`](crate::session::Frame).
28//! The size of this message can range from [`the minimum message size`](SessionMessage::minimum_message_size)
29//! up to `C`.
30//!
31//! ## Retransmission request message ([`Request`](SessionMessage::Request))
32//! Contains a request for retransmission of missing segments in a frame. This is sent from
33//! the segment recipient to the sender, once it realizes some of the received frames are incomplete
34//! (after a certain period of time).
35//!
36//! The encoding of this message consists of pairs of [frame ID](FrameId) and
37//! a single byte bitmap of requested segments in this frame.
38//! Each pair is therefore [`ENTRY_SIZE`](SegmentRequest::ENTRY_SIZE) bytes long.
39//! There can be at most [`MAX_ENTRIES`](SegmentRequest::MAX_ENTRIES)
40//! in a single Retransmission request message, given `C` as the MTU size. If the message contains
41//! fewer entries, it is padded with zeros (0 is not a valid frame ID).
42//! If more frames have missing segments, multiple retransmission request messages need to be sent.
43//!
44//! ## Frame acknowledgement message ([`Acknowledge`](SessionMessage::Acknowledge))
45//! This message is sent from the segment recipient to the segment sender, to acknowledge that
46//! all segments of certain frames have been completely and correctly received by the recipient.
47//!
48//! The message consists simply of a [frame ID](FrameId) list of the completely received
49//! frames. There can be at most [`MAX_ACK_FRAMES`](FrameAcknowledgements::MAX_ACK_FRAMES)
50//! per message. If more frames need to be acknowledged, more messages need to be sent.
51//! If the message contains fewer entries, it is padded with zeros (0 is not a valid frame ID).
52use std::{
53    borrow::Cow,
54    collections::{BTreeMap, BTreeSet},
55    fmt::{Display, Formatter},
56    mem,
57};
58
59use bitvec::field::BitField;
60
61use crate::{
62    errors::NetworkTypeError,
63    session::{
64        errors::SessionError,
65        frame::{FrameId, FrameInfo, MissingSegmentsBitmap, Segment, SegmentId, SeqNum},
66    },
67};
68
69/// Holds the Segment Retransmission Request message.
70/// That is an ordered map of frame IDs and a bitmap of missing segments in each frame.
71/// The bitmap can cover up a request for up to [`SegmentRequest::MAX_ENTRIES`] segments.
72#[derive(Debug, Clone, PartialEq, Eq, Default)]
73pub struct SegmentRequest<const C: usize>(BTreeMap<FrameId, SeqNum>);
74
75impl<const C: usize> SegmentRequest<C> {
76    /// Size of a single segment retransmission request entry.
77    pub const ENTRY_SIZE: usize = mem::size_of::<FrameId>() + mem::size_of::<SeqNum>();
78    /// Maximum number of segment retransmission entries.
79    pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
80    /// Maximum number of missing segments per frame.
81    pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = SeqNum::BITS as usize;
82    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
83
84    /// Returns the number of segments to retransmit.
85    pub fn len(&self) -> usize {
86        self.0
87            .values()
88            .take(Self::MAX_ENTRIES)
89            .map(|e| e.count_ones() as usize)
90            .sum()
91    }
92
93    /// Returns true if there are no segments to retransmit in this request.
94    pub fn is_empty(&self) -> bool {
95        self.0.is_empty()
96    }
97}
98
99impl<const C: usize> IntoIterator for SegmentRequest<C> {
100    type IntoIter = std::vec::IntoIter<SegmentId>;
101    type Item = SegmentId;
102
103    fn into_iter(self) -> Self::IntoIter {
104        let seq_size = SeqNum::BITS as usize;
105        let mut ret = Vec::with_capacity(seq_size * self.0.len());
106        for (frame_id, missing) in self.0 {
107            ret.extend(
108                MissingSegmentsBitmap::from([missing])
109                    .iter_ones()
110                    .map(|i| SegmentId(frame_id, i as SeqNum)),
111            );
112        }
113        ret.into_iter()
114    }
115}
116
117impl<const C: usize> FromIterator<FrameInfo> for SegmentRequest<C> {
118    fn from_iter<T: IntoIterator<Item = FrameInfo>>(iter: T) -> Self {
119        let mut ret = Self::default();
120        for frame in iter.into_iter().take(Self::MAX_ENTRIES) {
121            let frame_id = frame.frame_id;
122            ret.0.insert(frame_id, frame.missing_segments.load());
123        }
124        ret
125    }
126}
127
128impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
129    type Error = SessionError;
130
131    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
132        if value.len() == Self::SIZE {
133            let mut ret = Self::default();
134            for (frame_id, missing) in value
135                .chunks_exact(Self::ENTRY_SIZE)
136                .map(|c| c.split_at(mem::size_of::<FrameId>()))
137            {
138                let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
139                if frame_id > 0 {
140                    ret.0.insert(
141                        frame_id,
142                        SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
143                    );
144                }
145            }
146            Ok(ret)
147        } else {
148            Err(SessionError::ParseError)
149        }
150    }
151}
152
153impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
154    fn from(value: SegmentRequest<C>) -> Self {
155        let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
156        let mut offset = 0;
157        for (frame_id, seq_num) in value.0 {
158            if offset + mem::size_of::<FrameId>() + mem::size_of::<SeqNum>() < C {
159                ret[offset..offset + mem::size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
160                offset += mem::size_of::<FrameId>();
161                ret[offset..offset + mem::size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
162                offset += mem::size_of::<SeqNum>();
163            } else {
164                break;
165            }
166        }
167        ret
168    }
169}
170
171/// Holds the Frame Acknowledgement message.
172/// This carries an ordered set of up to [`FrameAcknowledgements::MAX_ACK_FRAMES`] [frame IDs](FrameId) that has
173/// been acknowledged by the counterparty.
174#[derive(Debug, Clone, PartialEq, Eq, Default)]
175pub struct FrameAcknowledgements<const C: usize>(BTreeSet<FrameId>);
176
177impl<const C: usize> FrameAcknowledgements<C> {
178    /// Maximum number of [frame IDs](FrameId) that can be accommodated.
179    pub const MAX_ACK_FRAMES: usize = Self::SIZE / mem::size_of::<FrameId>();
180    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
181
182    /// Pushes the frame ID.
183    /// Returns true if the value has been pushed or false it the container is full or already
184    /// contains that value.
185    #[inline]
186    pub fn push(&mut self, frame_id: FrameId) -> bool {
187        !self.is_full() && self.0.insert(frame_id)
188    }
189
190    /// Number of acknowledged frame IDs in this instance.
191    #[inline]
192    pub fn len(&self) -> usize {
193        self.0.len()
194    }
195
196    /// Returns true if there are no frame IDs in this instance.
197    pub fn is_empty(&self) -> bool {
198        self.0.is_empty()
199    }
200
201    /// Indicates whether the [maximum number of frame IDs](FrameAcknowledgements::MAX_ACK_FRAMES)
202    /// has been reached.
203    #[inline]
204    pub fn is_full(&self) -> bool {
205        self.0.len() == Self::MAX_ACK_FRAMES
206    }
207}
208
209impl<const C: usize> From<Vec<FrameId>> for FrameAcknowledgements<C> {
210    fn from(value: Vec<FrameId>) -> Self {
211        Self(
212            value
213                .into_iter()
214                .take(Self::MAX_ACK_FRAMES)
215                .filter(|v| *v > 0)
216                .collect(),
217        )
218    }
219}
220
221impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
222    type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
223    type Item = FrameId;
224
225    fn into_iter(self) -> Self::IntoIter {
226        self.0.into_iter()
227    }
228}
229
230impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
231    type Error = SessionError;
232
233    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
234        if value.len() == Self::SIZE {
235            Ok(Self(
236                // chunks_exact discards the remainder bytes
237                value
238                    .chunks_exact(mem::size_of::<FrameId>())
239                    .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
240                    .filter(|f| *f > 0)
241                    .collect(),
242            ))
243        } else {
244            Err(SessionError::ParseError)
245        }
246    }
247}
248
249impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
250    fn from(value: FrameAcknowledgements<C>) -> Self {
251        value
252            .0
253            .iter()
254            .flat_map(|v| v.to_be_bytes())
255            .chain(std::iter::repeat(0_u8))
256            .take(FrameAcknowledgements::<C>::SIZE)
257            .collect::<Vec<_>>()
258    }
259}
260
261/// Contains all messages of the Session sub-protocol.
262///
263/// The maximum size of the Session sub-protocol message is given by `C`.
264#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
265#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
266pub enum SessionMessage<const C: usize> {
267    /// Represents a message containing a segment.
268    Segment(Segment),
269    /// Represents a message containing a [request](SegmentRequest) for segments.
270    Request(SegmentRequest<C>),
271    /// Represents a message containing [frame acknowledgements](FrameAcknowledgements).
272    Acknowledge(FrameAcknowledgements<C>),
273}
274
275impl<const C: usize> Display for SessionMessage<C> {
276    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
277        match &self {
278            SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
279            SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
280            SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
281        }
282    }
283}
284
285impl<const C: usize> SessionMessage<C> {
286    /// Header size of the session message.
287    /// This is currently the version byte, the size of [SessionMessageDiscriminants] representation
288    /// and two bytes for the message length.
289    pub const HEADER_SIZE: usize = 1 + mem::size_of::<SessionMessageDiscriminants>() + mem::size_of::<u16>();
290    /// Maximum size of the Session protocol message.
291    ///
292    /// This is equal to the typical Ethernet MTU size minus [`Self::SEGMENT_OVERHEAD`].
293    pub const MAX_MESSAGE_SIZE: usize = 1492 - Self::SEGMENT_OVERHEAD;
294    /// Maximum number of segments per frame.
295    pub const MAX_SEGMENTS_PER_FRAME: usize = SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME;
296    /// Size of the overhead that's added to the raw payload of each [`Segment`].
297    ///
298    /// This amounts to [`SessionMessage::HEADER_SIZE`] + [`Segment::HEADER_SIZE`].
299    pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
300    /// Current version of the protocol.
301    pub const VERSION: u8 = 1;
302
303    /// Returns the minimum size of a [SessionMessage].
304    pub fn minimum_message_size() -> usize {
305        // Make this a "const fn" once "min" is const fn too
306        Self::HEADER_SIZE
307            + Segment::MINIMUM_SIZE
308                .min(SegmentRequest::<C>::SIZE)
309                .min(FrameAcknowledgements::<C>::SIZE)
310    }
311
312    /// Convenience method to encode the session message.
313    pub fn into_encoded(self) -> Box<[u8]> {
314        Vec::from(self).into_boxed_slice()
315    }
316}
317
318impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
319    type Error = SessionError;
320
321    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
322        SessionMessageIter::from(value).try_next()
323    }
324}
325
326impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
327    fn from(value: SessionMessage<C>) -> Self {
328        let disc = SessionMessageDiscriminants::from(&value) as u8;
329
330        let msg = match value {
331            SessionMessage::Segment(s) => Vec::from(s),
332            SessionMessage::Request(r) => Vec::from(r),
333            SessionMessage::Acknowledge(a) => Vec::from(a),
334        };
335
336        let msg_len = msg.len() as u16;
337
338        let mut ret = Vec::with_capacity(SessionMessage::<C>::HEADER_SIZE + msg_len as usize);
339        ret.push(SessionMessage::<C>::VERSION);
340        ret.push(disc);
341        ret.extend(msg_len.to_be_bytes());
342        ret.extend(msg);
343        ret
344    }
345}
346
347/// Allows parsing of multiple [`SessionMessages`](SessionMessage)
348/// from a borrowed or an owned binary chunk.
349///
350/// The iterator will yield [`SessionMessages`](SessionMessage) until all the messages from
351/// the underlying data chunk are completely parsed or an error occurs.
352///
353/// In other words, it keeps yielding `Some(Ok(_))` until it yields either `None`
354/// or `Some(Err(_))` immediately followed by `None`.
355///
356/// This iterator is [fused](std::iter::FusedIterator).
357#[derive(Debug, Clone)]
358pub struct SessionMessageIter<'a, const C: usize> {
359    data: Cow<'a, [u8]>,
360    offset: usize,
361    last_err: Option<SessionError>,
362}
363
364impl<const C: usize> SessionMessageIter<'_, C> {
365    /// Determines if there was an error reading the last message.
366    ///
367    /// If this function returns some error value, the iterator will not
368    /// yield any more messages.
369    pub fn last_error(&self) -> Option<&SessionError> {
370        self.last_err.as_ref()
371    }
372
373    /// Check if this iterator can yield any more messages.
374    ///
375    /// Returns `true` only if a [prior error](SessionMessageIter::last_error) occurred or all useful bytes
376    /// from the underlying chunk were consumed and all messages were parsed.
377    pub fn is_done(&self) -> bool {
378        self.last_err.is_some() || self.data.len() - self.offset < SessionMessage::<C>::minimum_message_size()
379    }
380
381    /// Attempts to parse the current message and moves the offset if successful.
382    fn try_next(&mut self) -> Result<SessionMessage<C>, SessionError> {
383        let mut offset = self.offset;
384
385        // Protocol version
386        if self.data[offset] != SessionMessage::<C>::VERSION {
387            return Err(SessionError::WrongVersion);
388        }
389        offset += 1;
390
391        // Message discriminant
392        let disc = self.data[offset];
393        offset += 1;
394
395        // Message length
396        let len = u16::from_be_bytes(
397            self.data[offset..offset + mem::size_of::<u16>()]
398                .try_into()
399                .map_err(|_| SessionError::IncorrectMessageLength)?,
400        ) as usize;
401        offset += mem::size_of::<u16>();
402
403        if len > SessionMessage::<C>::MAX_MESSAGE_SIZE {
404            return Err(SessionError::IncorrectMessageLength);
405        }
406
407        // The upper 6 bits of the size are reserved for future use,
408        // since MAX_MESSAGE_SIZE always fits within 10 bits (<= MAX_MESSAGE_SIZE = 1500)
409        let reserved = len & 0b111111_0000000000;
410
411        // In version 1 check that the reserved bits are all 0
412        if reserved != 0 {
413            return Err(SessionError::ParseError);
414        }
415
416        // Read the message
417        let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
418            SessionMessageDiscriminants::Segment => {
419                SessionMessage::Segment(self.data[offset..offset + len].try_into()?)
420            }
421            SessionMessageDiscriminants::Request => {
422                SessionMessage::Request(self.data[offset..offset + len].try_into()?)
423            }
424            SessionMessageDiscriminants::Acknowledge => {
425                SessionMessage::Acknowledge(self.data[offset..offset + len].try_into()?)
426            }
427        };
428
429        // Move the internal offset only once the message has been fully parsed
430        self.offset = offset + len;
431        Ok(res)
432    }
433}
434
435impl<'a, const C: usize, T: Into<Cow<'a, [u8]>>> From<T> for SessionMessageIter<'a, C> {
436    fn from(value: T) -> Self {
437        Self {
438            data: value.into(),
439            offset: 0,
440            last_err: None,
441        }
442    }
443}
444
445impl<const C: usize> Iterator for SessionMessageIter<'_, C> {
446    type Item = Result<SessionMessage<C>, NetworkTypeError>;
447
448    fn next(&mut self) -> Option<Self::Item> {
449        if !self.is_done() {
450            self.try_next()
451                .inspect_err(|e| self.last_err = Some(e.clone()))
452                .map_err(NetworkTypeError::SessionProtocolError)
453                .into()
454        } else {
455            None
456        }
457    }
458}
459
460impl<const C: usize> std::iter::FusedIterator for SessionMessageIter<'_, C> {}
461
462#[cfg(test)]
463mod tests {
464    use std::time::SystemTime;
465
466    use bitvec::bitarr;
467    use hex_literal::hex;
468    use hopr_platform::time::native::current_time;
469    use rand::{Rng, prelude::IteratorRandom, thread_rng};
470
471    use super::*;
472    use crate::session::{
473        Frame,
474        frame::{MissingSegmentsBitmap, NO_MISSING_SEGMENTS},
475    };
476
477    pub const ALL_MISSING_SEGMENTS: MissingSegmentsBitmap =
478        bitarr![SeqNum, bitvec::prelude::Msb0; 1; SeqNum::BITS as usize];
479
480    #[test]
481    fn ensure_session_protocol_version_1_values() {
482        // All of these values are independent of C, so we can set C = 0
483        assert_eq!(1, SessionMessage::<0>::VERSION);
484        assert_eq!(4, SessionMessage::<0>::HEADER_SIZE);
485        assert_eq!(10, SessionMessage::<0>::SEGMENT_OVERHEAD);
486        assert_eq!(8, SessionMessage::<0>::MAX_SEGMENTS_PER_FRAME);
487
488        const _: () = {
489            assert!(SessionMessage::<0>::MAX_MESSAGE_SIZE < 2048);
490        };
491    }
492
493    #[test]
494    fn segment_request_should_be_constructible_from_frame_info() {
495        let frames = (1..20)
496            .map(|i| {
497                let mut missing_segments = NO_MISSING_SEGMENTS;
498                (0..7_usize)
499                    .choose_multiple(&mut thread_rng(), 4)
500                    .into_iter()
501                    .for_each(|i| missing_segments.set(i, true));
502                FrameInfo {
503                    frame_id: i,
504                    missing_segments,
505                    total_segments: 8,
506                    last_update: SystemTime::UNIX_EPOCH,
507                }
508            })
509            .collect::<Vec<_>>();
510
511        let mut req = SegmentRequest::<466>::from_iter(frames.clone())
512            .into_iter()
513            .collect::<Vec<_>>();
514        req.sort();
515
516        assert_eq!(frames.len() * 4, req.len());
517        assert_eq!(
518            req,
519            frames
520                .into_iter()
521                .flat_map(|f| f.into_missing_segments())
522                .collect::<Vec<_>>()
523        );
524    }
525
526    #[test]
527    fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
528        const SEG_SIZE: usize = 8;
529
530        let mut segments = Frame {
531            frame_id: 10,
532            data: hex!("deadbeefcafebabe").into(),
533        }
534        .segment(SEG_SIZE)?;
535
536        const MTU: usize = SEG_SIZE + Segment::HEADER_SIZE + 2;
537
538        let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
539        let data = Vec::from(msg_1.clone());
540        let msg_2 = SessionMessage::try_from(&data[..])?;
541
542        assert_eq!(msg_1, msg_2);
543
544        Ok(())
545    }
546
547    #[test]
548    fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
549        let frame_info = FrameInfo {
550            frame_id: 10,
551            total_segments: 8,
552            missing_segments: [0b10100001].into(),
553            last_update: SystemTime::now(),
554        };
555
556        let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter(vec![frame_info]));
557        let data = Vec::from(msg_1.clone());
558        let msg_2 = SessionMessage::try_from(&data[..])?;
559
560        assert_eq!(msg_1, msg_2);
561
562        match msg_1 {
563            SessionMessage::Request(r) => {
564                let missing_segments = r.into_iter().collect::<Vec<_>>();
565                let expected = vec![SegmentId(10, 0), SegmentId(10, 2), SegmentId(10, 7)];
566                assert_eq!(expected, missing_segments);
567            }
568            _ => panic!("invalid type"),
569        }
570
571        Ok(())
572    }
573
574    #[test]
575    fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
576        let mut rng = thread_rng();
577        let frame_ids: Vec<u32> = (0..500).map(|_| rng.r#gen()).collect();
578
579        let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.into());
580        let data = Vec::from(msg_1.clone());
581        let msg_2 = SessionMessage::try_from(&data[..])?;
582
583        assert_eq!(msg_1, msg_2);
584
585        Ok(())
586    }
587
588    #[test]
589    fn session_message_segment_request_should_yield_correct_bitset_values() {
590        let seg_req = SegmentRequest::<466>([(3, 0b01000001), (10, 0b00101000)].into());
591
592        let mut iter = seg_req.into_iter();
593        assert_eq!(iter.next(), Some(SegmentId(3, 1)));
594        assert_eq!(iter.next(), Some(SegmentId(3, 7)));
595        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
596        assert_eq!(iter.next(), Some(SegmentId(10, 4)));
597        assert_eq!(iter.next(), None);
598
599        let mut frame_info = FrameInfo {
600            frame_id: 10,
601            missing_segments: NO_MISSING_SEGMENTS,
602            total_segments: 10,
603            last_update: current_time(),
604        };
605        frame_info.missing_segments.set(2, true);
606        frame_info.missing_segments.set(4, true);
607
608        let mut iter = frame_info.clone().into_missing_segments();
609
610        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
611        assert_eq!(iter.next(), Some(SegmentId(10, 4)));
612        assert_eq!(iter.next(), None);
613
614        let mut iter = SegmentRequest::<466>::from_iter(vec![frame_info]).into_iter();
615        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
616        assert_eq!(iter.next(), Some(SegmentId(10, 4)));
617        assert_eq!(iter.next(), None);
618    }
619
620    #[test]
621    fn session_message_iter_should_be_empty_if_slice_has_no_messages() {
622        const MTU: usize = 462;
623
624        let mut iter = SessionMessageIter::<MTU>::from(Vec::<u8>::new());
625        assert!(iter.next().is_none());
626        assert!(iter.is_done());
627
628        let mut iter = SessionMessageIter::<MTU>::from(&[0u8; 2]);
629        assert!(iter.next().is_none());
630        assert!(iter.is_done());
631    }
632
633    #[test]
634    fn session_message_iter_should_deserialize_multiple_messages() -> anyhow::Result<()> {
635        const MTU: usize = 462;
636
637        let mut messages_1 = Frame {
638            frame_id: 10,
639            data: hopr_crypto_random::random_bytes::<1500>().into(),
640        }
641        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
642        .into_iter()
643        .map(SessionMessage::<MTU>::Segment)
644        .collect::<Vec<_>>();
645
646        let frame_info = FrameInfo {
647            frame_id: 10,
648            total_segments: 255,
649            missing_segments: ALL_MISSING_SEGMENTS,
650            last_update: SystemTime::now(),
651        };
652
653        messages_1.push(SessionMessage::<MTU>::Request(SegmentRequest::from_iter(vec![
654            frame_info,
655        ])));
656
657        let mut rng = thread_rng();
658        let frame_ids: Vec<u32> = (0..100).map(|_| rng.r#gen()).collect();
659        messages_1.push(SessionMessage::<MTU>::Acknowledge(frame_ids.into()));
660
661        let iter = SessionMessageIter::<MTU>::from(
662            messages_1
663                .iter()
664                .cloned()
665                .flat_map(|m| m.into_encoded().into_vec())
666                .chain(std::iter::repeat_n(0, 10))
667                .collect::<Vec<u8>>(),
668        );
669
670        let messages_2 = iter.collect::<Result<Vec<_>, _>>()?;
671        assert_eq!(messages_1, messages_2);
672
673        Ok(())
674    }
675
676    #[test]
677    fn session_message_iter_should_not_contain_error_when_consuming_everything() -> anyhow::Result<()> {
678        const MTU: usize = 462;
679
680        let messages = Frame {
681            frame_id: 10,
682            data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
683        }
684        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
685        .into_iter()
686        .map(SessionMessage::<MTU>::Segment)
687        .collect::<Vec<_>>();
688
689        assert_eq!(4, messages.len());
690
691        let data = messages
692            .iter()
693            .cloned()
694            .flat_map(|m| m.into_encoded().into_vec())
695            .chain(std::iter::repeat_n(0u8, 10))
696            .collect::<Vec<_>>();
697
698        let mut iter = SessionMessageIter::<MTU>::from(data);
699        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
700        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
701        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[2]));
702        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[3]));
703
704        assert!(iter.next().is_none());
705        assert!(iter.last_error().is_none());
706        assert!(iter.is_done());
707
708        Ok(())
709    }
710
711    #[test]
712    fn session_message_iter_should_not_yield_more_after_error() -> anyhow::Result<()> {
713        const MTU: usize = 462;
714
715        let messages = Frame {
716            frame_id: 10,
717            data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
718        }
719        .segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
720        .into_iter()
721        .map(SessionMessage::<MTU>::Segment)
722        .collect::<Vec<_>>();
723
724        assert_eq!(4, messages.len());
725
726        let data = messages
727            .iter()
728            .cloned()
729            .enumerate()
730            .flat_map(|(i, m)| {
731                if i == 2 {
732                    Vec::from(hopr_crypto_random::random_bytes::<MTU>())
733                } else {
734                    m.into_encoded().into_vec()
735                }
736            })
737            .collect::<Vec<_>>();
738
739        let mut iter = SessionMessageIter::<MTU>::from(data);
740        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
741        assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
742
743        let err = iter.next();
744        assert!(matches!(err, Some(Err(_))));
745        assert!(iter.is_done());
746        assert!(iter.last_error().is_some());
747
748        assert!(iter.next().is_none());
749
750        Ok(())
751    }
752}