hopr_protocol_session/protocol/
messages.rs1use std::collections::{BTreeMap, BTreeSet};
4
5use bitvec::{BitArr, field::BitField, prelude::Msb0};
6
7use crate::{
8 errors::SessionError,
9 protocol::{FrameId, SegmentId, SeqNum, SessionMessage},
10};
11
12#[derive(Debug, Clone, PartialEq, Eq, Default)]
17pub struct SegmentRequest<const C: usize>(pub(super) BTreeMap<FrameId, SeqNum>);
18
19pub type MissingSegmentsBitmap = BitArr!(for 1, in SeqNum, Msb0);
24
25impl<const C: usize> SegmentRequest<C> {
26 pub const ENTRY_SIZE: usize = size_of::<FrameId>() + size_of::<SeqNum>();
28 pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
30 pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = SeqNum::BITS as usize;
32 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
34
35 pub fn len(&self) -> usize {
37 self.0
38 .values()
39 .take(Self::MAX_ENTRIES)
40 .map(|e| e.count_ones() as usize)
41 .sum()
42 }
43
44 pub fn is_empty(&self) -> bool {
46 self.0.iter().take(Self::MAX_ENTRIES).all(|(_, e)| e.count_ones() == 0)
47 }
48}
49
50impl<const C: usize> IntoIterator for SegmentRequest<C> {
51 type IntoIter = std::vec::IntoIter<Self::Item>;
52 type Item = SegmentId;
53
54 fn into_iter(self) -> Self::IntoIter {
56 let seq_size = SeqNum::BITS as usize;
57 let mut ret = Vec::with_capacity(seq_size * self.0.len());
58 for (frame_id, missing) in self.0 {
59 ret.extend(
60 MissingSegmentsBitmap::from([missing])
61 .iter_ones()
62 .map(|i| SegmentId(frame_id, i as SeqNum)),
63 );
64 }
65 ret.into_iter()
66 }
67}
68
69impl<const C: usize> FromIterator<(FrameId, MissingSegmentsBitmap)> for SegmentRequest<C> {
71 fn from_iter<T: IntoIterator<Item = (FrameId, MissingSegmentsBitmap)>>(iter: T) -> Self {
72 Self(
73 iter.into_iter()
74 .map(|(fid, missing_segments)| (fid, missing_segments.load()))
75 .collect(),
76 )
77 }
78}
79
80impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
81 type Error = SessionError;
82
83 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
84 if value.len() == Self::SIZE {
85 let mut ret = Self::default();
86 for (frame_id, missing) in value
87 .chunks_exact(Self::ENTRY_SIZE)
88 .map(|c| c.split_at(size_of::<FrameId>()))
89 {
90 let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
91 if frame_id > 0 {
92 ret.0.insert(
93 frame_id,
94 SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
95 );
96 }
97 }
98 Ok(ret)
99 } else {
100 Err(SessionError::ParseError)
101 }
102 }
103}
104
105impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
106 fn from(value: SegmentRequest<C>) -> Self {
107 let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
108 let mut offset = 0;
109 for (frame_id, seq_num) in value.0 {
110 if offset + size_of::<FrameId>() + size_of::<SeqNum>() <= SegmentRequest::<C>::SIZE {
111 ret[offset..offset + size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
112 offset += size_of::<FrameId>();
113 ret[offset..offset + size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
114 offset += size_of::<SeqNum>();
115 } else {
116 break;
117 }
118 }
119 ret
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq, Default)]
127pub struct FrameAcknowledgements<const C: usize>(pub(super) BTreeSet<FrameId>);
128
129impl<const C: usize> FrameAcknowledgements<C> {
130 pub const MAX_ACK_FRAMES: usize = Self::SIZE / size_of::<FrameId>();
132 pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
134
135 #[inline]
139 pub fn push(&mut self, frame_id: FrameId) -> bool {
140 !self.is_full() && self.0.insert(frame_id)
141 }
142
143 #[inline]
145 pub fn len(&self) -> usize {
146 self.0.len()
147 }
148
149 pub fn is_empty(&self) -> bool {
151 self.0.is_empty()
152 }
153
154 #[inline]
157 pub fn is_full(&self) -> bool {
158 self.0.len() == Self::MAX_ACK_FRAMES
159 }
160
161 pub fn new_multiple<T: IntoIterator<Item = FrameId>>(items: T) -> Vec<Self> {
164 let mut out = Vec::with_capacity(2);
165 let mut frame_ack = Self::default();
166 for frame_id in items {
167 if frame_ack.is_full() {
168 out.push(frame_ack);
169 frame_ack = Self::default();
170 }
171
172 frame_ack.push(frame_id);
173 }
174 out.push(frame_ack);
175 out
176 }
177}
178
179impl<const C: usize> TryFrom<Vec<FrameId>> for FrameAcknowledgements<C> {
180 type Error = SessionError;
181
182 fn try_from(value: Vec<FrameId>) -> Result<Self, Self::Error> {
183 if value.len() <= Self::MAX_ACK_FRAMES {
184 value
185 .into_iter()
186 .map(|v| {
187 if v > 0 {
188 Ok(v)
189 } else {
190 Err(SessionError::InvalidFrameId)
191 }
192 })
193 .collect::<Result<BTreeSet<_>, _>>()
194 .map(Self)
195 } else {
196 Err(SessionError::DataTooLong)
197 }
198 }
199}
200
201impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
202 type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
203 type Item = FrameId;
204
205 fn into_iter(self) -> Self::IntoIter {
206 self.0.into_iter()
207 }
208}
209
210impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
211 type Error = SessionError;
212
213 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
214 if value.len() == Self::SIZE {
215 Ok(Self(
216 value
218 .chunks_exact(size_of::<FrameId>())
219 .map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
220 .filter(|f| *f > 0)
221 .collect(),
222 ))
223 } else {
224 Err(SessionError::ParseError)
225 }
226 }
227}
228
229impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
230 fn from(value: FrameAcknowledgements<C>) -> Self {
231 value
232 .0
233 .iter()
234 .flat_map(|v| v.to_be_bytes())
235 .chain(std::iter::repeat(0_u8))
236 .take(FrameAcknowledgements::<C>::SIZE)
237 .collect::<Vec<_>>()
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_frame_acks_multiple_single() {
247 let mut acks = FrameAcknowledgements::<1024>::new_multiple(vec![1, 2, 3]);
248 assert_eq!(acks.len(), 1);
249
250 let ids = acks.remove(0).into_iter().collect::<Vec<_>>();
251 assert_eq!(ids, vec![1, 2, 3]);
252 }
253
254 #[test]
255 fn test_frame_acks_multiple_many() {
256 const MAX: usize = FrameAcknowledgements::<1024>::MAX_ACK_FRAMES;
257
258 let expected = (0..(2 * MAX + 2) as FrameId).collect::<Vec<_>>();
259 let acks = FrameAcknowledgements::<1024>::new_multiple(expected.clone());
260 assert_eq!(3, acks.len());
261
262 assert_eq!(MAX, acks[0].len());
263 assert_eq!(MAX, acks[1].len());
264 assert_eq!(2, acks[2].len());
265
266 let actual = acks.into_iter().flat_map(|a| a.into_iter()).collect::<Vec<_>>();
267 assert_eq!(expected, actual);
268 }
269
270 #[test]
271 fn test_missing_segments_in_segment_request() {
272 let frame_1_missing: MissingSegmentsBitmap = [0b00000000_u8].into();
273 let frame_2_missing: MissingSegmentsBitmap = [0b00100000_u8].into();
274 let frame_3_missing: MissingSegmentsBitmap = [0b00111001_u8].into();
275 let frame_4_missing: MissingSegmentsBitmap = [0b11111111_u8].into();
276
277 let req = SegmentRequest::<1000>::from_iter([
278 (4, frame_4_missing),
279 (1, frame_1_missing),
280 (3, frame_3_missing),
281 (2, frame_2_missing),
282 ]);
283
284 let missing = req.into_iter().collect::<Vec<SegmentId>>();
286 let missing_seg_ids = [
287 SegmentId(2, 2),
288 SegmentId(3, 2),
289 SegmentId(3, 3),
290 SegmentId(3, 4),
291 SegmentId(3, 7),
292 SegmentId(4, 0),
293 SegmentId(4, 1),
294 SegmentId(4, 2),
295 SegmentId(4, 3),
296 SegmentId(4, 4),
297 SegmentId(4, 5),
298 SegmentId(4, 6),
299 SegmentId(4, 7),
300 ];
301
302 assert_eq!(missing, missing_seg_ids);
303 }
304}