hopr_transport_protocol/ack/
codec.rs1use std::marker::PhantomData;
2
3use tokio_util::bytes::{Buf, BufMut, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, thiserror::Error)]
16pub enum CborCodecError {
17 #[error("IO error: {0}")]
18 Io(#[from] std::io::Error),
19
20 #[error("CBOR error: {0}")]
21 Cbor(#[from] serde_cbor::Error),
22}
23
24#[derive(Debug, Clone, Default, PartialEq)]
25pub struct CborCodec<T: Serialize + for<'de> Deserialize<'de>> {
26 _phantom: PhantomData<T>,
27}
28
29impl<T: Serialize + for<'de> Deserialize<'de>> CborCodec<T> {
30 pub fn new() -> Self {
32 Self { _phantom: PhantomData }
33 }
34}
35
36impl<T: Serialize + for<'de> Deserialize<'de>> Decoder for CborCodec<T> {
38 type Item = T;
39 type Error = CborCodecError;
40
41 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
42 let mut de = serde_cbor::Deserializer::from_slice(buf);
43
44 let res: Result<T, _> = serde::de::Deserialize::deserialize(&mut de);
45
46 let item = match res {
47 Ok(item) => item,
48 Err(e) => {
49 if e.is_eof() {
50 return Ok(None);
51 } else {
52 return Err(e.into());
53 }
54 }
55 };
56
57 let offset = de.byte_offset();
58
59 buf.advance(offset);
60
61 Ok(Some(item))
62 }
63}
64
65impl<T: Serialize + for<'de> Deserialize<'de>> Encoder<T> for CborCodec<T> {
67 type Error = CborCodecError;
68
69 fn encode(&mut self, data: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
70 let j = serde_cbor::to_vec(&data)?;
71
72 buf.reserve(j.len());
73 buf.put_slice(&j);
74
75 Ok(())
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use serde::{Deserialize, Serialize};
82
83 use super::CborCodec;
84 use super::{BytesMut, Decoder, Encoder};
85
86 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
87 struct TestStruct {
88 pub name: String,
89 pub data: u16,
90 }
91
92 #[test]
93 fn cbor_codec_encode_decode_is_reversible() {
94 let mut codec = CborCodec::<TestStruct>::new();
95 let mut buff = BytesMut::new();
96
97 let item1 = TestStruct {
98 name: "Test name".to_owned(),
99 data: 16,
100 };
101 codec.encode(item1.clone(), &mut buff).unwrap();
102
103 let item2 = codec.decode(&mut buff).unwrap().unwrap();
104 assert_eq!(item1, item2);
105
106 assert_eq!(codec.decode(&mut buff).unwrap(), None);
107
108 assert_eq!(buff.len(), 0);
109 }
110
111 #[test]
112 fn cbor_codec_partial_decode() {
113 let mut codec = CborCodec::<TestStruct>::new();
114 let mut buff = BytesMut::new();
115
116 let item1 = TestStruct {
117 name: "Test name".to_owned(),
118 data: 34,
119 };
120 codec.encode(item1, &mut buff).unwrap();
121
122 let mut start = buff.clone().split_to(4);
123 assert_eq!(codec.decode(&mut start).unwrap(), None);
124
125 codec.decode(&mut buff).unwrap().unwrap();
126
127 assert_eq!(buff.len(), 0);
128 }
129
130 #[test]
131 fn cbor_codec_eof_reached() {
132 let mut codec = CborCodec::<TestStruct>::new();
133 let mut buff = BytesMut::new();
134
135 let item1 = TestStruct {
136 name: "Test name".to_owned(),
137 data: 34,
138 };
139 codec.encode(item1.clone(), &mut buff).unwrap();
140
141 let mut buff_start = buff.clone().split_to(4);
143 let buff_end = buff.clone().split_off(4);
144
145 assert_eq!(codec.decode(&mut buff_start).unwrap(), None);
148 assert_eq!(buff_start.len(), 4);
149
150 buff_start.extend(buff_end.iter());
152
153 let item2 = codec.decode(&mut buff).unwrap().unwrap();
155 assert_eq!(item1, item2);
156 }
157
158 #[test]
159 fn cbor_codec_decode_error() {
160 let mut codec = CborCodec::<TestStruct>::new();
161 let mut buff = BytesMut::new();
162
163 let item1 = TestStruct {
164 name: "Test name".to_owned(),
165 data: 34,
166 };
167 codec.encode(item1.clone(), &mut buff).unwrap();
168
169 let mut buff_end = buff.clone().split_off(4);
171 let buff_end_length = buff_end.len();
172
173 assert!(codec.decode(&mut buff_end).is_err());
175 assert_eq!(buff_end.len(), buff_end_length);
176 }
177}