hopr_protocol_start/
lib.rs

1//! This crate defines the Start sub-protocol used for HOPR Session initiation and management.
2//!
3//! The Start protocol is used to establish Session as described in HOPR
4//! [`RFC-0012`](https://github.com/hoprnet/rfc/tree/main/rfcs/RFC-0012-session-start-protocol).
5//! and is implemented via the [`StartProtocol`] enum.
6//!
7//! The protocol is defined via generic arguments `I` (for Session ID), `T` (for Session Target)
8//! and `C` (for Session capabilities).
9//!
10//! Per `RFC-0012`, the types `I` and `T` are serialized/deserialized to the CBOR binary format
11//! (see [`RFC7049`](https://datatracker.ietf.org/doc/html/rfc7049)) and therefore must implement
12//! `serde::Serialize + serde::Deserialize`.
13//! The capability type `C` must be expressible as a single unsigned byte.
14//!
15//! See [`StartProtocol`] docs for the protocol diagram.
16
17/// Contains errors raised by the Start protocol.
18pub mod errors;
19
20use hopr_crypto_packet::prelude::HoprPacket;
21use hopr_protocol_app::prelude::{ApplicationData, ReservedTag, Tag};
22
23use crate::errors::StartProtocolError;
24
25/// Challenge that identifies a Start initiation protocol message.
26pub type StartChallenge = u64;
27
28/// Lists all Start protocol error reasons.
29#[repr(u8)]
30#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display, strum::FromRepr)]
31pub enum StartErrorReason {
32    /// Unknown error.
33    Unknown = 0,
34    /// No more slots are available at the recipient.
35    NoSlotsAvailable = 1,
36    /// Recipient is busy.
37    Busy = 2,
38}
39
40/// Error message in the Start protocol.
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub struct StartErrorType {
43    /// Challenge that relates to this error.
44    pub challenge: StartChallenge,
45    /// The [reason](StartErrorReason) of this error.
46    pub reason: StartErrorReason,
47}
48
49/// The session initiation message of the Start protocol.
50///
51/// ## Generic parameters
52/// - `T` is the session target
53/// - `C` are session capabilities
54///
55/// The `additional_data` are set dependent on the `capabilities`
56/// or set to `0x00000000` to be ignored.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct StartInitiation<T, C> {
59    /// Random challenge for this initiation.
60    pub challenge: StartChallenge,
61    /// Target of the session, i.e., what should the other party do with the traffic.
62    pub target: T,
63    /// Capabilities of the session.
64    pub capabilities: C,
65    /// Additional options (might be `capabilities` dependent), ignored if `0x00000000`.
66    pub additional_data: u32,
67}
68
69/// Message of the Start protocol that confirms the establishment of a session.
70///
71/// ## Generic parameters
72/// `I` is for session identifier.
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct StartEstablished<I> {
75    /// Challenge that was used in the [initiation message](StartInitiation) to establish correspondence.
76    pub orig_challenge: StartChallenge,
77    /// Session ID that was selected by the recipient.
78    pub session_id: I,
79}
80
81#[cfg_attr(doc, aquamarine::aquamarine)]
82/// Lists all messages of the Start protocol for a session establishment.
83///
84/// ## Generic parameters
85/// - `I` is the session identifier.
86/// - `T` is the session target.
87/// - `C` are session capabilities.
88/// # Diagram of the protocol
89/// ```mermaid
90/// sequenceDiagram
91///     Entry->>Exit: SessionInitiation (Challenge)
92///     alt If Exit can accept a new session
93///     Note right of Exit: SessionID [Pseudonym, Tag]
94///     Exit->>Entry: SessionEstablished (Challenge, SessionID_Entry)
95///     Note left of Entry: SessionID [Pseudonym, Tag]
96///     Entry->>Exit: KeepAlive (SessionID)
97///     Note over Entry,Exit: Data
98///     else If Exit cannot accept a new session
99///     Exit->>Entry: SessionError (Challenge, Reason)
100///     end
101///     opt If initiation attempt times out
102///     Note left of Entry: Failure
103///     end
104/// ```
105#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
106#[strum_discriminants(vis(pub))]
107#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
108pub enum StartProtocol<I, T, C> {
109    /// Request to initiate a new session.
110    StartSession(StartInitiation<T, C>),
111    /// Confirmation that a new session has been established by the counterparty.
112    SessionEstablished(StartEstablished<I>),
113    /// Counterparty could not establish a new session due to an error.
114    SessionError(StartErrorType),
115    /// A ping message to keep the session alive.
116    KeepAlive(KeepAliveMessage<I>),
117}
118
119/// Keep-alive message for a Session with the identifier `T`.
120#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct KeepAliveMessage<I> {
122    /// Session ID.
123    pub session_id: I,
124    /// Reserved for future use, always zero currently.
125    pub flags: u8,
126    /// Additional data (might be `flags` dependent), ignored if `0x00000000`.
127    pub additional_data: u64,
128}
129
130impl<I> KeepAliveMessage<I> {
131    /// The minimum number of SURBs a [`KeepAliveMessage`] must be able to carry.
132    pub const MIN_SURBS_PER_MESSAGE: usize = HoprPacket::MAX_SURBS_IN_PACKET;
133}
134
135impl<I> From<I> for KeepAliveMessage<I> {
136    fn from(value: I) -> Self {
137        Self {
138            session_id: value,
139            flags: 0,
140            additional_data: 0,
141        }
142    }
143}
144
145impl<I, T, C> StartProtocol<I, T, C> {
146    /// Fixed [`Tag`] of every protocol message.
147    pub const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
148    /// Current version of the Start protocol.
149    pub const START_PROTOCOL_VERSION: u8 = 0x02;
150}
151
152impl<I, T, C> StartProtocol<I, T, C>
153where
154    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
155    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
156    C: Into<u8> + TryFrom<u8>,
157{
158    /// Tries to encode the message into binary format and [`Tag`]
159    pub fn encode(self) -> errors::Result<(Tag, Box<[u8]>)> {
160        let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
161        out.push(Self::START_PROTOCOL_VERSION);
162        out.push(StartProtocolDiscriminants::from(&self) as u8);
163
164        let mut data = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE - 2);
165        match self {
166            StartProtocol::StartSession(init) => {
167                data.extend_from_slice(&init.challenge.to_be_bytes());
168                data.push(init.capabilities.into());
169                data.extend_from_slice(&init.additional_data.to_be_bytes());
170                let target = serde_cbor_2::to_vec(&init.target)?;
171                data.extend_from_slice(&target);
172            }
173            StartProtocol::SessionEstablished(est) => {
174                data.extend_from_slice(&est.orig_challenge.to_be_bytes());
175                let session_id = serde_cbor_2::to_vec(&est.session_id)?;
176                data.extend(session_id);
177            }
178            StartProtocol::SessionError(err) => {
179                data.extend_from_slice(&err.challenge.to_be_bytes());
180                data.push(err.reason as u8);
181            }
182            StartProtocol::KeepAlive(ping) => {
183                data.push(ping.flags);
184                data.extend_from_slice(&ping.additional_data.to_be_bytes());
185                let session_id = serde_cbor_2::to_vec(&ping.session_id)?;
186                data.extend(session_id);
187            }
188        }
189
190        out.extend_from_slice(&(data.len() as u16).to_be_bytes());
191        out.extend(data);
192
193        Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
194    }
195
196    /// Tries to decode the message from the binary representation and [`Tag`].
197    ///
198    /// The `tag` must be currently [`START_PROTOCOL_MESSAGE_TAG`](Self::START_PROTOCOL_MESSAGE_TAG)
199    /// and version [`START_PROTOCOL_VERSION`](Self::START_PROTOCOL_VERSION).
200    pub fn decode(tag: Tag, data: &[u8]) -> errors::Result<Self> {
201        if tag != Self::START_PROTOCOL_MESSAGE_TAG {
202            return Err(StartProtocolError::UnknownTag);
203        }
204
205        if data.len() < 5 {
206            return Err(StartProtocolError::InvalidLength);
207        }
208
209        if data[0] != Self::START_PROTOCOL_VERSION {
210            return Err(StartProtocolError::InvalidVersion);
211        }
212
213        let disc = data[1];
214        let len = u16::from_be_bytes(
215            data[2..4]
216                .try_into()
217                .map_err(|_| StartProtocolError::ParseError("len".into()))?,
218        ) as usize;
219        let data_offset = 2 + size_of::<u16>();
220
221        if data.len() < data_offset + len {
222            return Err(StartProtocolError::InvalidLength);
223        }
224
225        Ok(
226            match StartProtocolDiscriminants::from_repr(disc).ok_or(StartProtocolError::UnknownMessage)? {
227                StartProtocolDiscriminants::StartSession => {
228                    if data.len() <= data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>() {
229                        return Err(StartProtocolError::InvalidLength);
230                    }
231
232                    StartProtocol::StartSession(StartInitiation {
233                        challenge: StartChallenge::from_be_bytes(
234                            data[data_offset..data_offset + size_of::<StartChallenge>()]
235                                .try_into()
236                                .map_err(|_| StartProtocolError::ParseError("init.challenge".into()))?,
237                        ),
238                        capabilities: data[data_offset + size_of::<StartChallenge>()]
239                            .try_into()
240                            .map_err(|_| StartProtocolError::ParseError("init.capabilities".into()))?,
241                        additional_data: u32::from_be_bytes(
242                            data[data_offset + size_of::<StartChallenge>() + 1
243                                ..data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()]
244                                .try_into()
245                                .map_err(|_| StartProtocolError::ParseError("init.additional_data".into()))?,
246                        ),
247                        target: serde_cbor_2::from_slice(
248                            &data[data_offset + size_of::<StartChallenge>() + 1 + size_of::<u32>()..],
249                        )?,
250                    })
251                }
252                StartProtocolDiscriminants::SessionEstablished => {
253                    if data.len() <= data_offset + size_of::<StartChallenge>() {
254                        return Err(StartProtocolError::InvalidLength);
255                    }
256                    StartProtocol::SessionEstablished(StartEstablished {
257                        orig_challenge: StartChallenge::from_be_bytes(
258                            data[data_offset..data_offset + size_of::<StartChallenge>()]
259                                .try_into()
260                                .map_err(|_| StartProtocolError::ParseError("est.challenge".into()))?,
261                        ),
262                        session_id: serde_cbor_2::from_slice(&data[data_offset + size_of::<StartChallenge>()..])?,
263                    })
264                }
265                StartProtocolDiscriminants::SessionError => {
266                    if data.len() < data_offset + size_of::<StartChallenge>() + 1 {
267                        return Err(StartProtocolError::InvalidLength);
268                    }
269                    StartProtocol::SessionError(StartErrorType {
270                        challenge: StartChallenge::from_be_bytes(
271                            data[data_offset..data_offset + size_of::<StartChallenge>()]
272                                .try_into()
273                                .map_err(|_| StartProtocolError::ParseError("err.challenge".into()))?,
274                        ),
275                        reason: StartErrorReason::from_repr(data[data_offset + size_of::<StartChallenge>()])
276                            .ok_or(StartProtocolError::ParseError("err.reason".into()))?,
277                    })
278                }
279                StartProtocolDiscriminants::KeepAlive => {
280                    if data.len() <= data_offset + size_of::<u32>() {
281                        return Err(StartProtocolError::InvalidLength);
282                    }
283
284                    StartProtocol::KeepAlive(KeepAliveMessage {
285                        flags: data[data_offset],
286                        additional_data: u64::from_be_bytes(
287                            data[data_offset + 1..data_offset + 1 + size_of::<u64>()]
288                                .try_into()
289                                .map_err(|_| StartProtocolError::ParseError("ka.additional_data".into()))?,
290                        ),
291                        session_id: serde_cbor_2::from_slice(&data[data_offset + 1 + size_of::<u64>()..])?,
292                    })
293                }
294            },
295        )
296    }
297}
298
299impl<I, T, C> TryFrom<StartProtocol<I, T, C>> for ApplicationData
300where
301    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
302    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
303    C: Into<u8> + TryFrom<u8>,
304{
305    type Error = StartProtocolError;
306
307    fn try_from(value: StartProtocol<I, T, C>) -> Result<Self, Self::Error> {
308        let (application_tag, plain_text) = value.encode()?;
309        Ok(ApplicationData::new(application_tag, plain_text.into_vec())?)
310    }
311}
312
313impl<I, T, C> TryFrom<ApplicationData> for StartProtocol<I, T, C>
314where
315    I: serde::Serialize + for<'de> serde::Deserialize<'de>,
316    T: serde::Serialize + for<'de> serde::Deserialize<'de>,
317    C: Into<u8> + TryFrom<u8>,
318{
319    type Error = StartProtocolError;
320
321    fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
322        Self::decode(value.application_tag, &value.plain_text)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use hopr_crypto_packet::prelude::HoprPacket;
329    use hopr_protocol_app::prelude::Tag;
330
331    use super::*;
332
333    #[test]
334    fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
335        let msg_1 = StartProtocol::StartSession(StartInitiation {
336            challenge: 0,
337            target: "127.0.0.1:1234".to_string(),
338            capabilities: Default::default(),
339            additional_data: 0x12345678,
340        });
341
342        let (tag, msg) = msg_1.clone().encode()?;
343        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
344        assert_eq!(tag, expected);
345
346        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
347
348        assert_eq!(msg_1, msg_2);
349        Ok(())
350    }
351
352    #[test]
353    fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
354        let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
355            challenge: 0,
356            target: "127.0.0.1:1234".to_string(),
357            capabilities: 0xff,
358            additional_data: 0xffffffff,
359        });
360
361        let len = msg.encode()?.1.len();
362        assert!(
363            HoprPacket::max_surbs_with_message(len) >= 1,
364            "StartSession message size ({len}) must allow for at least 1 SURBs in packet",
365        );
366
367        Ok(())
368    }
369
370    #[test]
371    fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
372        let msg_1 = StartProtocol::SessionEstablished(StartEstablished {
373            orig_challenge: 0,
374            session_id: 10_i32,
375        });
376
377        let (tag, msg) = msg_1.clone().encode()?;
378        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
379        assert_eq!(tag, expected);
380
381        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
382
383        assert_eq!(msg_1, msg_2);
384        Ok(())
385    }
386
387    #[test]
388    fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
389        let msg_1 = StartProtocol::SessionError(StartErrorType {
390            challenge: 10,
391            reason: StartErrorReason::NoSlotsAvailable,
392        });
393
394        let (tag, msg) = msg_1.clone().encode()?;
395        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
396        assert_eq!(tag, expected);
397
398        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
399
400        assert_eq!(msg_1, msg_2);
401        Ok(())
402    }
403
404    #[test]
405    fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
406        let msg_1 = StartProtocol::KeepAlive(KeepAliveMessage {
407            session_id: 10_i32,
408            flags: 0,
409            additional_data: 0xffffffff,
410        });
411
412        let (tag, msg) = msg_1.clone().encode()?;
413        let expected: Tag = StartProtocol::<(), (), ()>::START_PROTOCOL_MESSAGE_TAG;
414        assert_eq!(tag, expected);
415
416        let msg_2 = StartProtocol::<i32, String, u8>::decode(tag, &msg)?;
417
418        assert_eq!(msg_1, msg_2);
419        Ok(())
420    }
421
422    #[test]
423    fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
424        let msg = StartProtocol::<i32, String, u8>::StartSession(StartInitiation {
425            challenge: StartChallenge::MAX,
426            target: "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530"
427                .to_string(),
428            capabilities: 0x80,
429            additional_data: 0xffffffff,
430        });
431
432        assert!(
433            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
434            "StartSession must fit within {}",
435            HoprPacket::PAYLOAD_SIZE
436        );
437
438        let msg = StartProtocol::<String, String, u8>::SessionEstablished(StartEstablished {
439            orig_challenge: StartChallenge::MAX,
440            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
441        });
442
443        assert!(
444            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
445            "SessionEstablished must fit within {}",
446            HoprPacket::PAYLOAD_SIZE
447        );
448
449        let msg = StartProtocol::<String, String, u8>::SessionError(StartErrorType {
450            challenge: StartChallenge::MAX,
451            reason: StartErrorReason::NoSlotsAvailable,
452        });
453
454        assert!(
455            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
456            "SessionError must fit within {}",
457            HoprPacket::PAYLOAD_SIZE
458        );
459
460        let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
461            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
462            flags: 0xff,
463            additional_data: 0xffffffff,
464        });
465        assert!(
466            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
467            "KeepAlive must fit within {}",
468            HoprPacket::PAYLOAD_SIZE
469        );
470
471        Ok(())
472    }
473
474    #[test]
475    fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
476        let msg = StartProtocol::<String, String, u8>::KeepAlive(KeepAliveMessage {
477            session_id: "example-of-a-very-very-long-session-id-that-should-still-fit-the-packet".to_string(),
478            flags: 0xff,
479            additional_data: 0xffffffff,
480        });
481        let len = msg.encode()?.1.len();
482        assert_eq!(
483            KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
484            HoprPacket::MAX_SURBS_IN_PACKET
485        );
486        assert!(
487            HoprPacket::max_surbs_with_message(len) >= KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE,
488            "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
489            len,
490            KeepAliveMessage::<String>::MIN_SURBS_PER_MESSAGE
491        );
492
493        Ok(())
494    }
495}