1pub mod errors;
19
20use hopr_crypto_packet::prelude::HoprPacket;
21use hopr_protocol_app::prelude::{ApplicationData, ReservedTag, Tag};
22
23use crate::errors::StartProtocolError;
24
25pub type StartChallenge = u64;
27
28#[repr(u8)]
30#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display, strum::FromRepr)]
31pub enum StartErrorReason {
32 Unknown = 0,
34 NoSlotsAvailable = 1,
36 Busy = 2,
38}
39
40#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub struct StartErrorType {
43 pub challenge: StartChallenge,
45 pub reason: StartErrorReason,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct StartInitiation<T, C> {
59 pub challenge: StartChallenge,
61 pub target: T,
63 pub capabilities: C,
65 pub additional_data: u32,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct StartEstablished<I> {
75 pub orig_challenge: StartChallenge,
77 pub session_id: I,
79}
80
81#[cfg_attr(doc, aquamarine::aquamarine)]
82#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
106#[strum_discriminants(vis(pub))]
107#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
108pub enum StartProtocol<I, T, C> {
109 StartSession(StartInitiation<T, C>),
111 SessionEstablished(StartEstablished<I>),
113 SessionError(StartErrorType),
115 KeepAlive(KeepAliveMessage<I>),
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct KeepAliveMessage<I> {
122 pub session_id: I,
124 pub flags: u8,
126 pub additional_data: u64,
128}
129
130impl<I> KeepAliveMessage<I> {
131 pub const MIN_SURBS_PER_MESSAGE: usize = HoprPacket::MAX_SURBS_IN_PACKET;
133}
134
135impl<I> From<I> for KeepAliveMessage<I> {
136 fn from(value: I) -> Self {
137 Self {
138 session_id: value,
139 flags: 0,
140 additional_data: 0,
141 }
142 }
143}
144
145impl<I, T, C> StartProtocol<I, T, C> {
146 pub const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
148 pub const START_PROTOCOL_VERSION: u8 = 0x02;
150}
151
152impl<I, T, C> StartProtocol<I, T, C>
153where
154 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
155 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
156 C: Into<u8> + TryFrom<u8>,
157{
158 pub fn encode(self) -> errors::Result<(Tag, Box<[u8]>)> {
160 let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
161 out.push(Self::START_PROTOCOL_VERSION);
162 out.push(StartProtocolDiscriminants::from(&self) as u8);
163
164 let mut data = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE - 2);
165 match self {
166 StartProtocol::StartSession(init) => {
167 data.extend_from_slice(&init.challenge.to_be_bytes());
168 data.push(init.capabilities.into());
169 data.extend_from_slice(&init.additional_data.to_be_bytes());
170 let target = serde_cbor_2::to_vec(&init.target)?;
171 data.extend_from_slice(&target);
172 }
173 StartProtocol::SessionEstablished(est) => {
174 data.extend_from_slice(&est.orig_challenge.to_be_bytes());
175 let session_id = serde_cbor_2::to_vec(&est.session_id)?;
176 data.extend(session_id);
177 }
178 StartProtocol::SessionError(err) => {
179 data.extend_from_slice(&err.challenge.to_be_bytes());
180 data.push(err.reason as u8);
181 }
182 StartProtocol::KeepAlive(ping) => {
183 data.push(ping.flags);
184 data.extend_from_slice(&ping.additional_data.to_be_bytes());
185 let session_id = serde_cbor_2::to_vec(&ping.session_id)?;
186 data.extend(session_id);
187 }
188 }
189
190 out.extend_from_slice(&(data.len() as u16).to_be_bytes());
191 out.extend(data);
192
193 Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
194 }
195
196 pub fn decode(tag: Tag, data: &[u8]) -> errors::Result<Self> {
201 if tag != Self::START_PROTOCOL_MESSAGE_TAG {
202 return Err(StartProtocolError::UnknownTag);
203 }
204
205 if data.len() < 5 {
206 return Err(StartProtocolError::InvalidLength);
207 }
208
209 if data[0] != Self::START_PROTOCOL_VERSION {
210 return Err(StartProtocolError::InvalidVersion);
211 }
212
213 let disc = data[1];
214 let len = u16::from_be_bytes(
215 data[2..4]
216 .try_into()
217 .map_err(|_| StartProtocolError::ParseError("len".into()))?,
218 ) as usize;
219 let data_offset = 2 + size_of::<u16>();
220
221 if data.len() < data_offset + len {
222 return Err(StartProtocolError::InvalidLength);
223 }
224
225 Ok(
226 match StartProtocolDiscriminants::from_repr(disc).ok_or(StartProtocolError::UnknownMessage)? {
227 StartProtocolDiscriminants::StartSession => {
228 if data.len() <= data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>() {
229 return Err(StartProtocolError::InvalidLength);
230 }
231
232 StartProtocol::StartSession(StartInitiation {
233 challenge: StartChallenge::from_be_bytes(
234 data[data_offset..data_offset + size_of::<StartChallenge>()]
235 .try_into()
236 .map_err(|_| StartProtocolError::ParseError("init.challenge".into()))?,
237 ),
238 capabilities: data[data_offset + size_of::<StartChallenge>()]
239 .try_into()
240 .map_err(|_| StartProtocolError::ParseError("init.capabilities".into()))?,
241 additional_data: u32::from_be_bytes(
242 data[data_offset + size_of::<StartChallenge>() + 1
243 ..data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()]
244 .try_into()
245 .map_err(|_| StartProtocolError::ParseError("init.additional_data".into()))?,
246 ),
247 target: serde_cbor_2::from_slice(
248 &data[data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()..],
249 )?,
250 })
251 }
252 StartProtocolDiscriminants::SessionEstablished => {
253 if data.len() <= data_offset + size_of::<StartChallenge>() {
254 return Err(StartProtocolError::InvalidLength);
255 }
256 StartProtocol::SessionEstablished(StartEstablished {
257 orig_challenge: StartChallenge::from_be_bytes(
258 data[data_offset..data_offset + size_of::<StartChallenge>()]
259 .try_into()
260 .map_err(|_| StartProtocolError::ParseError("est.challenge".into()))?,
261 ),
262 session_id: serde_cbor_2::from_slice(&data[data_offset + size_of::<StartChallenge>()..])?,
263 })
264 }
265 StartProtocolDiscriminants::SessionError => {
266 if data.len() < data_offset + size_of::<StartChallenge>() + 1 {
267 return Err(StartProtocolError::InvalidLength);
268 }
269 StartProtocol::SessionError(StartErrorType {
270 challenge: StartChallenge::from_be_bytes(
271 data[data_offset..data_offset + size_of::<StartChallenge>()]
272 .try_into()
273 .map_err(|_| StartProtocolError::ParseError("err.challenge".into()))?,
274 ),
275 reason: StartErrorReason::from_repr(data[data_offset + size_of::<StartChallenge>()])
276 .ok_or(StartProtocolError::ParseError("err.reason".into()))?,
277 })
278 }
279 StartProtocolDiscriminants::KeepAlive => {
280 if data.len() <= data_offset + size_of::<u32>() {
281 return Err(StartProtocolError::InvalidLength);
282 }
283
284 StartProtocol::KeepAlive(KeepAliveMessage {
285 flags: data[data_offset],
286 additional_data: u64::from_be_bytes(
287 data[data_offset + 1..data_offset + 1 + size_of::<u64>()]
288 .try_into()
289 .map_err(|_| StartProtocolError::ParseError("ka.additional_data".into()))?,
290 ),
291 session_id: serde_cbor_2::from_slice(&data[data_offset + 1 + size_of::<u64>()..])?,
292 })
293 }
294 },
295 )
296 }
297}
298
299impl<I, T, C> TryFrom<StartProtocol<I, T, C>> for ApplicationData
300where
301 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
302 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
303 C: Into<u8> + TryFrom<u8>,
304{
305 type Error = StartProtocolError;
306
307 fn try_from(value: StartProtocol<I, T, C>) -> Result<Self, Self::Error> {
308 let (application_tag, plain_text) = value.encode()?;
309 Ok(ApplicationData::new(application_tag, plain_text.into_vec())?)
310 }
311}
312
313impl<I, T, C> TryFrom<ApplicationData> for StartProtocol<I, T, C>
314where
315 I: serde::Serialize + for<'de> serde::Deserialize<'de>,
316 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
317 C: Into<u8> + TryFrom<u8>,
318{
319 type Error = StartProtocolError;
320
321 fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
322 Self::decode(value.application_tag, &value.plain_text)
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use hopr_crypto_packet::prelude::HoprPacket;
329 use hopr_protocol_app::prelude::Tag;
330
331 use super::*;
332
333 #[test]
334 fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
335 let msg_1 = StartProtocol::StartSession(StartInitiation {
336 challenge: 0,
337 target: "127.0.0.1:1234".to_string(),
338 capabilities: Default::default(),
339 additional_data: 0x12345678,
340 });
341
342 let (tag, msg) = msg_1.clone().encode()?;
343 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
344 assert_eq!(tag, expected);
345
346 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
347
348 assert_eq!(msg_1, msg_2);
349 Ok(())
350 }
351
352 #[test]
353 fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
354 let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
355 challenge: 0,
356 target: "127.0.0.1:1234".to_string(),
357 capabilities: 0xff,
358 additional_data: 0xffffffff,
359 });
360
361 let len = msg.encode()?.1.len();
362 assert!(
363 HoprPacket::max_surbs_with_message(len) >= 1,
364 "StartSession message size ({len}) must allow for at least 1 SURBs in packet",
365 );
366
367 Ok(())
368 }
369
370 #[test]
371 fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
372 let msg_1 = StartProtocol::SessionEstablished(StartEstablished {
373 orig_challenge: 0,
374 session_id: 10_i32,
375 });
376
377 let (tag, msg) = msg_1.clone().encode()?;
378 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
379 assert_eq!(tag, expected);
380
381 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
382
383 assert_eq!(msg_1, msg_2);
384 Ok(())
385 }
386
387 #[test]
388 fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
389 let msg_1 = StartProtocol::SessionError(StartErrorType {
390 challenge: 10,
391 reason: StartErrorReason::NoSlotsAvailable,
392 });
393
394 let (tag, msg) = msg_1.clone().encode()?;
395 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
396 assert_eq!(tag, expected);
397
398 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
399
400 assert_eq!(msg_1, msg_2);
401 Ok(())
402 }
403
404 #[test]
405 fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
406 let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
407 session_id: 10_i32,
408 flags: 0,
409 additional_data: 0xffffffff,
410 });
411
412 let (tag, msg) = msg_1.clone().encode()?;
413 let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
414 assert_eq!(tag, expected);
415
416 let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
417
418 assert_eq!(msg_1, msg_2);
419 Ok(())
420 }
421
422 #[test]
423 fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
424 let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
425 challenge: StartChallenge::MAX,
426 target: "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530"
427 .to_string(),
428 capabilities: 0x80,
429 additional_data: 0xffffffff,
430 });
431
432 assert!(
433 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
434 "StartSession must fit within {}",
435 HoprPacket::PAYLOAD_SIZE
436 );
437
438 let msg = StartProtocol::<String, String, u8>::SessionEstablished(StartEstablished {
439 orig_challenge: StartChallenge::MAX,
440 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
441 });
442
443 assert!(
444 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
445 "SessionEstablished must fit within {}",
446 HoprPacket::PAYLOAD_SIZE
447 );
448
449 let msg = StartProtocol::<String, String, u8>::SessionError(StartErrorType {
450 challenge: StartChallenge::MAX,
451 reason: StartErrorReason::NoSlotsAvailable,
452 });
453
454 assert!(
455 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
456 "SessionError must fit within {}",
457 HoprPacket::PAYLOAD_SIZE
458 );
459
460 let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
461 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
462 flags: 0xff,
463 additional_data: 0xffffffff,
464 });
465 assert!(
466 msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
467 "KeepAlive must fit within {}",
468 HoprPacket::PAYLOAD_SIZE
469 );
470
471 Ok(())
472 }
473
474 #[test]
475 fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
476 let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
477 session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
478 flags: 0xff,
479 additional_data: 0xffffffff,
480 });
481 let len = msg.encode()?.1.len();
482 assert_eq!(
483 KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
484 HoprPacket::MAX_SURBS_IN_PACKET
485 );
486 assert!(
487 HoprPacket::max_surbs_with_message(len) >= KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
488 "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
489 len,
490 KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE
491 );
492
493 Ok(())
494 }
495}