Skip to main content

hopr_protocol_session/protocol/
messages.rs

1//! Contains definitions of Session protocol messages.
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use bitvec::{BitArr, field::BitField, prelude::Msb0};
6
7use crate::{
8    errors::SessionError,
9    protocol::{FrameId, SegmentId, SeqNum, SessionMessage},
10};
11
12/// Holds the Segment Retransmission Request message.
13///
14/// That is an ordered map of frame IDs and a bitmap of missing segments in each frame.
15/// The bitmap can cover up a request for up to [`SegmentRequest::MAX_ENTRIES`] segments.
16#[derive(Debug, Clone, PartialEq, Eq, Default)]
17pub struct SegmentRequest<const C: usize>(pub(super) BTreeMap<FrameId, SeqNum>);
18
19/// Bitmap of segments missing in a frame.
20///
21/// Represented by `u8`, it can cover up to 8 segments per frame.
22/// If a bit is set, the segment is *missing* from the frame.
23pub type MissingSegmentsBitmap = BitArr!(for 1, in SeqNum, Msb0);
24
25impl<const C: usize> SegmentRequest<C> {
26    /// Size of a single segment retransmission request entry.
27    pub const ENTRY_SIZE: usize = size_of::<FrameId>() + size_of::<SeqNum>();
28    /// Maximum number of segment retransmission entries.
29    pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
30    /// Maximum number of missing segments per frame.
31    pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = SeqNum::BITS as usize;
32    /// Size of the message.
33    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
34
35    /// Returns the total number of segments to retransmit for all frames in this request.
36    pub fn len(&self) -> usize {
37        self.0
38            .values()
39            .take(Self::MAX_ENTRIES)
40            .map(|e| e.count_ones() as usize)
41            .sum()
42    }
43
44    /// Returns true if there are no segments to retransmit in this request.
45    pub fn is_empty(&self) -> bool {
46        self.0.iter().take(Self::MAX_ENTRIES).all(|(_, e)| e.count_ones() == 0)
47    }
48}
49
50impl<const C: usize> IntoIterator for SegmentRequest<C> {
51    type IntoIter = std::vec::IntoIter<Self::Item>;
52    type Item = SegmentId;
53
54    // An ordered iterator of missing segments in the form of SegmentId tuples.
55    fn into_iter(self) -> Self::IntoIter {
56        let seq_size = SeqNum::BITS as usize;
57        let mut ret = Vec::with_capacity(seq_size * self.0.len());
58        for (frame_id, missing) in self.0 {
59            ret.extend(
60                MissingSegmentsBitmap::from([missing])
61                    .iter_ones()
62                    .map(|i| SegmentId(frame_id, i as SeqNum)),
63            );
64        }
65        ret.into_iter()
66    }
67}
68
69// From FrameIds and bitmap of missing segments per frame
70impl<const C: usize> FromIterator<(FrameId, MissingSegmentsBitmap)> for SegmentRequest<C> {
71    fn from_iter<T: IntoIterator<Item = (FrameId, MissingSegmentsBitmap)>>(iter: T) -> Self {
72        Self(
73            iter.into_iter()
74                .map(|(fid, missing_segments)| (fid, missing_segments.load()))
75                .collect(),
76        )
77    }
78}
79
80impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
81    type Error = SessionError;
82
83    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
84        if value.len() == Self::SIZE {
85            let mut ret = Self::default();
86            for (frame_id, missing) in value
87                .chunks_exact(Self::ENTRY_SIZE)
88                .map(|c| c.split_at(size_of::<FrameId>()))
89            {
90                let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
91                if frame_id > 0 {
92                    ret.0.insert(
93                        frame_id,
94                        SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
95                    );
96                }
97            }
98            Ok(ret)
99        } else {
100            Err(SessionError::ParseError)
101        }
102    }
103}
104
105impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
106    fn from(value: SegmentRequest<C>) -> Self {
107        let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
108        let mut offset = 0;
109        for (frame_id, seq_num) in value.0 {
110            if offset + size_of::<FrameId>() + size_of::<SeqNum>() <= SegmentRequest::<C>::SIZE {
111                ret[offset..offset + size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
112                offset += size_of::<FrameId>();
113                ret[offset..offset + size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
114                offset += size_of::<SeqNum>();
115            } else {
116                break;
117            }
118        }
119        ret
120    }
121}
122
123/// Holds the Frame Acknowledgement message.
124/// This carries an ordered set of up to [`FrameAcknowledgements::MAX_ACK_FRAMES`] [frame IDs](FrameId)
125/// that has been acknowledged by the counterparty.
126#[derive(Debug, Clone, PartialEq, Eq, Default)]
127pub struct FrameAcknowledgements<const C: usize>(pub(super) BTreeSet<FrameId>);
128
129impl<const C: usize> FrameAcknowledgements<C> {
130    /// Maximum number of [`FrameIds`](FrameId) that can be accommodated.
131    pub const MAX_ACK_FRAMES: usize = Self::SIZE / size_of::<FrameId>();
132    /// Size of the message.
133    pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
134
135    /// Pushes the frame ID.
136    /// Returns true if the value has been pushed or false it the container is full or already
137    /// contains that value.
138    #[inline]
139    pub fn push(&mut self, frame_id: FrameId) -> bool {
140        !self.is_full() && self.0.insert(frame_id)
141    }
142
143    /// Number of acknowledged frame IDs in this instance.
144    #[inline]
145    pub fn len(&self) -> usize {
146        self.0.len()
147    }
148
149    /// Returns true if there are no frame IDs in this instance.
150    pub fn is_empty(&self) -> bool {
151        self.0.is_empty()
152    }
153
154    /// Indicates whether the [maximum number of frame IDs](FrameAcknowledgements::MAX_ACK_FRAMES)
155    /// has been reached.
156    #[inline]
157    pub fn is_full(&self) -> bool {
158        self.0.len() == Self::MAX_ACK_FRAMES
159    }
160
161    /// Creates a vector of [`FrameAcknowledgements`](FrameAcknowledgements) from the given iterator
162    /// of acknowledged [`FrameIds`](FrameId).
163    pub fn new_multiple<T: IntoIterator<Item = FrameId>>(items: T) -> Vec<Self> {
164        let mut out = Vec::with_capacity(2);
165        let mut frame_ack = Self::default();
166        for frame_id in items {
167            if frame_ack.is_full() {
168                out.push(frame_ack);
169                frame_ack = Self::default();
170            }
171
172            frame_ack.push(frame_id);
173        }
174        out.push(frame_ack);
175        out
176    }
177}
178
179impl<const C: usize> TryFrom<Vec<FrameId>> for FrameAcknowledgements<C> {
180    type Error = SessionError;
181
182    fn try_from(value: Vec<FrameId>) -> Result<Self, Self::Error> {
183        if value.len() <= Self::MAX_ACK_FRAMES {
184            value
185                .into_iter()
186                .map(|v| {
187                    if v > 0 {
188                        Ok(v)
189                    } else {
190                        Err(SessionError::InvalidFrameId)
191                    }
192                })
193                .collect::<Result<BTreeSet<_>, _>>()
194                .map(Self)
195        } else {
196            Err(SessionError::DataTooLong)
197        }
198    }
199}
200
201impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
202    type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
203    type Item = FrameId;
204
205    fn into_iter(self) -> Self::IntoIter {
206        self.0.into_iter()
207    }
208}
209
210impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
211    type Error = SessionError;
212
213    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
214        if value.len() == Self::SIZE {
215            Ok(Self(
216                // chunks_exact discards the remainder bytes
217                value
218                    .chunks_exact(size_of::<FrameId>())
219                    .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
220                    .filter(|f| *f > 0)
221                    .collect(),
222            ))
223        } else {
224            Err(SessionError::ParseError)
225        }
226    }
227}
228
229impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
230    fn from(value: FrameAcknowledgements<C>) -> Self {
231        value
232            .0
233            .iter()
234            .flat_map(|v| v.to_be_bytes())
235            .chain(std::iter::repeat(0_u8))
236            .take(FrameAcknowledgements::<C>::SIZE)
237            .collect::<Vec<_>>()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_frame_acks_multiple_single() {
247        let mut acks = FrameAcknowledgements::<1024>::new_multiple(vec![1, 2, 3]);
248        assert_eq!(acks.len(), 1);
249
250        let ids = acks.remove(0).into_iter().collect::<Vec<_>>();
251        assert_eq!(ids, vec![1, 2, 3]);
252    }
253
254    #[test]
255    fn test_frame_acks_multiple_many() {
256        const MAX: usize = FrameAcknowledgements::<1024>::MAX_ACK_FRAMES;
257
258        let expected = (0..(2 * MAX + 2) as FrameId).collect::<Vec<_>>();
259        let acks = FrameAcknowledgements::<1024>::new_multiple(expected.clone());
260        assert_eq!(3, acks.len());
261
262        assert_eq!(MAX, acks[0].len());
263        assert_eq!(MAX, acks[1].len());
264        assert_eq!(2, acks[2].len());
265
266        let actual = acks.into_iter().flat_map(|a| a.into_iter()).collect::<Vec<_>>();
267        assert_eq!(expected, actual);
268    }
269
270    #[test]
271    fn test_missing_segments_in_segment_request() {
272        let frame_1_missing: MissingSegmentsBitmap = [0b00000000_u8].into();
273        let frame_2_missing: MissingSegmentsBitmap = [0b00100000_u8].into();
274        let frame_3_missing: MissingSegmentsBitmap = [0b00111001_u8].into();
275        let frame_4_missing: MissingSegmentsBitmap = [0b11111111_u8].into();
276
277        let req = SegmentRequest::<1000>::from_iter([
278            (4, frame_4_missing),
279            (1, frame_1_missing),
280            (3, frame_3_missing),
281            (2, frame_2_missing),
282        ]);
283
284        // Iterator of SegmentIds is guaranteed to be sorted
285        let missing = req.into_iter().collect::<Vec<SegmentId>>();
286        let missing_seg_ids = [
287            SegmentId(2, 2),
288            SegmentId(3, 2),
289            SegmentId(3, 3),
290            SegmentId(3, 4),
291            SegmentId(3, 7),
292            SegmentId(4, 0),
293            SegmentId(4, 1),
294            SegmentId(4, 2),
295            SegmentId(4, 3),
296            SegmentId(4, 4),
297            SegmentId(4, 5),
298            SegmentId(4, 6),
299            SegmentId(4, 7),
300        ];
301
302        assert_eq!(missing, missing_seg_ids);
303    }
304}