hopr_transport_protocol/ack/
codec.rs

1use std::marker::PhantomData;
2
3use tokio_util::bytes::{Buf, BufMut, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use serde::{Deserialize, Serialize};
7
8/// A codec for CBOR encoding and decoding using serde_cbor
9///
10/// Inspired by the [asynchronous_codec](`https://docs.rs/asynchronous-codec/latest/asynchronous_codec/`) crate
11/// but better fitting this codebase.
12///
13/// TODO: Replace with cbor4ii
14
15#[derive(Debug, thiserror::Error)]
16pub enum CborCodecError {
17    #[error("IO error: {0}")]
18    Io(#[from] std::io::Error),
19
20    #[error("CBOR error: {0}")]
21    Cbor(#[from] serde_cbor::Error),
22}
23
24#[derive(Debug, Clone, Default, PartialEq)]
25pub struct CborCodec<T: Serialize + for<'de> Deserialize<'de>> {
26    _phantom: PhantomData<T>,
27}
28
29impl<T: Serialize + for<'de> Deserialize<'de>> CborCodec<T> {
30    /// Creates a new `CborCodec` with the associated types
31    pub fn new() -> Self {
32        Self { _phantom: PhantomData }
33    }
34}
35
36/// Decode the type from bytes
37impl<T: Serialize + for<'de> Deserialize<'de>> Decoder for CborCodec<T> {
38    type Item = T;
39    type Error = CborCodecError;
40
41    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
42        let mut de = serde_cbor::Deserializer::from_slice(buf);
43
44        let res: Result<T, _> = serde::de::Deserialize::deserialize(&mut de);
45
46        let item = match res {
47            Ok(item) => item,
48            Err(e) => {
49                if e.is_eof() {
50                    return Ok(None);
51                } else {
52                    return Err(e.into());
53                }
54            }
55        };
56
57        let offset = de.byte_offset();
58
59        buf.advance(offset);
60
61        Ok(Some(item))
62    }
63}
64
65/// Encoder impl encodes object streams to bytes
66impl<T: Serialize + for<'de> Deserialize<'de>> Encoder<T> for CborCodec<T> {
67    type Error = CborCodecError;
68
69    fn encode(&mut self, data: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
70        let j = serde_cbor::to_vec(&data)?;
71
72        buf.reserve(j.len());
73        buf.put_slice(&j);
74
75        Ok(())
76    }
77}
78
79#[cfg(test)]
80mod test {
81    use serde::{Deserialize, Serialize};
82
83    use super::CborCodec;
84    use super::{BytesMut, Decoder, Encoder};
85
86    #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
87    struct TestStruct {
88        pub name: String,
89        pub data: u16,
90    }
91
92    #[test]
93    fn cbor_codec_encode_decode_is_reversible() {
94        let mut codec = CborCodec::<TestStruct>::new();
95        let mut buff = BytesMut::new();
96
97        let item1 = TestStruct {
98            name: "Test name".to_owned(),
99            data: 16,
100        };
101        codec.encode(item1.clone(), &mut buff).unwrap();
102
103        let item2 = codec.decode(&mut buff).unwrap().unwrap();
104        assert_eq!(item1, item2);
105
106        assert_eq!(codec.decode(&mut buff).unwrap(), None);
107
108        assert_eq!(buff.len(), 0);
109    }
110
111    #[test]
112    fn cbor_codec_partial_decode() {
113        let mut codec = CborCodec::<TestStruct>::new();
114        let mut buff = BytesMut::new();
115
116        let item1 = TestStruct {
117            name: "Test name".to_owned(),
118            data: 34,
119        };
120        codec.encode(item1, &mut buff).unwrap();
121
122        let mut start = buff.clone().split_to(4);
123        assert_eq!(codec.decode(&mut start).unwrap(), None);
124
125        codec.decode(&mut buff).unwrap().unwrap();
126
127        assert_eq!(buff.len(), 0);
128    }
129
130    #[test]
131    fn cbor_codec_eof_reached() {
132        let mut codec = CborCodec::<TestStruct>::new();
133        let mut buff = BytesMut::new();
134
135        let item1 = TestStruct {
136            name: "Test name".to_owned(),
137            data: 34,
138        };
139        codec.encode(item1.clone(), &mut buff).unwrap();
140
141        // Split the buffer into two.
142        let mut buff_start = buff.clone().split_to(4);
143        let buff_end = buff.clone().split_off(4);
144
145        // Attempt to decode the first half of the buffer. This should return `Ok(None)` and not
146        // advance the buffer.
147        assert_eq!(codec.decode(&mut buff_start).unwrap(), None);
148        assert_eq!(buff_start.len(), 4);
149
150        // Combine the buffer back together.
151        buff_start.extend(buff_end.iter());
152
153        // It should now decode successfully.
154        let item2 = codec.decode(&mut buff).unwrap().unwrap();
155        assert_eq!(item1, item2);
156    }
157
158    #[test]
159    fn cbor_codec_decode_error() {
160        let mut codec = CborCodec::<TestStruct>::new();
161        let mut buff = BytesMut::new();
162
163        let item1 = TestStruct {
164            name: "Test name".to_owned(),
165            data: 34,
166        };
167        codec.encode(item1.clone(), &mut buff).unwrap();
168
169        // Split the end off the buffer.
170        let mut buff_end = buff.clone().split_off(4);
171        let buff_end_length = buff_end.len();
172
173        // Attempting to decode should return an error.
174        assert!(codec.decode(&mut buff_end).is_err());
175        assert_eq!(buff_end.len(), buff_end_length);
176    }
177}