Skip to main content

hopr_protocol_session/protocol/
frames.rs

1//! Contains basic types for the Session protocol.
2
3use std::{
4    cmp::Ordering,
5    fmt::{Debug, Display, Formatter},
6};
7
8use hopr_primitive_types::prelude::{GeneralError, to_hex_shortened};
9
10use crate::errors::SessionError;
11
12/// ID of a [Frame].
13pub type FrameId = u32;
14
15/// Type representing the sequence numbers in a [Frame].
16pub type SeqNum = u8;
17
18/// Convenience type that identifies a segment within a frame.
19#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
21pub struct SegmentId(pub FrameId, pub SeqNum);
22
23impl Display for SegmentId {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        write!(f, "seg({},{})", self.0, self.1)
26    }
27}
28
29/// Data frame of arbitrary length.
30///
31/// The frame can be segmented into [segments](Segment) and reassembled back
32/// via [`FrameBuilder`].
33#[derive(Clone, PartialEq, Eq)]
34pub struct Frame {
35    /// Identifier of this frame.
36    pub frame_id: FrameId,
37    /// Frame data.
38    pub data: Box<[u8]>,
39    /// Indicates whether the frame is the last one of the frame sequence.
40    ///
41    /// This indicates that the Session is over.
42    pub is_terminating: bool,
43}
44
45impl Debug for Frame {
46    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("Frame")
48            .field("frame_id", &self.frame_id)
49            .field("len", &self.data.len())
50            .field("data", &to_hex_shortened::<16>(&self.data))
51            .field("is_terminating", &self.is_terminating)
52            .finish()
53    }
54}
55
56impl AsRef<[u8]> for Frame {
57    fn as_ref(&self) -> &[u8] {
58        &self.data
59    }
60}
61
62/// Wrapper for [`Frame`] that implements comparison and total ordering based on [`FrameId`].
63#[derive(Clone, Debug)]
64pub struct OrderedFrame(pub Frame);
65
66impl Eq for OrderedFrame {}
67
68impl PartialEq<Self> for OrderedFrame {
69    fn eq(&self, other: &Self) -> bool {
70        self.0.frame_id == other.0.frame_id
71    }
72}
73
74impl PartialEq<FrameId> for OrderedFrame {
75    fn eq(&self, other: &FrameId) -> bool {
76        self.0.frame_id == *other
77    }
78}
79
80impl PartialOrd<Self> for OrderedFrame {
81    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82        Some(self.cmp(other))
83    }
84}
85
86impl PartialOrd<FrameId> for OrderedFrame {
87    fn partial_cmp(&self, other: &FrameId) -> Option<Ordering> {
88        Some(self.0.frame_id.cmp(other))
89    }
90}
91
92impl Ord for OrderedFrame {
93    fn cmp(&self, other: &Self) -> Ordering {
94        self.0.frame_id.cmp(&other.0.frame_id)
95    }
96}
97
98impl From<Frame> for OrderedFrame {
99    fn from(value: Frame) -> Self {
100        Self(value)
101    }
102}
103
104impl From<OrderedFrame> for Frame {
105    fn from(value: OrderedFrame) -> Self {
106        value.0
107    }
108}
109
110/// Carries segment flags and the length of the segment sequence.
111#[derive(Copy, Clone, Eq, PartialEq, Default, PartialOrd, Ord, Hash)]
112#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
113pub struct SeqIndicator(SeqNum);
114
115impl SeqIndicator {
116    /// Maximum length of a segment sequence.
117    pub const MAX: SeqNum = 0b0011_1111;
118
119    #[inline]
120    pub const fn new_with_flags(seq_len: SeqNum, is_terminating: bool) -> Self {
121        let flags = ((is_terminating as u8) << 7) | (seq_len & Self::MAX);
122        Self(flags)
123    }
124
125    #[inline]
126    pub const fn new(seq_len: SeqNum) -> Self {
127        Self::new_with_flags(seq_len, false)
128    }
129
130    #[inline]
131    const fn new_unchecked(seq_ind: SeqNum) -> Self {
132        Self(seq_ind)
133    }
134
135    #[inline]
136    pub fn with_terminating_bit(self, is_terminating: bool) -> Self {
137        Self::new_with_flags(self.0, is_terminating)
138    }
139
140    #[inline]
141    pub const fn is_terminating(&self) -> bool {
142        self.0 & 0b1000_0000 != 0
143    }
144
145    #[inline]
146    pub const fn seq_len(&self) -> SeqNum {
147        self.0 & Self::MAX
148    }
149
150    #[inline]
151    pub const fn value(&self) -> SeqNum {
152        self.0
153    }
154}
155
156impl Debug for SeqIndicator {
157    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("SeqIndicator")
159            .field("seq_len", &self.seq_len())
160            .field("is_terminating", &self.is_terminating())
161            .finish()
162    }
163}
164
165impl Display for SeqIndicator {
166    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
167        write!(f, "{}{}", self.seq_len(), if self.is_terminating() { "*" } else { "" })
168    }
169}
170
171impl TryFrom<SeqNum> for SeqIndicator {
172    type Error = GeneralError;
173
174    fn try_from(value: SeqNum) -> Result<Self, Self::Error> {
175        if value <= Self::MAX {
176            Ok(Self(value))
177        } else {
178            Err(GeneralError::InvalidInput)
179        }
180    }
181}
182
183/// Represents a frame segment.
184///
185/// Besides the data, a segment carries information about the total number of
186/// segments in the original frame ([`SeqIndicator`]), its index within the frame ([`SeqNum`]), and
187/// ID of that frame ([`FrameId`]).
188#[derive(Clone, Eq, PartialEq)]
189#[cfg_attr(feature = "serde", derive(serde::Serialize), derive(serde::Deserialize))]
190pub struct Segment {
191    /// ID of the [Frame] this segment belongs to.
192    pub frame_id: FrameId,
193    /// Index of this segment within the segment sequence.
194    pub seq_idx: SeqNum,
195    /// Flags of the segment sequence (includes sequence length).
196    pub seq_flags: SeqIndicator,
197    /// Data in this segment.
198    #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
199    pub data: Box<[u8]>,
200}
201
202impl Segment {
203    /// Size of the segment header.
204    pub const HEADER_SIZE: usize = size_of::<FrameId>() + 2 * size_of::<SeqNum>();
205
206    /// Returns the [SegmentId] for this segment.
207    pub fn id(&self) -> SegmentId {
208        SegmentId(self.frame_id, self.seq_idx)
209    }
210
211    /// Length of the segment data plus header.
212    #[allow(clippy::len_without_is_empty)]
213    pub fn len(&self) -> usize {
214        Self::HEADER_SIZE + self.data.len()
215    }
216
217    /// Indicates whether this segment is the last one from the frame.
218    #[inline]
219    pub fn is_last(&self) -> bool {
220        self.seq_idx == self.seq_flags.seq_len() - 1
221    }
222
223    /// Short-cut to check if this segment is a terminating segment.
224    #[inline]
225    pub fn is_terminating(&self) -> bool {
226        self.seq_flags.is_terminating()
227    }
228
229    /// Creates an empty `Segment` with the terminating flag set.
230    pub fn terminating(frame_id: FrameId) -> Self {
231        Self {
232            frame_id,
233            seq_idx: 0,
234            seq_flags: SeqIndicator::new_with_flags(1, true),
235            data: Box::default(),
236        }
237    }
238}
239
240impl Debug for Segment {
241    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242        f.debug_struct("Segment")
243            .field("frame_id", &self.frame_id)
244            .field("seq_id", &self.seq_idx)
245            .field("seq_flags", &self.seq_flags)
246            .field("data", &to_hex_shortened::<16>(&self.data))
247            .finish()
248    }
249}
250
251impl PartialOrd<Segment> for Segment {
252    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
253        Some(self.cmp(other))
254    }
255}
256
257impl Ord for Segment {
258    fn cmp(&self, other: &Self) -> Ordering {
259        match self.frame_id.cmp(&other.frame_id) {
260            Ordering::Equal => self.seq_idx.cmp(&other.seq_idx),
261            cmp => cmp,
262        }
263    }
264}
265
266impl From<&Segment> for SegmentId {
267    fn from(value: &Segment) -> Self {
268        value.id()
269    }
270}
271
272impl From<Segment> for Vec<u8> {
273    fn from(value: Segment) -> Self {
274        let mut ret = Vec::with_capacity(Segment::HEADER_SIZE + value.data.len());
275        ret.extend_from_slice(value.frame_id.to_be_bytes().as_ref());
276        ret.extend_from_slice(value.seq_idx.to_be_bytes().as_ref());
277        ret.push(value.seq_flags.value());
278        ret.extend_from_slice(value.data.as_ref());
279        ret
280    }
281}
282
283impl TryFrom<&[u8]> for Segment {
284    type Error = SessionError;
285
286    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
287        if value.len() < Self::HEADER_SIZE {
288            return Err(SessionError::InvalidSegment);
289        }
290
291        let (header, data) = value.split_at(Self::HEADER_SIZE);
292        let segment = Segment {
293            frame_id: FrameId::from_be_bytes(header[0..4].try_into().map_err(|_| SessionError::InvalidSegment)?),
294            seq_idx: SeqNum::from_be_bytes(header[4..5].try_into().map_err(|_| SessionError::InvalidSegment)?),
295            seq_flags: SeqIndicator::new_unchecked(header[5]),
296            data: data.into(),
297        };
298        (segment.frame_id > 0 && segment.seq_idx < segment.seq_flags.seq_len())
299            .then_some(segment)
300            .ok_or(SessionError::InvalidSegment)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn terminating_sequence_indicator_should_be_greater_than_non_terminating() -> anyhow::Result<()> {
310        let ind_1 = SeqIndicator::new_with_flags(1, true);
311        let ind_2 = SeqIndicator::new_with_flags(1, false);
312
313        assert!(ind_1 > ind_2);
314        Ok(())
315    }
316
317    #[test]
318    fn segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
319        let seg_1 = Segment {
320            frame_id: 10,
321            seq_idx: 0,
322            seq_flags: 2.try_into()?,
323            data: Box::new([123u8]),
324        };
325
326        let seg_2 = Segment::try_from(Vec::from(seg_1.clone()).as_slice())?;
327        assert_eq!(seg_1, seg_2);
328        Ok(())
329    }
330}