1use hopr_transport_packet::prelude::{ApplicationData, ReservedTag, Tag};
4
5use crate::{Capabilities, errors::TransportSessionError, types::SessionTarget};
6
7pub type StartChallenge = u64;
9
10#[repr(u8)]
12#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14pub enum StartErrorReason {
15 NoSlotsAvailable,
17 Busy,
19}
20
21#[derive(Debug, Copy, Clone, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct StartErrorType {
25 pub challenge: StartChallenge,
27 pub reason: StartErrorReason,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct StartInitiation {
35 pub challenge: StartChallenge,
37 pub target: SessionTarget,
39 pub capabilities: Capabilities,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct StartEstablished<T> {
47 pub orig_challenge: StartChallenge,
49 pub session_id: T,
51}
52
53#[cfg_attr(doc, aquamarine::aquamarine)]
54#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
78#[strum_discriminants(vis(pub(crate)))]
79#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
80pub enum StartProtocol<T> {
81 StartSession(StartInitiation),
83 SessionEstablished(StartEstablished<T>),
85 SessionError(StartErrorType),
87 CloseSession(T),
89 KeepAlive(KeepAliveMessage<T>),
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct KeepAliveMessage<T> {
96 pub id: T,
98 pub flags: u8,
100}
101
102impl<T> From<T> for KeepAliveMessage<T> {
103 fn from(value: T) -> Self {
104 Self { id: value, flags: 0 }
105 }
106}
107
108impl<T> StartProtocol<T> {
109 pub(crate) const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
110 const START_PROTOCOL_VERSION: u8 = 0x01;
111}
112
113#[cfg(feature = "serde")]
115impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> StartProtocol<T> {
116 const SESSION_BINCODE_CONFIGURATION: bincode::config::Configuration = bincode::config::standard()
117 .with_little_endian()
118 .with_variable_int_encoding();
119
120 pub fn encode(self) -> crate::errors::Result<(Tag, Box<[u8]>)> {
123 let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
124 out.push(Self::START_PROTOCOL_VERSION);
125 out.push(StartProtocolDiscriminants::from(&self) as u8);
126
127 match self {
128 StartProtocol::StartSession(init) => {
129 bincode::serde::encode_into_std_write(&init, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
130 }
131 StartProtocol::SessionEstablished(est) => {
132 bincode::serde::encode_into_std_write(&est, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
133 }
134 StartProtocol::SessionError(err) => {
135 bincode::serde::encode_into_std_write(err, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
136 }
137 StartProtocol::CloseSession(id) => {
138 bincode::serde::encode_into_std_write(&id, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
139 }
140 StartProtocol::KeepAlive(msg) => {
141 bincode::serde::encode_into_std_write(&msg, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
142 }
143 }?;
144
145 Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
146 }
147
148 pub fn decode(tag: Tag, data: &[u8]) -> crate::errors::Result<Self> {
151 if tag != Self::START_PROTOCOL_MESSAGE_TAG {
152 return Err(TransportSessionError::StartProtocolError("unknown message tag".into()));
153 }
154
155 if data.len() < 3 {
156 return Err(TransportSessionError::StartProtocolError("message too short".into()));
157 }
158
159 if data[0] != Self::START_PROTOCOL_VERSION {
160 return Err(TransportSessionError::StartProtocolError(
161 "unknown message version".into(),
162 ));
163 }
164
165 match StartProtocolDiscriminants::from_repr(data[1])
166 .ok_or(TransportSessionError::StartProtocolError("unknown message".into()))?
167 {
168 StartProtocolDiscriminants::StartSession => Ok(StartProtocol::StartSession(
169 bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
170 .map(|(v, _bytes)| v)?,
171 )),
172 StartProtocolDiscriminants::SessionEstablished => Ok(StartProtocol::SessionEstablished(
173 bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
174 .map(|(v, _bytes)| v)?,
175 )),
176 StartProtocolDiscriminants::SessionError => Ok(StartProtocol::SessionError(
177 bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
178 .map(|(v, _bytes)| v)?,
179 )),
180 StartProtocolDiscriminants::CloseSession => Ok(StartProtocol::CloseSession(
181 bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
182 .map(|(v, _bytes)| v)?,
183 )),
184 StartProtocolDiscriminants::KeepAlive => Ok(StartProtocol::KeepAlive(
185 bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
186 .map(|(v, _bytes)| v)?,
187 )),
188 }
189 }
190}
191
192#[cfg(not(feature = "serde"))]
193impl<T> StartProtocol<T> {
194 pub fn encode(self) -> crate::errors::Result<(u16, Box<[u8]>)> {
195 unimplemented!()
196 }
197
198 pub fn decode(_tag: u16, _data: &[u8]) -> crate::errors::Result<Self> {
199 unimplemented!()
200 }
201}
202
203#[cfg(feature = "serde")]
204impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<StartProtocol<T>> for ApplicationData {
205 type Error = TransportSessionError;
206
207 fn try_from(value: StartProtocol<T>) -> Result<Self, Self::Error> {
208 let (application_tag, plain_text) = value.encode()?;
209 Ok(ApplicationData {
210 application_tag,
211 plain_text,
212 })
213 }
214}
215
216#[cfg(not(feature = "serde"))]
217impl<T> TryFrom<StartProtocol<T>> for ApplicationData {
218 type Error = TransportSessionError;
219
220 fn try_from(value: StartProtocol<T>) -> Result<Self, Self::Error> {
221 let (application_tag, plain_text) = value.encode()?;
222 Ok(ApplicationData {
223 application_tag,
224 plain_text,
225 })
226 }
227}
228
229#[cfg(feature = "serde")]
230impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<ApplicationData> for StartProtocol<T> {
231 type Error = TransportSessionError;
232
233 fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
234 Self::decode(value.application_tag, &value.plain_text)
235 }
236}
237
238#[cfg(not(feature = "serde"))]
239impl<T> TryFrom<ApplicationData> for StartProtocol<T> {
240 type Error = TransportSessionError;
241
242 fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
243 Self::decode(value.application_tag, &value.plain_text)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use hopr_crypto_packet::prelude::HoprPacket;
250 use hopr_crypto_random::Randomizable;
251 use hopr_internal_types::prelude::HoprPseudonym;
252 use hopr_network_types::prelude::SealedHost;
253 use hopr_transport_packet::prelude::Tag;
254
255 use super::*;
256 use crate::{Capability, SessionId};
257
258 #[cfg(feature = "serde")]
259 #[test]
260 fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
261 let msg_1 = StartProtocol::<i32>::StartSession(StartInitiation {
262 challenge: 0,
263 target: SessionTarget::TcpStream(SealedHost::Plain("127.0.0.1:1234".parse()?)),
264 capabilities: Default::default(),
265 });
266
267 let (tag, msg) = msg_1.clone().encode()?;
268 let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
269 assert_eq!(tag, expected);
270
271 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
272
273 assert_eq!(msg_1, msg_2);
274 Ok(())
275 }
276
277 #[test]
278 fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
279 let msg = StartProtocol::<SessionId>::StartSession(StartInitiation {
280 challenge: 0,
281 target: SessionTarget::TcpStream(SealedHost::Plain("127.0.0.1:1234".parse()?)),
282 capabilities: Default::default(),
283 });
284
285 let len = msg.encode()?.1.len();
286 assert!(
287 HoprPacket::max_surbs_with_message(len) >= 1,
288 "KeepAlive message size ({}) must allow for at least 1 SURBs in packet",
289 len,
290 );
291
292 Ok(())
293 }
294
295 #[cfg(feature = "serde")]
296 #[test]
297 fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
298 let msg_1 = StartProtocol::<i32>::SessionEstablished(StartEstablished {
299 orig_challenge: 0,
300 session_id: 10,
301 });
302
303 let (tag, msg) = msg_1.clone().encode()?;
304 let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
305 assert_eq!(tag, expected);
306
307 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
308
309 assert_eq!(msg_1, msg_2);
310 Ok(())
311 }
312
313 #[cfg(feature = "serde")]
314 #[test]
315 fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
316 let msg_1 = StartProtocol::<i32>::SessionError(StartErrorType {
317 challenge: 10,
318 reason: StartErrorReason::NoSlotsAvailable,
319 });
320
321 let (tag, msg) = msg_1.clone().encode()?;
322 let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
323 assert_eq!(tag, expected);
324
325 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
326
327 assert_eq!(msg_1, msg_2);
328 Ok(())
329 }
330
331 #[cfg(feature = "serde")]
332 #[test]
333 fn start_protocol_close_session_message_should_encode_and_decode() -> anyhow::Result<()> {
334 let msg_1 = StartProtocol::<i32>::CloseSession(10);
335
336 let (tag, msg) = msg_1.clone().encode()?;
337 let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
338 assert_eq!(tag, expected);
339
340 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
341
342 assert_eq!(msg_1, msg_2);
343 Ok(())
344 }
345
346 #[cfg(feature = "serde")]
347 #[test]
348 fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
349 let msg_1 = StartProtocol::<i32>::KeepAlive(10.into());
350
351 let (tag, msg) = msg_1.clone().encode()?;
352 let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
353 assert_eq!(tag, expected);
354
355 let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
356
357 assert_eq!(msg_1, msg_2);
358 Ok(())
359 }
360
361 #[cfg(feature = "serde")]
362 #[test]
363 fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
364 let msg = StartProtocol::<SessionId>::StartSession(StartInitiation {
365 challenge: StartChallenge::MAX,
366 target: SessionTarget::TcpStream(SealedHost::Plain(
367 "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530".parse()?,
368 )),
369 capabilities: Capability::RetransmissionAck | Capability::RetransmissionNack | Capability::Segmentation,
370 });
371
372 assert!(
373 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
374 "StartSession must fit within {}",
375 HoprPacket::PAYLOAD_SIZE
376 );
377
378 let msg = StartProtocol::SessionEstablished(StartEstablished {
379 orig_challenge: StartChallenge::MAX,
380 session_id: SessionId::new(Tag::MAX, HoprPseudonym::random()),
381 });
382
383 assert!(
384 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
385 "SessionEstablished must fit within {}",
386 HoprPacket::PAYLOAD_SIZE
387 );
388
389 let msg = StartProtocol::<i32>::SessionError(StartErrorType {
390 challenge: StartChallenge::MAX,
391 reason: StartErrorReason::NoSlotsAvailable,
392 });
393
394 assert!(
395 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
396 "SessionError must fit within {}",
397 HoprPacket::PAYLOAD_SIZE
398 );
399
400 let msg = StartProtocol::CloseSession(SessionId::new(Tag::MAX, HoprPseudonym::random()));
401 assert!(
402 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
403 "CloseSession must fit within {}",
404 HoprPacket::PAYLOAD_SIZE
405 );
406
407 let msg = StartProtocol::KeepAlive(SessionId::new(Tag::MAX, HoprPseudonym::random()).into());
408 assert!(
409 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
410 "KeepAlive must fit within {}",
411 HoprPacket::PAYLOAD_SIZE
412 );
413
414 Ok(())
415 }
416
417 #[test]
418 fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
419 let msg = StartProtocol::KeepAlive(SessionId::new(Tag::MAX, HoprPseudonym::random()).into());
420 let len = msg.encode()?.1.len();
421 assert!(
422 HoprPacket::max_surbs_with_message(len) >= HoprPacket::MAX_SURBS_IN_PACKET,
423 "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
424 len,
425 HoprPacket::MAX_SURBS_IN_PACKET
426 );
427
428 Ok(())
429 }
430}