Skip to main content

hopr_protocol_session/protocol/
mod.rs

1//! # `Session` protocol messages
2//!
3//! The protocol components are built via low-level types of the `frame` module, such as
4//! [`Segment`] and [`Frame`](crate::session::Frame).
5//! Most importantly, the `Session` protocol fixes the maximum number of segments per frame
6//! to 8 (see [`MAX_SEGMENTS_PER_FRAME`](SessionMessage::MAX_SEGMENTS_PER_FRAME)).
7//! Since each segment must fit within a maximum transmission unit (MTU),
8//! a frame can be at most *eight* times the size of the MTU.
9//!
10//! The [current version](SessionMessage::VERSION) of the protocol consists of three
11//! messages that are sent and received via the underlying transport:
12//! - [`Segment message`](Segment)
13//! - [`Retransmission request`](SegmentRequest)
14//! - [`Frame acknowledgement`](FrameAcknowledgements)
15//!
16//! All of these messages are bundled within the [`SessionMessage`] enum,
17//! which is then encoded as a byte array of a maximum
18//! MTU size `C` (which is a generic const argument of the `SessionMessage` type).
19//! The header of the `SessionMessage` encoding consists of the [`version`](SessionMessage::VERSION)
20//! byte, followed by the discriminator byte of one of the above messages and then followed by
21//! the message length and message's encoding itself.
22//!
23//! ## Segment message ([`Segment`](SessionMessage::Segment))
24//!
25//! The Segment message contains the payload [`Segment`] of some [`Frame`](crate::session::Frame).
26//! The size of this message can range from [`the minimum message size`](SessionMessage::minimum_message_size)
27//! up to `C`.
28//!
29//! ## Retransmission request message ([`Request`](SessionMessage::Request))
30//!
31//! Contains a request for retransmission of missing segments in a frame. This is sent from
32//! the segment recipient to the sender, once it realizes some of the received frames are incomplete
33//! (after a certain period of time).
34//!
35//! The encoding of this message consists of pairs of [frame ID](FrameId) and
36//! a single byte bitmap of requested segments in this frame.
37//! Each pair is therefore [`ENTRY_SIZE`](SegmentRequest::ENTRY_SIZE) bytes-long.
38//! There can be at most [`MAX_ENTRIES`](SegmentRequest::MAX_ENTRIES)
39//! in a single Retransmission request message, given `C` as the MTU size. If the message contains
40//! fewer entries, it is padded with zeros (0 is not a valid frame ID).
41//! If more frames have missing segments, multiple retransmission request messages need to be sent.
42//!
43//! ## Frame acknowledgement message ([`Acknowledge`](SessionMessage::Acknowledge))
44//!
45//! This message is sent from the segment recipient to the segment sender to acknowledge that
46//! all segments of certain frames have been completely and correctly received by the recipient.
47//!
48//! The message consists simply of a [frame ID](FrameId) list of the completely received
49//! frames. There can be at most [`MAX_ACK_FRAMES`](FrameAcknowledgements::MAX_ACK_FRAMES)
50//! per message. If more frames need to be acknowledged, more messages need to be sent.
51//! If the message contains fewer entries, it is padded with zeros (0 is not a valid frame ID).
52
53mod frames;
54mod messages;
55
56use asynchronous_codec::{Decoder, Encoder};
57use bytes::{Buf, BufMut, BytesMut};
58pub use frames::{Frame, FrameId, OrderedFrame, Segment, SegmentId, SeqIndicator, SeqNum};
59pub use messages::{FrameAcknowledgements, MissingSegmentsBitmap, SegmentRequest};
60
61use crate::errors::SessionError;
62
63/// Contains all messages of the Session sub-protocol.
64///
65/// The maximum size of the Session sub-protocol message is given by `C`.
66#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
67#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
68pub enum SessionMessage<const C: usize> {
69    /// Represents a message containing a segment.
70    Segment(Segment),
71    /// Represents a message containing a [request](SegmentRequest) for segments.
72    Request(SegmentRequest<C>),
73    /// Represents a message containing [frame acknowledgements](FrameAcknowledgements).
74    Acknowledge(FrameAcknowledgements<C>),
75}
76
77impl<const C: usize> std::fmt::Display for SessionMessage<C> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match &self {
80            SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
81            SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
82            SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
83        }
84    }
85}
86
87impl<const C: usize> SessionMessage<C> {
88    /// Header size of the session message.
89    /// This is currently the version byte, the size of [`SessionMessageDiscriminants`] representation
90    /// and two bytes for the message length.
91    pub const HEADER_SIZE: usize = 1 + size_of::<SessionMessageDiscriminants>() + size_of::<u16>();
92    /// Maximum size of the message in v1.
93    pub const MAX_MESSAGE_LENGTH: usize = C.saturating_sub(Self::HEADER_SIZE);
94    /// Size of the overhead that's added to the raw payload of each [`Segment`].
95    ///
96    /// This amounts to [`SessionMessage::HEADER_SIZE`] + [`Segment::HEADER_SIZE`].
97    pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
98    /// Current version of the protocol.
99    pub const VERSION: u8 = 1;
100
101    /// Returns the minimum size of a [`SessionMessage`].
102    pub fn minimum_message_size() -> usize {
103        // Make this a "const fn" once "min" is const fn too
104        Self::HEADER_SIZE
105            + Segment::HEADER_SIZE
106                .min(SegmentRequest::<C>::SIZE)
107                .min(FrameAcknowledgements::<C>::SIZE)
108    }
109
110    /// Convenience method to encode the session message.
111    pub fn into_encoded(self) -> Box<[u8]> {
112        Vec::from(self).into_boxed_slice()
113    }
114}
115
116impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
117    fn from(message: SessionMessage<C>) -> Self {
118        debug_assert!(
119            C > SessionMessage::<C>::HEADER_SIZE && SessionMessage::<C>::MAX_MESSAGE_LENGTH <= u16::MAX as usize
120        );
121
122        let mut result = BytesMut::new();
123        SessionCodec::<C>
124            .encode(message, &mut result)
125            .expect("encoding never fails");
126
127        result.to_vec()
128    }
129}
130
131impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
132    type Error = SessionError;
133
134    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
135        SessionCodec
136            .decode(&mut BytesMut::from(value))?
137            .ok_or(SessionError::IncorrectMessageLength)
138    }
139}
140
141#[derive(Clone, Copy, Default)]
142pub struct SessionCodec<const C: usize>;
143
144impl<const C: usize> Encoder for SessionCodec<C> {
145    type Error = SessionError;
146    type Item<'a> = SessionMessage<C>;
147
148    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
149        debug_assert!(
150            C > SessionMessage::<C>::HEADER_SIZE && SessionMessage::<C>::MAX_MESSAGE_LENGTH <= u16::MAX as usize
151        );
152
153        let disc = SessionMessageDiscriminants::from(&item) as u8;
154
155        let msg = match item {
156            SessionMessage::Segment(s) => Vec::from(s),
157            SessionMessage::Request(r) => Vec::from(r),
158            SessionMessage::Acknowledge(a) => Vec::from(a),
159        };
160
161        if msg.len() > SessionMessage::<C>::MAX_MESSAGE_LENGTH {
162            return Err(SessionError::IncorrectMessageLength);
163        }
164
165        let msg_len = msg.len() as u16;
166        dst.put_u8(SessionMessage::<C>::VERSION);
167        dst.put_u8(disc);
168        dst.put_u16(msg_len);
169        dst.extend_from_slice(&msg);
170
171        tracing::trace!(disc, msg_len, "encoded message");
172        Ok(())
173    }
174}
175
176impl<const C: usize> Decoder for SessionCodec<C> {
177    type Error = SessionError;
178    type Item = SessionMessage<C>;
179
180    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
181        debug_assert!(C > SessionMessage::<C>::HEADER_SIZE);
182
183        tracing::trace!(msg_len = src.len(), "decoding message");
184        if src.len() < SessionMessage::<C>::minimum_message_size() {
185            return Ok(None);
186        }
187
188        // Protocol version
189        if src[0] != SessionMessage::<C>::VERSION {
190            return Err(SessionError::WrongVersion);
191        }
192
193        // Message discriminant
194        let disc = src[1];
195
196        // Message length
197        let payload_len = u16::from_be_bytes([src[2], src[3]]) as usize;
198
199        // Check the maximum message length for version 1
200        if payload_len > SessionMessage::<C>::MAX_MESSAGE_LENGTH {
201            return Err(SessionError::IncorrectMessageLength);
202        }
203
204        // Check if there's enough data so that we can read the rest of the message
205        if src.len() < SessionMessage::<C>::HEADER_SIZE + payload_len {
206            return Ok(None);
207        }
208
209        // Read the message
210        let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
211            SessionMessageDiscriminants::Segment => SessionMessage::Segment(
212                src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
213            ),
214            SessionMessageDiscriminants::Request => SessionMessage::Request(
215                src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
216            ),
217            SessionMessageDiscriminants::Acknowledge => SessionMessage::Acknowledge(
218                src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
219            ),
220        };
221
222        src.advance(SessionMessage::<C>::HEADER_SIZE + payload_len);
223        Ok(Some(res))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use hex_literal::hex;
230    use hopr_protocol_app::prelude::ApplicationData;
231    use rand::{RngExt, rngs::ThreadRng};
232
233    use super::*;
234    use crate::{
235        protocol::{FrameId, SegmentId},
236        utils::segment,
237    };
238
239    #[test]
240    fn ensure_session_protocol_version_1_values() {
241        // All of these values are independent of C, so we can set C = 0
242        assert_eq!(1, SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::VERSION);
243        assert_eq!(4, SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::HEADER_SIZE);
244        assert_eq!(
245            10,
246            SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::SEGMENT_OVERHEAD
247        );
248        assert_eq!(
249            1024,
250            SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::MAX_MESSAGE_LENGTH
251        );
252    }
253
254    #[test]
255    fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
256        const SEG_SIZE: usize = 8;
257
258        let mut segments = segment(hex!("deadbeefcafebabe"), SEG_SIZE, 10)?;
259
260        const MTU: usize = SEG_SIZE + SessionMessage::<0>::SEGMENT_OVERHEAD;
261
262        let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
263        let data = Vec::from(msg_1.clone());
264        let msg_2 = SessionMessage::try_from(&data[..])?;
265
266        assert_eq!(msg_1, msg_2);
267
268        Ok(())
269    }
270
271    #[test]
272    fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
273        // The first 8 segments are missing in Frame 10
274        let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter([
275            (2 as FrameId, [0b11000001].into()),
276            (10 as FrameId, [0b01000100].into()),
277        ]));
278        let data = Vec::from(msg_1.clone());
279        let msg_2 = SessionMessage::try_from(&data[..])?;
280
281        assert_eq!(msg_1, msg_2);
282
283        match msg_1 {
284            SessionMessage::Request(r) => {
285                let missing_segments = r.into_iter().collect::<Vec<_>>();
286                let expected = vec![
287                    SegmentId(2, 0),
288                    SegmentId(2, 1),
289                    SegmentId(2, 7),
290                    SegmentId(10, 1),
291                    SegmentId(10, 5),
292                ];
293                assert_eq!(expected, missing_segments);
294            }
295            _ => panic!("invalid type"),
296        }
297
298        Ok(())
299    }
300
301    #[test]
302    fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
303        let mut rng = ThreadRng::default();
304        let frame_ids: Vec<u32> = (0..FrameAcknowledgements::<466>::MAX_ACK_FRAMES)
305            .map(|_| rng.random())
306            .collect();
307
308        let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.try_into()?);
309        let data = Vec::from(msg_1.clone());
310        let msg_2 = SessionMessage::try_from(&data[..])?;
311
312        assert_eq!(msg_1, msg_2);
313
314        Ok(())
315    }
316
317    #[test]
318    fn session_message_segment_request_should_yield_correct_bitset_values() {
319        let seg_req = SegmentRequest::<466>::from_iter([(10, MissingSegmentsBitmap::from([0b00101000]))]);
320
321        let mut iter = seg_req.into_iter();
322        assert_eq!(iter.next(), Some(SegmentId(10, 2)));
323        assert_eq!(iter.next(), Some(SegmentId(10, 4)));
324        assert_eq!(iter.next(), None);
325    }
326}