Skip to main content

hopr_protocol_start/
lib.rs

1//! This crate defines the Start sub-protocol used for HOPR Session initiation and management.
2//!
3//! The Start protocol is used to establish Session as described in HOPR
4//! [`RFC-0012`](https://github.com/hoprnet/rfc/tree/main/rfcs/RFC-0012-session-start-protocol).
5//! and is implemented via the [`StartProtocol`] enum.
6//!
7//! The protocol is defined via generic arguments `I` (for Session ID), `T` (for Session Target)
8//! and `C` (for Session capabilities).
9//!
10//! Per `RFC-0012`, the types `I` and `T` are serialized/deserialized to the CBOR binary format
11//! (see [`RFC7049`](https://datatracker.ietf.org/doc/html/rfc7049)) and therefore must implement
12//! `serde::Serialize + serde::Deserialize`.
13//! The capability type `C` must be expressible as a single unsigned byte.
14//!
15//! See [`StartProtocol`] docs for the protocol diagram.
16
17/// Contains errors raised by the Start protocol.
18pub mod errors;
19
20use hopr_crypto_packet::prelude::HoprPacket;
21use hopr_protocol_app::prelude::{ApplicationData, ReservedTag, Tag};
22
23use crate::errors::StartProtocolError;
24
25/// Challenge that identifies a Start initiation protocol message.
26pub type StartChallenge = u64;
27
28/// Lists all Start protocol error reasons.
29#[repr(u8)]
30#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display, strum::FromRepr)]
31pub enum StartErrorReason {
32    /// Unknown error.
33    Unknown = 0,
34    /// No more slots are available at the recipient.
35    NoSlotsAvailable = 1,
36    /// Recipient is busy.
37    Busy = 2,
38}
39
40/// Error message in the Start protocol.
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub struct StartErrorType {
43    /// Challenge that relates to this error.
44    pub challenge: StartChallenge,
45    /// The [reason](StartErrorReason) of this error.
46    pub reason: StartErrorReason,
47}
48
49/// The session initiation message of the Start protocol.
50///
51/// ## Generic parameters
52/// - `T` is the session target
53/// - `C` are session capabilities
54///
55/// The `additional_data` are set dependent on the `capabilities`
56/// or set to `0x00000000` to be ignored.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct StartInitiation<T, C> {
59    /// Random challenge for this initiation.
60    pub challenge: StartChallenge,
61    /// Target of the session, i.e., what should the other party do with the traffic.
62    pub target: T,
63    /// Capabilities of the session.
64    pub capabilities: C,
65    /// Additional options (might be `capabilities` dependent), ignored if `0x00000000`.
66    pub additional_data: u32,
67}
68
69/// Message of the Start protocol that confirms the establishment of a session.
70///
71/// ## Generic parameters
72/// `I` is for session identifier.
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct StartEstablished<I> {
75    /// Challenge that was used in the [initiation message](StartInitiation) to establish correspondence.
76    pub orig_challenge: StartChallenge,
77    /// Session ID that was selected by the recipient.
78    pub session_id: I,
79}
80
81#[cfg_attr(doc, aquamarine::aquamarine)]
82/// Lists all messages of the Start protocol for a session establishment.
83///
84/// ## Generic parameters
85/// - `I` is the session identifier.
86/// - `T` is the session target.
87/// - `C` are session capabilities.
88/// # Diagram of the protocol
89/// ```mermaid
90/// sequenceDiagram
91///     Entry->>Exit: SessionInitiation (Challenge)
92///     alt If Exit can accept a new session
93///     Note right of Exit: SessionID [Pseudonym, Tag]
94///     Exit->>Entry: SessionEstablished (Challenge, SessionID_Entry)
95///     Note left of Entry: SessionID [Pseudonym, Tag]
96///     Entry->>Exit: KeepAlive (SessionID)
97///     Note over Entry,Exit: Data
98///     else If Exit cannot accept a new session
99///     Exit->>Entry: SessionError (Challenge, Reason)
100///     end
101///     opt If initiation attempt times out
102///     Note left of Entry: Failure
103///     end
104/// ```
105#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
106#[strum_discriminants(vis(pub))]
107#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
108pub enum StartProtocol<I, T, C> {
109    /// Request to initiate a new session.
110    StartSession(StartInitiation<T, C>),
111    /// Confirmation that a new session has been established by the counterparty.
112    SessionEstablished(StartEstablished<I>),
113    /// Counterparty could not establish a new session due to an error.
114    SessionError(StartErrorType),
115    /// A ping message to keep the session alive.
116    KeepAlive(KeepAliveMessage<I>),
117}
118
119/// Keep-alive message for a Session with the identifier `T`.
120#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct KeepAliveMessage<I> {
122    /// Session ID.
123    pub session_id: I,
124    /// Additional flags that govern how the `additional_data` field is interpreted, or 0.
125    pub flags: KeepAliveFlags,
126    /// Additional data (usually `flags` dependent), ignored if `0x00000000`.
127    pub additional_data: u64,
128}
129
130/// [Flags](KeepAliveFlag) that can be sent via the [`KeepAliveMessage`].
131///
132/// The flags can define the meaning of the `additional_data` field.
133pub type KeepAliveFlags = flagset::FlagSet<KeepAliveFlag>;
134
135flagset::flags! {
136    /// Individual flags that can be set in a [`KeepAliveMessage`].
137    pub enum KeepAliveFlag: u8 {
138        /// The `additional_data` field contains load balancer target information.
139        ///
140        /// The value of `additional_data` represents the optimal number of SURBs that the
141        /// Session Initiator wishes to maintain at the Session Recipient.
142        ///
143        /// Mutually exclusive with `BalancerState`.
144        BalancerTarget = 0x01,
145        /// The `additional_data` field contains load balancer state information.
146        ///
147        /// The value of `additional_data` represents the current number of SURBs
148        /// that the Session Recipient estimates to have.
149        ///
150        /// Mutually exclusive with `BalancerTarget`.
151        BalancerState = 0x02,
152    }
153}
154
155impl<I> KeepAliveMessage<I> {
156    /// The minimum number of SURBs a [`KeepAliveMessage`] must be able to carry.
157    pub const MIN_SURBS_PER_MESSAGE: usize = HoprPacket::MAX_SURBS_IN_PACKET;
158}
159
160impl<I> From<I> for KeepAliveMessage<I> {
161    fn from(value: I) -> Self {
162        Self {
163            session_id: value,
164            flags: None.into(),
165            additional_data: 0,
166        }
167    }
168}
169
170impl<I, T, C> StartProtocol<I, T, C> {
171    /// Fixed [`Tag`] of every protocol message.
172    pub const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
173    /// Current version of the Start protocol.
174    pub const START_PROTOCOL_VERSION: u8 = 0x02;
175}
176
177impl<I, T, C> StartProtocol<I, T, C>
178where
179    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
180    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
181    C: Into<u8> + TryFrom<u8>,
182{
183    /// Tries to encode the message into binary format and [`Tag`]
184    pub fn encode(self) -> errors::Result<(Tag, Box<[u8]>)> {
185        let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
186        out.push(Self::START_PROTOCOL_VERSION);
187        out.push(StartProtocolDiscriminants::from(&self) as u8);
188
189        let mut data = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE - 2);
190        match self {
191            StartProtocol::StartSession(init) => {
192                data.extend_from_slice(&init.challenge.to_be_bytes());
193                data.push(init.capabilities.into());
194                data.extend_from_slice(&init.additional_data.to_be_bytes());
195                let target = serde_cbor_2::to_vec(&init.target)?;
196                data.extend_from_slice(&target);
197            }
198            StartProtocol::SessionEstablished(est) => {
199                data.extend_from_slice(&est.orig_challenge.to_be_bytes());
200                let session_id = serde_cbor_2::to_vec(&est.session_id)?;
201                data.extend(session_id);
202            }
203            StartProtocol::SessionError(err) => {
204                data.extend_from_slice(&err.challenge.to_be_bytes());
205                data.push(err.reason as u8);
206            }
207            StartProtocol::KeepAlive(ping) => {
208                data.push(ping.flags.bits());
209                data.extend_from_slice(&ping.additional_data.to_be_bytes());
210                let session_id = serde_cbor_2::to_vec(&ping.session_id)?;
211                data.extend(session_id);
212            }
213        }
214
215        out.extend_from_slice(&(data.len() as u16).to_be_bytes());
216        out.extend(data);
217
218        Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
219    }
220
221    /// Tries to decode the message from the binary representation and [`Tag`].
222    ///
223    /// The `tag` must be currently [`START_PROTOCOL_MESSAGE_TAG`](Self::START_PROTOCOL_MESSAGE_TAG)
224    /// and version [`START_PROTOCOL_VERSION`](Self::START_PROTOCOL_VERSION).
225    pub fn decode(tag: Tag, data: &[u8]) -> errors::Result<Self> {
226        if tag != Self::START_PROTOCOL_MESSAGE_TAG {
227            return Err(StartProtocolError::UnknownTag);
228        }
229
230        if data.len() < 5 {
231            return Err(StartProtocolError::InvalidLength);
232        }
233
234        if data[0] != Self::START_PROTOCOL_VERSION {
235            return Err(StartProtocolError::InvalidVersion);
236        }
237
238        let disc = data[1];
239        let len = u16::from_be_bytes(
240            data[2..4]
241                .try_into()
242                .map_err(|_| StartProtocolError::ParseError("len".into()))?,
243        ) as usize;
244        let data_offset = 2 + size_of::<u16>();
245
246        if data.len() < data_offset + len {
247            return Err(StartProtocolError::InvalidLength);
248        }
249
250        Ok(
251            match StartProtocolDiscriminants::from_repr(disc).ok_or(StartProtocolError::UnknownMessage)? {
252                StartProtocolDiscriminants::StartSession => {
253                    if data.len() <= data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>() {
254                        return Err(StartProtocolError::InvalidLength);
255                    }
256
257                    StartProtocol::StartSession(StartInitiation {
258                        challenge: StartChallenge::from_be_bytes(
259                            data[data_offset..data_offset + size_of::<StartChallenge>()]
260                                .try_into()
261                                .map_err(|_| StartProtocolError::ParseError("init.challenge".into()))?,
262                        ),
263                        capabilities: data[data_offset + size_of::<StartChallenge>()]
264                            .try_into()
265                            .map_err(|_| StartProtocolError::ParseError("init.capabilities".into()))?,
266                        additional_data: u32::from_be_bytes(
267                            data[data_offset + size_of::<StartChallenge>() + 1
268                                ..data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()]
269                                .try_into()
270                                .map_err(|_| StartProtocolError::ParseError("init.additional_data".into()))?,
271                        ),
272                        target: serde_cbor_2::from_slice(
273                            &data[data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()..],
274                        )?,
275                    })
276                }
277                StartProtocolDiscriminants::SessionEstablished => {
278                    if data.len() <= data_offset + size_of::<StartChallenge>() {
279                        return Err(StartProtocolError::InvalidLength);
280                    }
281                    StartProtocol::SessionEstablished(StartEstablished {
282                        orig_challenge: StartChallenge::from_be_bytes(
283                            data[data_offset..data_offset + size_of::<StartChallenge>()]
284                                .try_into()
285                                .map_err(|_| StartProtocolError::ParseError("est.challenge".into()))?,
286                        ),
287                        session_id: serde_cbor_2::from_slice(&data[data_offset + size_of::<StartChallenge>()..])?,
288                    })
289                }
290                StartProtocolDiscriminants::SessionError => {
291                    if data.len() < data_offset + size_of::<StartChallenge>() + 1 {
292                        return Err(StartProtocolError::InvalidLength);
293                    }
294                    StartProtocol::SessionError(StartErrorType {
295                        challenge: StartChallenge::from_be_bytes(
296                            data[data_offset..data_offset + size_of::<StartChallenge>()]
297                                .try_into()
298                                .map_err(|_| StartProtocolError::ParseError("err.challenge".into()))?,
299                        ),
300                        reason: StartErrorReason::from_repr(data[data_offset + size_of::<StartChallenge>()])
301                            .ok_or(StartProtocolError::ParseError("err.reason".into()))?,
302                    })
303                }
304                StartProtocolDiscriminants::KeepAlive => {
305                    if data.len() <= data_offset + size_of::<u32>() {
306                        return Err(StartProtocolError::InvalidLength);
307                    }
308
309                    StartProtocol::KeepAlive(KeepAliveMessage {
310                        flags: KeepAliveFlags::new(data[data_offset])
311                            .map_err(|_| StartProtocolError::ParseError("ka.flags".into()))?,
312                        additional_data: u64::from_be_bytes(
313                            data[data_offset + 1..data_offset + 1 + size_of::<u64>()]
314                                .try_into()
315                                .map_err(|_| StartProtocolError::ParseError("ka.additional_data".into()))?,
316                        ),
317                        session_id: serde_cbor_2::from_slice(&data[data_offset + 1 + size_of::<u64>()..])?,
318                    })
319                }
320            },
321        )
322    }
323}
324
325impl<I, T, C> TryFrom<StartProtocol<I, T, C>> for ApplicationData
326where
327    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
328    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
329    C: Into<u8> + TryFrom<u8>,
330{
331    type Error = StartProtocolError;
332
333    fn try_from(value: StartProtocol<I, T, C>) -> Result<Self, Self::Error> {
334        let (application_tag, plain_text) = value.encode()?;
335        Ok(ApplicationData::new(application_tag, plain_text.into_vec())?)
336    }
337}
338
339impl<I, T, C> TryFrom<ApplicationData> for StartProtocol<I, T, C>
340where
341    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
342    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
343    C: Into<u8> + TryFrom<u8>,
344{
345    type Error = StartProtocolError;
346
347    fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
348        Self::decode(value.application_tag, &value.plain_text)
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use hopr_crypto_packet::prelude::HoprPacket;
355    use hopr_protocol_app::prelude::Tag;
356
357    use super::*;
358
359    #[test]
360    fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
361        let msg_1 = StartProtocol::StartSession(StartInitiation {
362            challenge: 0,
363            target: "127.0.0.1:1234".to_string(),
364            capabilities: Default::default(),
365            additional_data: 0x12345678,
366        });
367
368        let (tag, msg) = msg_1.clone().encode()?;
369        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
370        assert_eq!(tag, expected);
371
372        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
373
374        assert_eq!(msg_1, msg_2);
375        Ok(())
376    }
377
378    #[test]
379    fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
380        let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
381            challenge: 0,
382            target: "127.0.0.1:1234".to_string(),
383            capabilities: 0xff,
384            additional_data: 0xffffffff,
385        });
386
387        let len = msg.encode()?.1.len();
388        assert!(
389            HoprPacket::max_surbs_with_message(len) >= 1,
390            "StartSession message size ({len}) must allow for at least 1 SURBs in packet",
391        );
392
393        Ok(())
394    }
395
396    #[test]
397    fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
398        let msg_1 = StartProtocol::SessionEstablished(StartEstablished {
399            orig_challenge: 0,
400            session_id: 10_i32,
401        });
402
403        let (tag, msg) = msg_1.clone().encode()?;
404        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
405        assert_eq!(tag, expected);
406
407        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
408
409        assert_eq!(msg_1, msg_2);
410        Ok(())
411    }
412
413    #[test]
414    fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
415        let msg_1 = StartProtocol::SessionError(StartErrorType {
416            challenge: 10,
417            reason: StartErrorReason::NoSlotsAvailable,
418        });
419
420        let (tag, msg) = msg_1.clone().encode()?;
421        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
422        assert_eq!(tag, expected);
423
424        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
425
426        assert_eq!(msg_1, msg_2);
427        Ok(())
428    }
429
430    #[test]
431    fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
432        let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
433            session_id: 10_i32,
434            flags: None.into(),
435            additional_data: 0xffffffff,
436        });
437
438        let (tag, msg) = msg_1.clone().encode()?;
439        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
440        assert_eq!(tag, expected);
441
442        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
443
444        assert_eq!(msg_1, msg_2);
445
446        let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
447            session_id: 10_i32,
448            flags: KeepAliveFlag::BalancerTarget.into(),
449            additional_data: 0xffffffff,
450        });
451
452        let (tag, msg) = msg_1.clone().encode()?;
453        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
454        assert_eq!(tag, expected);
455
456        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
457
458        assert_eq!(msg_1, msg_2);
459        Ok(())
460    }
461
462    #[test]
463    fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
464        let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
465            challenge: StartChallenge::MAX,
466            target: "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530"
467                .to_string(),
468            capabilities: 0x80,
469            additional_data: 0xffffffff,
470        });
471
472        assert!(
473            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
474            "StartSession must fit within {}",
475            HoprPacket::PAYLOAD_SIZE
476        );
477
478        let msg = StartProtocol::<String, String, u8>::SessionEstablished(StartEstablished {
479            orig_challenge: StartChallenge::MAX,
480            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
481        });
482
483        assert!(
484            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
485            "SessionEstablished must fit within {}",
486            HoprPacket::PAYLOAD_SIZE
487        );
488
489        let msg = StartProtocol::<String, String, u8>::SessionError(StartErrorType {
490            challenge: StartChallenge::MAX,
491            reason: StartErrorReason::NoSlotsAvailable,
492        });
493
494        assert!(
495            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
496            "SessionError must fit within {}",
497            HoprPacket::PAYLOAD_SIZE
498        );
499
500        let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
501            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
502            flags: None.into(),
503            additional_data: 0,
504        });
505        assert!(
506            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
507            "KeepAlive must fit within {}",
508            HoprPacket::PAYLOAD_SIZE
509        );
510
511        Ok(())
512    }
513
514    #[test]
515    fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
516        let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
517            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
518            flags: None.into(),
519            additional_data: 0,
520        });
521        let len = msg.encode()?.1.len();
522        assert_eq!(
523            KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
524            HoprPacket::MAX_SURBS_IN_PACKET
525        );
526        assert!(
527            HoprPacket::max_surbs_with_message(len) >= KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
528            "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
529            len,
530            KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE
531        );
532
533        Ok(())
534    }
535}