hopr_network_types/session/
protocol.rsuse std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Display, Formatter};
use std::mem;
use crate::errors::NetworkTypeError;
use crate::session::errors::SessionError;
use crate::session::frame::{FrameId, FrameInfo, Segment, SegmentId, SeqNum};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct SegmentRequest<const C: usize>(BTreeMap<FrameId, SeqNum>);
impl<const C: usize> SegmentRequest<C> {
pub const ENTRY_SIZE: usize = mem::size_of::<FrameId>() + mem::size_of::<SeqNum>();
pub const MAX_MISSING_SEGMENTS_PER_FRAME: usize = mem::size_of::<SeqNum>() * 8;
pub const MAX_ENTRIES: usize = Self::SIZE / Self::ENTRY_SIZE;
pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
pub fn len(&self) -> usize {
self.0
.values()
.take(Self::MAX_ENTRIES)
.map(|e| e.count_ones() as usize)
.sum()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
pub struct SegmentIdIter(Vec<SegmentId>);
impl Iterator for SegmentIdIter {
type Item = SegmentId;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop()
}
}
impl<const C: usize> IntoIterator for SegmentRequest<C> {
type Item = SegmentId;
type IntoIter = SegmentIdIter;
fn into_iter(self) -> Self::IntoIter {
let seq_size = mem::size_of::<SeqNum>() * 8;
let mut ret = SegmentIdIter(Vec::with_capacity(seq_size * 8 * self.0.len()));
for (frame_id, missing) in self.0 {
for i in (0..seq_size).rev() {
let mask = (1 << i) as SeqNum;
if (mask & missing) != 0 {
ret.0.push(SegmentId(frame_id, i as SeqNum));
}
}
}
ret.0.shrink_to_fit();
ret
}
}
impl<const C: usize> FromIterator<FrameInfo> for SegmentRequest<C> {
fn from_iter<T: IntoIterator<Item = FrameInfo>>(iter: T) -> Self {
let mut ret = Self::default();
for frame in iter.into_iter().take(Self::MAX_ENTRIES) {
let frame_id = frame.frame_id;
let missing = frame
.iter_missing_sequence_indices()
.filter(|s| *s < Self::MAX_MISSING_SEGMENTS_PER_FRAME as SeqNum)
.map(|idx| 1 << idx)
.fold(SeqNum::default(), |acc, n| acc | n);
ret.0.insert(frame_id, missing);
}
ret
}
}
impl<const C: usize> TryFrom<&[u8]> for SegmentRequest<C> {
type Error = SessionError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() == Self::SIZE {
let mut ret = Self::default();
for (frame_id, missing) in value
.chunks_exact(Self::ENTRY_SIZE)
.map(|c| c.split_at(mem::size_of::<FrameId>()))
{
let frame_id = FrameId::from_be_bytes(frame_id.try_into().map_err(|_| SessionError::ParseError)?);
if frame_id > 0 {
ret.0.insert(
frame_id,
SeqNum::from_be_bytes(missing.try_into().map_err(|_| SessionError::ParseError)?),
);
}
}
Ok(ret)
} else {
Err(SessionError::ParseError)
}
}
}
impl<const C: usize> From<SegmentRequest<C>> for Vec<u8> {
fn from(value: SegmentRequest<C>) -> Self {
let mut ret = vec![0u8; SegmentRequest::<C>::SIZE];
let mut offset = 0;
for (frame_id, seq_num) in value.0 {
if offset + mem::size_of::<FrameId>() + mem::size_of::<SeqNum>() < C {
ret[offset..offset + mem::size_of::<FrameId>()].copy_from_slice(&frame_id.to_be_bytes());
offset += mem::size_of::<FrameId>();
ret[offset..offset + mem::size_of::<SeqNum>()].copy_from_slice(&seq_num.to_be_bytes());
offset += mem::size_of::<SeqNum>();
} else {
break;
}
}
ret
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct FrameAcknowledgements<const C: usize>(BTreeSet<FrameId>);
impl<const C: usize> FrameAcknowledgements<C> {
pub const MAX_ACK_FRAMES: usize = Self::SIZE / mem::size_of::<FrameId>();
pub const SIZE: usize = C - SessionMessage::<C>::HEADER_SIZE;
#[inline]
pub fn push(&mut self, frame_id: FrameId) -> bool {
!self.is_full() && self.0.insert(frame_id)
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn is_full(&self) -> bool {
self.0.len() == Self::MAX_ACK_FRAMES
}
}
impl<const C: usize> From<Vec<FrameId>> for FrameAcknowledgements<C> {
fn from(value: Vec<FrameId>) -> Self {
Self(
value
.into_iter()
.take(Self::MAX_ACK_FRAMES)
.filter(|v| *v > 0)
.collect(),
)
}
}
impl<const C: usize> IntoIterator for FrameAcknowledgements<C> {
type Item = FrameId;
type IntoIter = std::collections::btree_set::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a, const C: usize> TryFrom<&'a [u8]> for FrameAcknowledgements<C> {
type Error = SessionError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
if value.len() == Self::SIZE {
Ok(Self(
value
.chunks_exact(mem::size_of::<FrameId>())
.map(|v| FrameId::from_be_bytes(v.try_into().unwrap()))
.filter(|f| *f > 0)
.collect(),
))
} else {
Err(SessionError::ParseError)
}
}
}
impl<const C: usize> From<FrameAcknowledgements<C>> for Vec<u8> {
fn from(value: FrameAcknowledgements<C>) -> Self {
value
.0
.iter()
.flat_map(|v| v.to_be_bytes())
.chain(std::iter::repeat(0_u8))
.take(FrameAcknowledgements::<C>::SIZE)
.collect::<Vec<_>>()
}
}
#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, strum::EnumTryAs)]
#[strum_discriminants(derive(strum::FromRepr), repr(u8))]
pub enum SessionMessage<const C: usize> {
Segment(Segment),
Request(SegmentRequest<C>),
Acknowledge(FrameAcknowledgements<C>),
}
impl<const C: usize> Display for SessionMessage<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self {
SessionMessage::Segment(s) => write!(f, "segment {}", s.id()),
SessionMessage::Request(r) => write!(f, "retransmission request of {:?}", r.0),
SessionMessage::Acknowledge(a) => write!(f, "acknowledgement of {:?}", a.0),
}
}
}
impl<const C: usize> SessionMessage<C> {
pub const HEADER_SIZE: usize = 1 + mem::size_of::<SessionMessageDiscriminants>() + mem::size_of::<u16>();
pub const SEGMENT_OVERHEAD: usize = Self::HEADER_SIZE + Segment::HEADER_SIZE;
pub const MAX_MESSAGE_SIZE: usize = 1492 - Self::SEGMENT_OVERHEAD;
pub const VERSION: u8 = 1;
pub const MAX_SEGMENTS_PER_FRAME: usize = SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME;
pub fn minimum_message_size() -> usize {
Self::HEADER_SIZE
+ Segment::MINIMUM_SIZE
.min(SegmentRequest::<C>::SIZE)
.min(FrameAcknowledgements::<C>::SIZE)
}
pub fn into_encoded(self) -> Box<[u8]> {
Vec::from(self).into_boxed_slice()
}
}
impl<const C: usize> TryFrom<&[u8]> for SessionMessage<C> {
type Error = SessionError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
SessionMessageIter::from(value).try_next()
}
}
impl<const C: usize> From<SessionMessage<C>> for Vec<u8> {
fn from(value: SessionMessage<C>) -> Self {
let disc = SessionMessageDiscriminants::from(&value) as u8;
let msg = match value {
SessionMessage::Segment(s) => Vec::from(s),
SessionMessage::Request(r) => Vec::from(r),
SessionMessage::Acknowledge(a) => Vec::from(a),
};
let msg_len = msg.len() as u16;
let mut ret = Vec::with_capacity(SessionMessage::<C>::HEADER_SIZE + msg_len as usize);
ret.push(SessionMessage::<C>::VERSION);
ret.push(disc);
ret.extend(msg_len.to_be_bytes());
ret.extend(msg);
ret
}
}
#[derive(Debug, Clone)]
pub struct SessionMessageIter<'a, const C: usize> {
data: Cow<'a, [u8]>,
offset: usize,
last_err: Option<SessionError>,
}
impl<const C: usize> SessionMessageIter<'_, C> {
pub fn last_error(&self) -> Option<&SessionError> {
self.last_err.as_ref()
}
pub fn is_done(&self) -> bool {
self.last_err.is_some() || self.data.len() - self.offset < SessionMessage::<C>::minimum_message_size()
}
fn try_next(&mut self) -> Result<SessionMessage<C>, SessionError> {
let mut offset = self.offset;
if self.data[offset] != SessionMessage::<C>::VERSION {
return Err(SessionError::WrongVersion);
}
offset += 1;
let disc = self.data[offset];
offset += 1;
let len = u16::from_be_bytes(
self.data[offset..offset + mem::size_of::<u16>()]
.try_into()
.map_err(|_| SessionError::IncorrectMessageLength)?,
) as usize;
offset += mem::size_of::<u16>();
if len > SessionMessage::<C>::MAX_MESSAGE_SIZE {
return Err(SessionError::IncorrectMessageLength);
}
let reserved = len & 0b111111_0000000000;
if reserved != 0 {
return Err(SessionError::ParseError);
}
let res = match SessionMessageDiscriminants::from_repr(disc).ok_or(SessionError::UnknownMessageTag)? {
SessionMessageDiscriminants::Segment => {
SessionMessage::Segment(self.data[offset..offset + len].try_into()?)
}
SessionMessageDiscriminants::Request => {
SessionMessage::Request(self.data[offset..offset + len].try_into()?)
}
SessionMessageDiscriminants::Acknowledge => {
SessionMessage::Acknowledge(self.data[offset..offset + len].try_into()?)
}
};
self.offset = offset + len;
Ok(res)
}
}
impl<'a, const C: usize, T: Into<Cow<'a, [u8]>>> From<T> for SessionMessageIter<'a, C> {
fn from(value: T) -> Self {
Self {
data: value.into(),
offset: 0,
last_err: None,
}
}
}
impl<const C: usize> Iterator for SessionMessageIter<'_, C> {
type Item = Result<SessionMessage<C>, NetworkTypeError>;
fn next(&mut self) -> Option<Self::Item> {
if !self.is_done() {
self.try_next()
.inspect_err(|e| self.last_err = Some(e.clone()))
.map_err(NetworkTypeError::SessionProtocolError)
.into()
} else {
None
}
}
}
impl<const C: usize> std::iter::FusedIterator for SessionMessageIter<'_, C> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::Frame;
use bitvec::array::BitArray;
use bitvec::bitarr;
use hex_literal::hex;
use hopr_platform::time::native::current_time;
use rand::prelude::IteratorRandom;
use rand::{thread_rng, Rng};
use std::time::SystemTime;
#[test]
fn ensure_session_protocol_version_1_values() {
assert_eq!(1, SessionMessage::<0>::VERSION);
assert_eq!(4, SessionMessage::<0>::HEADER_SIZE);
assert_eq!(10, SessionMessage::<0>::SEGMENT_OVERHEAD);
assert_eq!(8, SessionMessage::<0>::MAX_SEGMENTS_PER_FRAME);
assert!(SessionMessage::<0>::MAX_MESSAGE_SIZE < 2048);
}
#[test]
fn segment_request_should_be_constructible_from_frame_info() {
let frames = (1..20)
.map(|i| {
let mut missing_segments = BitArray::ZERO;
(0..7_usize)
.choose_multiple(&mut thread_rng(), 4)
.into_iter()
.for_each(|i| missing_segments.set(i, true));
FrameInfo {
frame_id: i,
missing_segments,
total_segments: 8,
last_update: SystemTime::UNIX_EPOCH,
}
})
.collect::<Vec<_>>();
let mut req = SegmentRequest::<466>::from_iter(frames.clone())
.into_iter()
.collect::<Vec<_>>();
req.sort();
assert_eq!(frames.len() * 4, req.len());
assert_eq!(
req,
frames
.into_iter()
.flat_map(|f| f.into_missing_segments())
.collect::<Vec<_>>()
);
}
#[test]
fn session_message_segment_should_serialize_and_deserialize() -> anyhow::Result<()> {
const SEG_SIZE: usize = 8;
let mut segments = Frame {
frame_id: 10,
data: hex!("deadbeefcafebabe").into(),
}
.segment(SEG_SIZE)?;
const MTU: usize = SEG_SIZE + Segment::HEADER_SIZE + 2;
let msg_1 = SessionMessage::<MTU>::Segment(segments.pop().unwrap());
let data = Vec::from(msg_1.clone());
let msg_2 = SessionMessage::try_from(&data[..])?;
assert_eq!(msg_1, msg_2);
Ok(())
}
#[test]
fn session_message_segment_request_should_serialize_and_deserialize() -> anyhow::Result<()> {
let frame_info = FrameInfo {
frame_id: 10,
total_segments: 255,
missing_segments: bitarr![1; 256],
last_update: SystemTime::now(),
};
let msg_1 = SessionMessage::<466>::Request(SegmentRequest::from_iter(vec![frame_info]));
let data = Vec::from(msg_1.clone());
let msg_2 = SessionMessage::try_from(&data[..])?;
assert_eq!(msg_1, msg_2);
match msg_1 {
SessionMessage::Request(r) => {
let missing_segments = r.into_iter().collect::<Vec<_>>();
let expected = (0..=7).map(|s| SegmentId(10, s)).collect::<Vec<_>>();
assert_eq!(expected, missing_segments);
}
_ => panic!("invalid type"),
}
Ok(())
}
#[test]
fn session_message_ack_should_serialize_and_deserialize() -> anyhow::Result<()> {
let mut rng = thread_rng();
let frame_ids: Vec<u32> = (0..500).map(|_| rng.gen()).collect();
let msg_1 = SessionMessage::<466>::Acknowledge(frame_ids.into());
let data = Vec::from(msg_1.clone());
let msg_2 = SessionMessage::try_from(&data[..])?;
assert_eq!(msg_1, msg_2);
Ok(())
}
#[test]
fn session_message_segment_request_should_yield_correct_bitset_values() {
let seg_req = SegmentRequest::<466>([(10, 0b00100100)].into());
let mut iter = seg_req.into_iter();
assert_eq!(iter.next(), Some(SegmentId(10, 2)));
assert_eq!(iter.next(), Some(SegmentId(10, 5)));
assert_eq!(iter.next(), None);
let mut frame_info = FrameInfo {
frame_id: 10,
missing_segments: bitarr![0; 256],
total_segments: 10,
last_update: current_time(),
};
frame_info.missing_segments.set(2, true);
frame_info.missing_segments.set(5, true);
let mut iter = SegmentRequest::<466>::from_iter(vec![frame_info]).into_iter();
assert_eq!(iter.next(), Some(SegmentId(10, 2)));
assert_eq!(iter.next(), Some(SegmentId(10, 5)));
assert_eq!(iter.next(), None);
}
#[test]
fn session_message_iter_should_be_empty_if_slice_has_no_messages() {
const MTU: usize = 462;
let mut iter = SessionMessageIter::<MTU>::from(Vec::<u8>::new());
assert!(iter.next().is_none());
assert!(iter.is_done());
let mut iter = SessionMessageIter::<MTU>::from(&[0u8; 2]);
assert!(iter.next().is_none());
assert!(iter.is_done());
}
#[test]
fn session_message_iter_should_deserialize_multiple_messages() -> anyhow::Result<()> {
const MTU: usize = 462;
let mut messages_1 = Frame {
frame_id: 10,
data: hopr_crypto_random::random_bytes::<1500>().into(),
}
.segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
.into_iter()
.map(|s| SessionMessage::<MTU>::Segment(s))
.collect::<Vec<_>>();
let frame_info = FrameInfo {
frame_id: 10,
total_segments: 255,
missing_segments: bitarr![1; 256],
last_update: SystemTime::now(),
};
messages_1.push(SessionMessage::<MTU>::Request(SegmentRequest::from_iter(vec![
frame_info,
])));
let mut rng = thread_rng();
let frame_ids: Vec<u32> = (0..100).map(|_| rng.gen()).collect();
messages_1.push(SessionMessage::<MTU>::Acknowledge(frame_ids.into()));
let iter = SessionMessageIter::<MTU>::from(
messages_1
.iter()
.cloned()
.map(|m| m.into_encoded().into_vec())
.flatten()
.chain(std::iter::repeat(0).take(10))
.collect::<Vec<u8>>(),
);
let messages_2 = iter.collect::<Result<Vec<_>, _>>()?;
assert_eq!(messages_1, messages_2);
Ok(())
}
#[test]
fn session_message_iter_should_not_contain_error_when_consuming_everything() -> anyhow::Result<()> {
const MTU: usize = 462;
let messages = Frame {
frame_id: 10,
data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
}
.segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
.into_iter()
.map(|s| SessionMessage::<MTU>::Segment(s))
.collect::<Vec<_>>();
assert_eq!(4, messages.len());
let data = messages
.iter()
.cloned()
.map(|m| m.into_encoded().into_vec())
.flatten()
.chain(std::iter::repeat(0u8).take(10))
.collect::<Vec<_>>();
let mut iter = SessionMessageIter::<MTU>::from(data);
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[2]));
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[3]));
assert!(iter.next().is_none());
assert!(iter.last_error().is_none());
assert!(iter.is_done());
Ok(())
}
#[test]
fn session_message_iter_should_not_yield_more_after_error() -> anyhow::Result<()> {
const MTU: usize = 462;
let messages = Frame {
frame_id: 10,
data: hopr_crypto_random::random_bytes::<{ 3 * MTU }>().into(),
}
.segment(MTU - SessionMessage::<MTU>::HEADER_SIZE - Segment::HEADER_SIZE)?
.into_iter()
.map(|s| SessionMessage::<MTU>::Segment(s))
.collect::<Vec<_>>();
assert_eq!(4, messages.len());
let data = messages
.iter()
.cloned()
.enumerate()
.map(|(i, m)| {
if i == 2 {
Vec::from(hopr_crypto_random::random_bytes::<MTU>())
} else {
m.into_encoded().into_vec()
}
})
.flatten()
.collect::<Vec<_>>();
let mut iter = SessionMessageIter::<MTU>::from(data);
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[0]));
assert!(matches!(iter.next(), Some(Ok(m)) if m == messages[1]));
let err = iter.next();
assert!(matches!(err, Some(Err(_))));
assert!(iter.is_done());
assert!(iter.last_error().is_some());
assert!(iter.next().is_none());
Ok(())
}
}