1use crate::errors::TransportSessionError;
4use crate::types::SessionTarget;
5use crate::Capability;
6use hopr_crypto_types::prelude::PeerId;
7use hopr_internal_types::prelude::ApplicationData;
8use hopr_network_types::prelude::RoutingOptions;
9use std::collections::HashSet;
10
11pub type StartChallenge = u64;
13
14#[repr(u8)]
16#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub enum StartErrorReason {
19 NoSlotsAvailable,
21 Busy,
23}
24
25#[derive(Debug, Copy, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28pub struct StartErrorType {
29 pub challenge: StartChallenge,
31 pub reason: StartErrorReason,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38pub struct StartInitiation {
39 pub challenge: StartChallenge,
41 pub target: SessionTarget,
43 pub capabilities: HashSet<Capability>,
45 pub back_routing: Option<(RoutingOptions, PeerId)>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55pub struct StartEstablished<T> {
56 pub orig_challenge: StartChallenge,
58 pub session_id: T,
60}
61
62#[cfg_attr(doc, aquamarine::aquamarine)]
63#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
85#[strum_discriminants(vis(pub(crate)))]
87#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
88pub enum StartProtocol<T> {
89 StartSession(StartInitiation),
91 SessionEstablished(StartEstablished<T>),
93 SessionError(StartErrorType),
95 CloseSession(T),
97}
98
99const SESSION_BINCODE_CONFIGURATION: bincode::config::Configuration = bincode::config::standard()
100 .with_little_endian()
101 .with_variable_int_encoding();
102
103#[cfg(feature = "serde")]
104impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> StartProtocol<T> {
105 pub fn encode(self) -> crate::errors::Result<(u16, Box<[u8]>)> {
108 let disc = StartProtocolDiscriminants::from(&self) as u8 + 1;
109 let inner = match self {
110 StartProtocol::StartSession(init) => bincode::serde::encode_to_vec(&init, SESSION_BINCODE_CONFIGURATION),
111 StartProtocol::SessionEstablished(est) => {
112 bincode::serde::encode_to_vec(&est, SESSION_BINCODE_CONFIGURATION)
113 }
114 StartProtocol::SessionError(err) => bincode::serde::encode_to_vec(err, SESSION_BINCODE_CONFIGURATION),
115 StartProtocol::CloseSession(id) => bincode::serde::encode_to_vec(&id, SESSION_BINCODE_CONFIGURATION),
116 }?;
117
118 Ok((disc as u16, inner.into_boxed_slice()))
119 }
120
121 pub fn decode(tag: u16, data: &[u8]) -> crate::errors::Result<Self> {
124 if tag == 0 {
125 return Err(TransportSessionError::Tag);
126 }
127
128 match StartProtocolDiscriminants::from_repr(tag as u8 - 1).ok_or(TransportSessionError::PayloadSize)? {
129 StartProtocolDiscriminants::StartSession => Ok(StartProtocol::StartSession(
130 bincode::serde::borrow_decode_from_slice(data, SESSION_BINCODE_CONFIGURATION).map(|(v, _bytes)| v)?,
131 )),
132 StartProtocolDiscriminants::SessionEstablished => Ok(StartProtocol::SessionEstablished(
133 bincode::serde::borrow_decode_from_slice(data, SESSION_BINCODE_CONFIGURATION).map(|(v, _bytes)| v)?,
134 )),
135 StartProtocolDiscriminants::SessionError => Ok(StartProtocol::SessionError(
136 bincode::serde::borrow_decode_from_slice(data, SESSION_BINCODE_CONFIGURATION).map(|(v, _bytes)| v)?,
137 )),
138 StartProtocolDiscriminants::CloseSession => Ok(StartProtocol::CloseSession(
139 bincode::serde::borrow_decode_from_slice(data, SESSION_BINCODE_CONFIGURATION).map(|(v, _bytes)| v)?,
140 )),
141 }
142 }
143}
144
145impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<StartProtocol<T>> for ApplicationData {
146 type Error = TransportSessionError;
147
148 fn try_from(value: StartProtocol<T>) -> Result<Self, Self::Error> {
149 let (tag, plain_text) = value.encode()?;
150 Ok(ApplicationData {
151 application_tag: Some(tag),
152 plain_text,
153 })
154 }
155}
156
157impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<ApplicationData> for StartProtocol<T> {
158 type Error = TransportSessionError;
159
160 fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
161 Self::decode(
162 value.application_tag.ok_or(TransportSessionError::Tag)?,
163 &value.plain_text,
164 )
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::SessionId;
172 use hopr_internal_types::prelude::PAYLOAD_SIZE;
173 use hopr_network_types::prelude::SealedHost;
174
175 #[cfg(feature = "serde")]
176 #[test]
177 fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
178 let msg_1 = StartProtocol::<i32>::StartSession(StartInitiation {
179 challenge: 0,
180 target: SessionTarget::TcpStream(SealedHost::Plain("127.0.0.1:1234".parse()?)),
181 capabilities: Default::default(),
182 back_routing: Some((
183 RoutingOptions::IntermediatePath(vec![PeerId::random()].try_into()?),
184 PeerId::random(),
185 )),
186 });
187
188 let (tag, msg) = msg_1.clone().encode()?;
189 assert_eq!(1, tag);
190
191 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
192
193 assert_eq!(msg_1, msg_2);
194 Ok(())
195 }
196
197 #[cfg(feature = "serde")]
198 #[test]
199 fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
200 let msg_1 = StartProtocol::<i32>::SessionEstablished(StartEstablished {
201 orig_challenge: 0,
202 session_id: 10,
203 });
204
205 let (tag, msg) = msg_1.clone().encode()?;
206 assert_eq!(2, tag);
207
208 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
209
210 assert_eq!(msg_1, msg_2);
211 Ok(())
212 }
213
214 #[cfg(feature = "serde")]
215 #[test]
216 fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
217 let msg_1 = StartProtocol::<i32>::SessionError(StartErrorType {
218 challenge: 10,
219 reason: StartErrorReason::NoSlotsAvailable,
220 });
221
222 let (tag, msg) = msg_1.clone().encode()?;
223 assert_eq!(3, tag);
224
225 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
226
227 assert_eq!(msg_1, msg_2);
228 Ok(())
229 }
230
231 #[cfg(feature = "serde")]
232 #[test]
233 fn start_protocol_close_session_message_should_encode_and_decode() -> anyhow::Result<()> {
234 let msg_1 = StartProtocol::<i32>::CloseSession(10);
235
236 let (tag, msg) = msg_1.clone().encode()?;
237 assert_eq!(4, tag);
238
239 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
240
241 assert_eq!(msg_1, msg_2);
242 Ok(())
243 }
244
245 #[cfg(feature = "serde")]
246 #[test]
247 fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
248 let msg = StartProtocol::<i32>::StartSession(StartInitiation {
249 challenge: StartChallenge::MAX,
250 target: SessionTarget::TcpStream(SealedHost::Plain(
251 "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530".parse()?,
252 )),
253 capabilities: HashSet::from_iter([Capability::Retransmission, Capability::Segmentation]),
254 back_routing: Some((
255 RoutingOptions::IntermediatePath(
256 vec![PeerId::random(), PeerId::random(), PeerId::random()].try_into()?,
257 ),
258 PeerId::random(),
259 )),
260 });
261
262 assert!(
263 msg.encode()?.1.len() <= PAYLOAD_SIZE,
264 "StartSession must fit within {PAYLOAD_SIZE}"
265 );
266
267 let msg = StartProtocol::SessionEstablished(StartEstablished {
268 orig_challenge: StartChallenge::MAX,
269 session_id: SessionId::new(u16::MAX, PeerId::random()),
270 });
271
272 assert!(
273 msg.encode()?.1.len() <= PAYLOAD_SIZE,
274 "SessionEstablished must fit within {PAYLOAD_SIZE}"
275 );
276
277 let msg = StartProtocol::<i32>::SessionError(StartErrorType {
278 challenge: StartChallenge::MAX,
279 reason: StartErrorReason::NoSlotsAvailable,
280 });
281
282 assert!(
283 msg.encode()?.1.len() <= PAYLOAD_SIZE,
284 "SessionError must fit within {PAYLOAD_SIZE}"
285 );
286
287 let msg = StartProtocol::CloseSession(SessionId::new(u16::MAX, PeerId::random()));
288 assert!(
289 msg.encode()?.1.len() <= PAYLOAD_SIZE,
290 "CloseSession must fit within {PAYLOAD_SIZE}"
291 );
292
293 Ok(())
294 }
295}