1pub mod errors;
19
20use hopr_crypto_packet::prelude::HoprPacket;
21use hopr_protocol_app::prelude::{ApplicationData, ReservedTag, Tag};
22
23use crate::errors::StartProtocolError;
24
25pub type StartChallenge = u64;
27
28#[repr(u8)]
30#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display, strum::FromRepr)]
31pub enum StartErrorReason {
32 Unknown = 0,
34 NoSlotsAvailable = 1,
36 Busy = 2,
38}
39
40#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub struct StartErrorType {
43 pub challenge: StartChallenge,
45 pub reason: StartErrorReason,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct StartInitiation<T, C> {
59 pub challenge: StartChallenge,
61 pub target: T,
63 pub capabilities: C,
65 pub additional_data: u32,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct StartEstablished<I> {
75 pub orig_challenge: StartChallenge,
77 pub session_id: I,
79}
80
81#[cfg_attr(doc, aquamarine::aquamarine)]
82#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
106#[strum_discriminants(vis(pub))]
107#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
108pub enum StartProtocol<I, T, C> {
109 StartSession(StartInitiation<T, C>),
111 SessionEstablished(StartEstablished<I>),
113 SessionError(StartErrorType),
115 KeepAlive(KeepAliveMessage<I>),
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct KeepAliveMessage<I> {
122 pub session_id: I,
124 pub flags: KeepAliveFlags,
126 pub additional_data: u64,
128}
129
130pub type KeepAliveFlags = flagset::FlagSet<KeepAliveFlag>;
134
135flagset::flags! {
136 pub enum KeepAliveFlag: u8 {
138 BalancerTarget = 0x01,
145 BalancerState = 0x02,
152 }
153}
154
155impl<I> KeepAliveMessage<I> {
156 pub const MIN_SURBS_PER_MESSAGE: usize = HoprPacket::MAX_SURBS_IN_PACKET;
158}
159
160impl<I> From<I> for KeepAliveMessage<I> {
161 fn from(value: I) -> Self {
162 Self {
163 session_id: value,
164 flags: None.into(),
165 additional_data: 0,
166 }
167 }
168}
169
170impl<I, T, C> StartProtocol<I, T, C> {
171 pub const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
173 pub const START_PROTOCOL_VERSION: u8 = 0x02;
175}
176
177impl<I, T, C> StartProtocol<I, T, C>
178where
179 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
180 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
181 C: Into<u8> + TryFrom<u8>,
182{
183 pub fn encode(self) -> errors::Result<(Tag, Box<[u8]>)> {
185 let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
186 out.push(Self::START_PROTOCOL_VERSION);
187 out.push(StartProtocolDiscriminants::from(&self) as u8);
188
189 let mut data = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE - 2);
190 match self {
191 StartProtocol::StartSession(init) => {
192 data.extend_from_slice(&init.challenge.to_be_bytes());
193 data.push(init.capabilities.into());
194 data.extend_from_slice(&init.additional_data.to_be_bytes());
195 let target = serde_cbor_2::to_vec(&init.target)?;
196 data.extend_from_slice(&target);
197 }
198 StartProtocol::SessionEstablished(est) => {
199 data.extend_from_slice(&est.orig_challenge.to_be_bytes());
200 let session_id = serde_cbor_2::to_vec(&est.session_id)?;
201 data.extend(session_id);
202 }
203 StartProtocol::SessionError(err) => {
204 data.extend_from_slice(&err.challenge.to_be_bytes());
205 data.push(err.reason as u8);
206 }
207 StartProtocol::KeepAlive(ping) => {
208 data.push(ping.flags.bits());
209 data.extend_from_slice(&ping.additional_data.to_be_bytes());
210 let session_id = serde_cbor_2::to_vec(&ping.session_id)?;
211 data.extend(session_id);
212 }
213 }
214
215 out.extend_from_slice(&(data.len() as u16).to_be_bytes());
216 out.extend(data);
217
218 Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
219 }
220
221 pub fn decode(tag: Tag, data: &[u8]) -> errors::Result<Self> {
226 if tag != Self::START_PROTOCOL_MESSAGE_TAG {
227 return Err(StartProtocolError::UnknownTag);
228 }
229
230 if data.len() < 5 {
231 return Err(StartProtocolError::InvalidLength);
232 }
233
234 if data[0] != Self::START_PROTOCOL_VERSION {
235 return Err(StartProtocolError::InvalidVersion);
236 }
237
238 let disc = data[1];
239 let len = u16::from_be_bytes(
240 data[2..4]
241 .try_into()
242 .map_err(|_| StartProtocolError::ParseError("len".into()))?,
243 ) as usize;
244 let data_offset = 2 + size_of::<u16>();
245
246 if data.len() < data_offset + len {
247 return Err(StartProtocolError::InvalidLength);
248 }
249
250 Ok(
251 match StartProtocolDiscriminants::from_repr(disc).ok_or(StartProtocolError::UnknownMessage)? {
252 StartProtocolDiscriminants::StartSession => {
253 if data.len() <= data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>() {
254 return Err(StartProtocolError::InvalidLength);
255 }
256
257 StartProtocol::StartSession(StartInitiation {
258 challenge: StartChallenge::from_be_bytes(
259 data[data_offset..data_offset + size_of::<StartChallenge>()]
260 .try_into()
261 .map_err(|_| StartProtocolError::ParseError("init.challenge".into()))?,
262 ),
263 capabilities: data[data_offset + size_of::<StartChallenge>()]
264 .try_into()
265 .map_err(|_| StartProtocolError::ParseError("init.capabilities".into()))?,
266 additional_data: u32::from_be_bytes(
267 data[data_offset + size_of::<StartChallenge>() + 1
268 ..data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()]
269 .try_into()
270 .map_err(|_| StartProtocolError::ParseError("init.additional_data".into()))?,
271 ),
272 target: serde_cbor_2::from_slice(
273 &data[data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()..],
274 )?,
275 })
276 }
277 StartProtocolDiscriminants::SessionEstablished => {
278 if data.len() <= data_offset + size_of::<StartChallenge>() {
279 return Err(StartProtocolError::InvalidLength);
280 }
281 StartProtocol::SessionEstablished(StartEstablished {
282 orig_challenge: StartChallenge::from_be_bytes(
283 data[data_offset..data_offset + size_of::<StartChallenge>()]
284 .try_into()
285 .map_err(|_| StartProtocolError::ParseError("est.challenge".into()))?,
286 ),
287 session_id: serde_cbor_2::from_slice(&data[data_offset + size_of::<StartChallenge>()..])?,
288 })
289 }
290 StartProtocolDiscriminants::SessionError => {
291 if data.len() < data_offset + size_of::<StartChallenge>() + 1 {
292 return Err(StartProtocolError::InvalidLength);
293 }
294 StartProtocol::SessionError(StartErrorType {
295 challenge: StartChallenge::from_be_bytes(
296 data[data_offset..data_offset + size_of::<StartChallenge>()]
297 .try_into()
298 .map_err(|_| StartProtocolError::ParseError("err.challenge".into()))?,
299 ),
300 reason: StartErrorReason::from_repr(data[data_offset + size_of::<StartChallenge>()])
301 .ok_or(StartProtocolError::ParseError("err.reason".into()))?,
302 })
303 }
304 StartProtocolDiscriminants::KeepAlive => {
305 if data.len() <= data_offset + size_of::<u32>() {
306 return Err(StartProtocolError::InvalidLength);
307 }
308
309 StartProtocol::KeepAlive(KeepAliveMessage {
310 flags: KeepAliveFlags::new(data[data_offset])
311 .map_err(|_| StartProtocolError::ParseError("ka.flags".into()))?,
312 additional_data: u64::from_be_bytes(
313 data[data_offset + 1..data_offset + 1 + size_of::<u64>()]
314 .try_into()
315 .map_err(|_| StartProtocolError::ParseError("ka.additional_data".into()))?,
316 ),
317 session_id: serde_cbor_2::from_slice(&data[data_offset + 1 + size_of::<u64>()..])?,
318 })
319 }
320 },
321 )
322 }
323}
324
325impl<I, T, C> TryFrom<StartProtocol<I, T, C>> for ApplicationData
326where
327 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
328 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
329 C: Into<u8> + TryFrom<u8>,
330{
331 type Error = StartProtocolError;
332
333 fn try_from(value: StartProtocol<I, T, C>) -> Result<Self, Self::Error> {
334 let (application_tag, plain_text) = value.encode()?;
335 Ok(ApplicationData::new(application_tag, plain_text.into_vec())?)
336 }
337}
338
339impl<I, T, C> TryFrom<ApplicationData> for StartProtocol<I, T, C>
340where
341 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
342 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
343 C: Into<u8> + TryFrom<u8>,
344{
345 type Error = StartProtocolError;
346
347 fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
348 Self::decode(value.application_tag, &value.plain_text)
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use hopr_crypto_packet::prelude::HoprPacket;
355 use hopr_protocol_app::prelude::Tag;
356
357 use super::*;
358
359 #[test]
360 fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
361 let msg_1 = StartProtocol::StartSession(StartInitiation {
362 challenge: 0,
363 target: "127.0.0.1:1234".to_string(),
364 capabilities: Default::default(),
365 additional_data: 0x12345678,
366 });
367
368 let (tag, msg) = msg_1.clone().encode()?;
369 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
370 assert_eq!(tag, expected);
371
372 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
373
374 assert_eq!(msg_1, msg_2);
375 Ok(())
376 }
377
378 #[test]
379 fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
380 let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
381 challenge: 0,
382 target: "127.0.0.1:1234".to_string(),
383 capabilities: 0xff,
384 additional_data: 0xffffffff,
385 });
386
387 let len = msg.encode()?.1.len();
388 assert!(
389 HoprPacket::max_surbs_with_message(len) >= 1,
390 "StartSession message size ({len}) must allow for at least 1 SURBs in packet",
391 );
392
393 Ok(())
394 }
395
396 #[test]
397 fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
398 let msg_1 = StartProtocol::SessionEstablished(StartEstablished {
399 orig_challenge: 0,
400 session_id: 10_i32,
401 });
402
403 let (tag, msg) = msg_1.clone().encode()?;
404 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
405 assert_eq!(tag, expected);
406
407 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
408
409 assert_eq!(msg_1, msg_2);
410 Ok(())
411 }
412
413 #[test]
414 fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
415 let msg_1 = StartProtocol::SessionError(StartErrorType {
416 challenge: 10,
417 reason: StartErrorReason::NoSlotsAvailable,
418 });
419
420 let (tag, msg) = msg_1.clone().encode()?;
421 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
422 assert_eq!(tag, expected);
423
424 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
425
426 assert_eq!(msg_1, msg_2);
427 Ok(())
428 }
429
430 #[test]
431 fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
432 let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
433 session_id: 10_i32,
434 flags: None.into(),
435 additional_data: 0xffffffff,
436 });
437
438 let (tag, msg) = msg_1.clone().encode()?;
439 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
440 assert_eq!(tag, expected);
441
442 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
443
444 assert_eq!(msg_1, msg_2);
445
446 let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
447 session_id: 10_i32,
448 flags: KeepAliveFlag::BalancerTarget.into(),
449 additional_data: 0xffffffff,
450 });
451
452 let (tag, msg) = msg_1.clone().encode()?;
453 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
454 assert_eq!(tag, expected);
455
456 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
457
458 assert_eq!(msg_1, msg_2);
459 Ok(())
460 }
461
462 #[test]
463 fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
464 let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
465 challenge: StartChallenge::MAX,
466 target: "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530"
467 .to_string(),
468 capabilities: 0x80,
469 additional_data: 0xffffffff,
470 });
471
472 assert!(
473 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
474 "StartSession must fit within {}",
475 HoprPacket::PAYLOAD_SIZE
476 );
477
478 let msg = StartProtocol::<String, String, u8>::SessionEstablished(StartEstablished {
479 orig_challenge: StartChallenge::MAX,
480 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
481 });
482
483 assert!(
484 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
485 "SessionEstablished must fit within {}",
486 HoprPacket::PAYLOAD_SIZE
487 );
488
489 let msg = StartProtocol::<String, String, u8>::SessionError(StartErrorType {
490 challenge: StartChallenge::MAX,
491 reason: StartErrorReason::NoSlotsAvailable,
492 });
493
494 assert!(
495 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
496 "SessionError must fit within {}",
497 HoprPacket::PAYLOAD_SIZE
498 );
499
500 let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
501 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
502 flags: None.into(),
503 additional_data: 0,
504 });
505 assert!(
506 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
507 "KeepAlive must fit within {}",
508 HoprPacket::PAYLOAD_SIZE
509 );
510
511 Ok(())
512 }
513
514 #[test]
515 fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
516 let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
517 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
518 flags: None.into(),
519 additional_data: 0,
520 });
521 let len = msg.encode()?.1.len();
522 assert_eq!(
523 KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
524 HoprPacket::MAX_SURBS_IN_PACKET
525 );
526 assert!(
527 HoprPacket::max_surbs_with_message(len) >= KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
528 "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
529 len,
530 KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE
531 );
532
533 Ok(())
534 }
535}