hopr_protocol_session/protocol/
mod.rs1mod frames;
54mod messages;
55
56use asynchronous_codec::{Decoder, Encoder};
57use bytes::{Buf, BufMut, BytesMut};
58pub use frames::{Frame, FrameId, OrderedFrame, Segment, SegmentId, SeqIndicator, SeqNum};
59pub use messages::{FrameAcknowledgements, MissingSegmentsBitmap, SegmentRequest};
60
61use crate::errors::SessionError;
62
63#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
67#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
68pub enum SessionMessage<const C: usize> {
69 Segment(Segment),
71 Request(SegmentRequest<C>),
73 Acknowledge(FrameAcknowledgements<C>),
75}
76
77impl<const C: usize> std::fmt::Display for SessionMessage<C> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match &self {
80 SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
81 SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
82 SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
83 }
84 }
85}
86
87impl<const C: usize> SessionMessage<C> {
88 pub const HEADER_SIZE: usize = 1 + size_of::<SessionMessageDiscriminants>() + size_of::<u16>();
92 pub const MAX_MESSAGE_LENGTH: usize = C.saturating_sub(Self::HEADER_SIZE);
94 pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
98 pub const VERSION: u8 = 1;
100
101 pub fn minimum_message_size() -> usize {
103 Self::HEADER_SIZE
105 + Segment::HEADER_SIZE
106 .min(SegmentRequest::<C>::SIZE)
107 .min(FrameAcknowledgements::<C>::SIZE)
108 }
109
110 pub fn into_encoded(self) -> Box<[u8]> {
112 Vec::from(self).into_boxed_slice()
113 }
114}
115
116impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
117 fn from(message: SessionMessage<C>) -> Self {
118 debug_assert!(
119 C > SessionMessage::<C>::HEADER_SIZE && SessionMessage::<C>::MAX_MESSAGE_LENGTH <= u16::MAX as usize
120 );
121
122 let mut result = BytesMut::new();
123 SessionCodec::<C>
124 .encode(message, &mut result)
125 .expect("encoding never fails");
126
127 result.to_vec()
128 }
129}
130
131impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
132 type Error = SessionError;
133
134 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
135 SessionCodec
136 .decode(&mut BytesMut::from(value))?
137 .ok_or(SessionError::IncorrectMessageLength)
138 }
139}
140
141#[derive(Clone, Copy, Default)]
142pub struct SessionCodec<const C: usize>;
143
144impl<const C: usize> Encoder for SessionCodec<C> {
145 type Error = SessionError;
146 type Item<'a> = SessionMessage<C>;
147
148 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
149 debug_assert!(
150 C > SessionMessage::<C>::HEADER_SIZE && SessionMessage::<C>::MAX_MESSAGE_LENGTH <= u16::MAX as usize
151 );
152
153 let disc = SessionMessageDiscriminants::from(&item) as u8;
154
155 let msg = match item {
156 SessionMessage::Segment(s) => Vec::from(s),
157 SessionMessage::Request(r) => Vec::from(r),
158 SessionMessage::Acknowledge(a) => Vec::from(a),
159 };
160
161 if msg.len() > SessionMessage::<C>::MAX_MESSAGE_LENGTH {
162 return Err(SessionError::IncorrectMessageLength);
163 }
164
165 let msg_len = msg.len() as u16;
166 dst.put_u8(SessionMessage::<C>::VERSION);
167 dst.put_u8(disc);
168 dst.put_u16(msg_len);
169 dst.extend_from_slice(&msg);
170
171 tracing::trace!(disc, msg_len, "encoded message");
172 Ok(())
173 }
174}
175
176impl<const C: usize> Decoder for SessionCodec<C> {
177 type Error = SessionError;
178 type Item = SessionMessage<C>;
179
180 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
181 debug_assert!(C > SessionMessage::<C>::HEADER_SIZE);
182
183 tracing::trace!(msg_len = src.len(), "decoding message");
184 if src.len() < SessionMessage::<C>::minimum_message_size() {
185 return Ok(None);
186 }
187
188 if src[0] != SessionMessage::<C>::VERSION {
190 return Err(SessionError::WrongVersion);
191 }
192
193 let disc = src[1];
195
196 let payload_len = u16::from_be_bytes([src[2], src[3]]) as usize;
198
199 if payload_len > SessionMessage::<C>::MAX_MESSAGE_LENGTH {
201 return Err(SessionError::IncorrectMessageLength);
202 }
203
204 if src.len() < SessionMessage::<C>::HEADER_SIZE + payload_len {
206 return Ok(None);
207 }
208
209 let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
211 SessionMessageDiscriminants::Segment => SessionMessage::Segment(
212 src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
213 ),
214 SessionMessageDiscriminants::Request => SessionMessage::Request(
215 src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
216 ),
217 SessionMessageDiscriminants::Acknowledge => SessionMessage::Acknowledge(
218 src[SessionMessage::<C>::HEADER_SIZE..SessionMessage::<C>::HEADER_SIZE + payload_len].try_into()?,
219 ),
220 };
221
222 src.advance(SessionMessage::<C>::HEADER_SIZE + payload_len);
223 Ok(Some(res))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use hex_literal::hex;
230 use hopr_protocol_app::prelude::ApplicationData;
231 use rand::{RngExt, rngs::ThreadRng};
232
233 use super::*;
234 use crate::{
235 protocol::{FrameId, SegmentId},
236 utils::segment,
237 };
238
239 #[test]
240 fn ensure_session_protocol_version_1_values() {
241 assert_eq!(1, SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::VERSION);
243 assert_eq!(4, SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::HEADER_SIZE);
244 assert_eq!(
245 10,
246 SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::SEGMENT_OVERHEAD
247 );
248 assert_eq!(
249 1024,
250 SessionMessage::<{ ApplicationData::PAYLOAD_SIZE }>::MAX_MESSAGE_LENGTH
251 );
252 }
253
254 #[test]
255 fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
256 const SEG_SIZE: usize = 8;
257
258 let mut segments = segment(hex!("deadbeefcafebabe"), SEG_SIZE, 10)?;
259
260 const MTU: usize = SEG_SIZE + SessionMessage::<0>::SEGMENT_OVERHEAD;
261
262 let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
263 let data = Vec::from(msg_1.clone());
264 let msg_2 = SessionMessage::try_from(&data[..])?;
265
266 assert_eq!(msg_1, msg_2);
267
268 Ok(())
269 }
270
271 #[test]
272 fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
273 let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter([
275 (2 as FrameId, [0b11000001].into()),
276 (10 as FrameId, [0b01000100].into()),
277 ]));
278 let data = Vec::from(msg_1.clone());
279 let msg_2 = SessionMessage::try_from(&data[..])?;
280
281 assert_eq!(msg_1, msg_2);
282
283 match msg_1 {
284 SessionMessage::Request(r) => {
285 let missing_segments = r.into_iter().collect::<Vec<_>>();
286 let expected = vec![
287 SegmentId(2, 0),
288 SegmentId(2, 1),
289 SegmentId(2, 7),
290 SegmentId(10, 1),
291 SegmentId(10, 5),
292 ];
293 assert_eq!(expected, missing_segments);
294 }
295 _ => panic!("invalid type"),
296 }
297
298 Ok(())
299 }
300
301 #[test]
302 fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
303 let mut rng = ThreadRng::default();
304 let frame_ids: Vec<u32> = (0..FrameAcknowledgements::<466>::MAX_ACK_FRAMES)
305 .map(|_| rng.random())
306 .collect();
307
308 let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.try_into()?);
309 let data = Vec::from(msg_1.clone());
310 let msg_2 = SessionMessage::try_from(&data[..])?;
311
312 assert_eq!(msg_1, msg_2);
313
314 Ok(())
315 }
316
317 #[test]
318 fn session_message_segment_request_should_yield_correct_bitset_values() {
319 let seg_req = SegmentRequest::<466>::from_iter([(10, MissingSegmentsBitmap::from([0b00101000]))]);
320
321 let mut iter = seg_req.into_iter();
322 assert_eq!(iter.next(), Some(SegmentId(10, 2)));
323 assert_eq!(iter.next(), Some(SegmentId(10, 4)));
324 assert_eq!(iter.next(), None);
325 }
326}