hopr_transport_session/
initiation.rs

1//! This module defines the Start sub-protocol used for HOPR Session initiation and management.
2
3use hopr_transport_packet::prelude::{ApplicationData, ReservedTag, Tag};
4
5use crate::{Capabilities, errors::TransportSessionError, types::SessionTarget};
6
7/// Challenge that identifies a Start initiation protocol message.
8pub type StartChallenge = u64;
9
10/// Lists all Start protocol error reasons.
11#[repr(u8)]
12#[derive(Debug, Copy, Clone, PartialEq, Eq, strum::Display)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14pub enum StartErrorReason {
15    /// No more slots are available at the recipient.
16    NoSlotsAvailable,
17    /// Recipient is busy.
18    Busy,
19}
20
21/// Error message in the Start protocol.
22#[derive(Debug, Copy, Clone, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct StartErrorType {
25    /// Challenge that relates to this error.
26    pub challenge: StartChallenge,
27    /// The [reason](StartErrorReason) of this error.
28    pub reason: StartErrorReason,
29}
30
31/// The session initiation message of the Start protocol.
32#[derive(Debug, Clone, PartialEq, Eq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct StartInitiation {
35    /// Random challenge for this initiation.
36    pub challenge: StartChallenge,
37    /// [Target](SessionTarget) of the session, i.e., what should the other party do with the traffic.
38    pub target: SessionTarget,
39    /// Capabilities of the session.
40    pub capabilities: Capabilities,
41}
42
43/// Message of the Start protocol that confirms the establishment of a session.
44#[derive(Debug, Clone, PartialEq, Eq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct StartEstablished<T> {
47    /// Challenge that was used in the [initiation message](StartInitiation) to establish correspondence.
48    pub orig_challenge: StartChallenge,
49    /// Session ID that was selected by the recipient.
50    pub session_id: T,
51}
52
53#[cfg_attr(doc, aquamarine::aquamarine)]
54/// Lists all messages of the Start protocol for a session establishment
55/// with `T` as session identifier.
56///
57/// # Diagram of the protocol
58/// ```mermaid
59/// sequenceDiagram
60///     Entry->>Exit: SessionInitiation (Challenge)
61///     alt If Exit can accept a new session
62///     Note right of Exit: SessionID [Pseudonym, Tag]
63///     Exit->>Entry: SessionEstablished (Challenge, SessionID_Entry)
64///     Note left of Entry: SessionID [Pseudonym, Tag]
65///     Entry->>Exit: KeepAlive (SessionID)
66///     Note over Entry,Exit: Data
67///     Entry->>Exit: Close Session (SessionID)
68///     Exit->>Entry: Close Session (SessionID)
69///     else If Exit cannot accept a new session
70///     Exit->>Entry: SessionError (Challenge, Reason)
71///     end
72///     opt If initiation attempt times out
73///     Note left of Entry: Failure
74///     end
75/// ```
76// Do not implement Serialize,Deserialize -> enforce serialization via encode/decode
77#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants)]
78#[strum_discriminants(vis(pub(crate)))]
79#[strum_discriminants(derive(strum::FromRepr, strum::EnumCount), repr(u8))]
80pub enum StartProtocol<T> {
81    /// Request to initiate a new session.
82    StartSession(StartInitiation),
83    /// Confirmation that a new session has been established by the counterparty.
84    SessionEstablished(StartEstablished<T>),
85    /// Counterparty could not establish a new session due to an error.
86    SessionError(StartErrorType),
87    /// Counterparty has closed the session.
88    CloseSession(T),
89    /// A ping message to keep the session alive.
90    KeepAlive(KeepAliveMessage<T>),
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub struct KeepAliveMessage<T> {
96    /// Session ID.
97    pub id: T,
98    /// Reserved for future use, always zero currently.
99    pub flags: u8,
100}
101
102impl<T> From<T> for KeepAliveMessage<T> {
103    fn from(value: T) -> Self {
104        Self { id: value, flags: 0 }
105    }
106}
107
108impl<T> StartProtocol<T> {
109    pub(crate) const START_PROTOCOL_MESSAGE_TAG: Tag = Tag::Reserved(ReservedTag::SessionStart as u64);
110    const START_PROTOCOL_VERSION: u8 = 0x01;
111}
112
113// TODO: implement this without Serde, see #7145
114#[cfg(feature = "serde")]
115impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> StartProtocol<T> {
116    const SESSION_BINCODE_CONFIGURATION: bincode::config::Configuration = bincode::config::standard()
117        .with_little_endian()
118        .with_variable_int_encoding();
119
120    /// Serialize the message into a message tag and message data.
121    /// Data is serialized using `bincode`.
122    pub fn encode(self) -> crate::errors::Result<(Tag, Box<[u8]>)> {
123        let mut out = Vec::with_capacity(ApplicationData::PAYLOAD_SIZE);
124        out.push(Self::START_PROTOCOL_VERSION);
125        out.push(StartProtocolDiscriminants::from(&self) as u8);
126
127        match self {
128            StartProtocol::StartSession(init) => {
129                bincode::serde::encode_into_std_write(&init, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
130            }
131            StartProtocol::SessionEstablished(est) => {
132                bincode::serde::encode_into_std_write(&est, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
133            }
134            StartProtocol::SessionError(err) => {
135                bincode::serde::encode_into_std_write(err, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
136            }
137            StartProtocol::CloseSession(id) => {
138                bincode::serde::encode_into_std_write(&id, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
139            }
140            StartProtocol::KeepAlive(msg) => {
141                bincode::serde::encode_into_std_write(&msg, &mut out, Self::SESSION_BINCODE_CONFIGURATION)
142            }
143        }?;
144
145        Ok((Self::START_PROTOCOL_MESSAGE_TAG, out.into_boxed_slice()))
146    }
147
148    /// Deserialize the message from message tag and message data.
149    /// Data is deserialized using `bincode`.
150    pub fn decode(tag: Tag, data: &[u8]) -> crate::errors::Result<Self> {
151        if tag != Self::START_PROTOCOL_MESSAGE_TAG {
152            return Err(TransportSessionError::StartProtocolError("unknown message tag".into()));
153        }
154
155        if data.len() < 3 {
156            return Err(TransportSessionError::StartProtocolError("message too short".into()));
157        }
158
159        if data[0] != Self::START_PROTOCOL_VERSION {
160            return Err(TransportSessionError::StartProtocolError(
161                "unknown message version".into(),
162            ));
163        }
164
165        match StartProtocolDiscriminants::from_repr(data[1])
166            .ok_or(TransportSessionError::StartProtocolError("unknown message".into()))?
167        {
168            StartProtocolDiscriminants::StartSession => Ok(StartProtocol::StartSession(
169                bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
170                    .map(|(v, _bytes)| v)?,
171            )),
172            StartProtocolDiscriminants::SessionEstablished => Ok(StartProtocol::SessionEstablished(
173                bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
174                    .map(|(v, _bytes)| v)?,
175            )),
176            StartProtocolDiscriminants::SessionError => Ok(StartProtocol::SessionError(
177                bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
178                    .map(|(v, _bytes)| v)?,
179            )),
180            StartProtocolDiscriminants::CloseSession => Ok(StartProtocol::CloseSession(
181                bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
182                    .map(|(v, _bytes)| v)?,
183            )),
184            StartProtocolDiscriminants::KeepAlive => Ok(StartProtocol::KeepAlive(
185                bincode::serde::borrow_decode_from_slice(&data[2..], Self::SESSION_BINCODE_CONFIGURATION)
186                    .map(|(v, _bytes)| v)?,
187            )),
188        }
189    }
190}
191
192#[cfg(not(feature = "serde"))]
193impl<T> StartProtocol<T> {
194    pub fn encode(self) -> crate::errors::Result<(u16, Box<[u8]>)> {
195        unimplemented!()
196    }
197
198    pub fn decode(_tag: u16, _data: &[u8]) -> crate::errors::Result<Self> {
199        unimplemented!()
200    }
201}
202
203#[cfg(feature = "serde")]
204impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<StartProtocol<T>> for ApplicationData {
205    type Error = TransportSessionError;
206
207    fn try_from(value: StartProtocol<T>) -> Result<Self, Self::Error> {
208        let (application_tag, plain_text) = value.encode()?;
209        Ok(ApplicationData {
210            application_tag,
211            plain_text,
212        })
213    }
214}
215
216#[cfg(not(feature = "serde"))]
217impl<T> TryFrom<StartProtocol<T>> for ApplicationData {
218    type Error = TransportSessionError;
219
220    fn try_from(value: StartProtocol<T>) -> Result<Self, Self::Error> {
221        let (application_tag, plain_text) = value.encode()?;
222        Ok(ApplicationData {
223            application_tag,
224            plain_text,
225        })
226    }
227}
228
229#[cfg(feature = "serde")]
230impl<T: serde::Serialize + for<'de> serde::Deserialize<'de>> TryFrom<ApplicationData> for StartProtocol<T> {
231    type Error = TransportSessionError;
232
233    fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
234        Self::decode(value.application_tag, &value.plain_text)
235    }
236}
237
238#[cfg(not(feature = "serde"))]
239impl<T> TryFrom<ApplicationData> for StartProtocol<T> {
240    type Error = TransportSessionError;
241
242    fn try_from(value: ApplicationData) -> Result<Self, Self::Error> {
243        Self::decode(value.application_tag, &value.plain_text)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use hopr_crypto_packet::prelude::HoprPacket;
250    use hopr_crypto_random::Randomizable;
251    use hopr_internal_types::prelude::HoprPseudonym;
252    use hopr_network_types::prelude::SealedHost;
253    use hopr_transport_packet::prelude::Tag;
254
255    use super::*;
256    use crate::{Capability, SessionId};
257
258    #[cfg(feature = "serde")]
259    #[test]
260    fn start_protocol_start_session_message_should_encode_and_decode() -> anyhow::Result<()> {
261        let msg_1 = StartProtocol::<i32>::StartSession(StartInitiation {
262            challenge: 0,
263            target: SessionTarget::TcpStream(SealedHost::Plain("127.0.0.1:1234".parse()?)),
264            capabilities: Default::default(),
265        });
266
267        let (tag, msg) = msg_1.clone().encode()?;
268        let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
269        assert_eq!(tag, expected);
270
271        let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
272
273        assert_eq!(msg_1, msg_2);
274        Ok(())
275    }
276
277    #[test]
278    fn start_protocol_message_start_session_message_should_allow_for_at_least_one_surb() -> anyhow::Result<()> {
279        let msg = StartProtocol::<SessionId>::StartSession(StartInitiation {
280            challenge: 0,
281            target: SessionTarget::TcpStream(SealedHost::Plain("127.0.0.1:1234".parse()?)),
282            capabilities: Default::default(),
283        });
284
285        let len = msg.encode()?.1.len();
286        assert!(
287            HoprPacket::max_surbs_with_message(len) >= 1,
288            "KeepAlive message size ({}) must allow for at least 1 SURBs in packet",
289            len,
290        );
291
292        Ok(())
293    }
294
295    #[cfg(feature = "serde")]
296    #[test]
297    fn start_protocol_session_established_message_should_encode_and_decode() -> anyhow::Result<()> {
298        let msg_1 = StartProtocol::<i32>::SessionEstablished(StartEstablished {
299            orig_challenge: 0,
300            session_id: 10,
301        });
302
303        let (tag, msg) = msg_1.clone().encode()?;
304        let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
305        assert_eq!(tag, expected);
306
307        let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
308
309        assert_eq!(msg_1, msg_2);
310        Ok(())
311    }
312
313    #[cfg(feature = "serde")]
314    #[test]
315    fn start_protocol_session_error_message_should_encode_and_decode() -> anyhow::Result<()> {
316        let msg_1 = StartProtocol::<i32>::SessionError(StartErrorType {
317            challenge: 10,
318            reason: StartErrorReason::NoSlotsAvailable,
319        });
320
321        let (tag, msg) = msg_1.clone().encode()?;
322        let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
323        assert_eq!(tag, expected);
324
325        let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
326
327        assert_eq!(msg_1, msg_2);
328        Ok(())
329    }
330
331    #[cfg(feature = "serde")]
332    #[test]
333    fn start_protocol_close_session_message_should_encode_and_decode() -> anyhow::Result<()> {
334        let msg_1 = StartProtocol::<i32>::CloseSession(10);
335
336        let (tag, msg) = msg_1.clone().encode()?;
337        let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
338        assert_eq!(tag, expected);
339
340        let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
341
342        assert_eq!(msg_1, msg_2);
343        Ok(())
344    }
345
346    #[cfg(feature = "serde")]
347    #[test]
348    fn start_protocol_keep_alive_message_should_encode_and_decode() -> anyhow::Result<()> {
349        let msg_1 = StartProtocol::<i32>::KeepAlive(10.into());
350
351        let (tag, msg) = msg_1.clone().encode()?;
352        let expected: Tag = StartProtocol::<()>::START_PROTOCOL_MESSAGE_TAG;
353        assert_eq!(tag, expected);
354
355        let msg_2 = StartProtocol::<i32>::decode(tag, &msg)?;
356
357        assert_eq!(msg_1, msg_2);
358        Ok(())
359    }
360
361    #[cfg(feature = "serde")]
362    #[test]
363    fn start_protocol_messages_must_fit_within_hopr_packet() -> anyhow::Result<()> {
364        let msg = StartProtocol::<SessionId>::StartSession(StartInitiation {
365            challenge: StartChallenge::MAX,
366            target: SessionTarget::TcpStream(SealedHost::Plain(
367                "example-of-a-very-very-long-second-level-name.on-a-very-very-long-domain-name.info:65530".parse()?,
368            )),
369            capabilities: Capability::RetransmissionAck | Capability::RetransmissionNack | Capability::Segmentation,
370        });
371
372        assert!(
373            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
374            "StartSession must fit within {}",
375            HoprPacket::PAYLOAD_SIZE
376        );
377
378        let msg = StartProtocol::SessionEstablished(StartEstablished {
379            orig_challenge: StartChallenge::MAX,
380            session_id: SessionId::new(Tag::MAX, HoprPseudonym::random()),
381        });
382
383        assert!(
384            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
385            "SessionEstablished must fit within {}",
386            HoprPacket::PAYLOAD_SIZE
387        );
388
389        let msg = StartProtocol::<i32>::SessionError(StartErrorType {
390            challenge: StartChallenge::MAX,
391            reason: StartErrorReason::NoSlotsAvailable,
392        });
393
394        assert!(
395            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
396            "SessionError must fit within {}",
397            HoprPacket::PAYLOAD_SIZE
398        );
399
400        let msg = StartProtocol::CloseSession(SessionId::new(Tag::MAX, HoprPseudonym::random()));
401        assert!(
402            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
403            "CloseSession must fit within {}",
404            HoprPacket::PAYLOAD_SIZE
405        );
406
407        let msg = StartProtocol::KeepAlive(SessionId::new(Tag::MAX, HoprPseudonym::random()).into());
408        assert!(
409            msg.encode()?.1.len() <= HoprPacket::PAYLOAD_SIZE,
410            "KeepAlive must fit within {}",
411            HoprPacket::PAYLOAD_SIZE
412        );
413
414        Ok(())
415    }
416
417    #[test]
418    fn start_protocol_message_keep_alive_message_should_allow_for_maximum_surbs() -> anyhow::Result<()> {
419        let msg = StartProtocol::KeepAlive(SessionId::new(Tag::MAX, HoprPseudonym::random()).into());
420        let len = msg.encode()?.1.len();
421        assert!(
422            HoprPacket::max_surbs_with_message(len) >= HoprPacket::MAX_SURBS_IN_PACKET,
423            "KeepAlive message size ({}) must allow for at least {} SURBs in packet",
424            len,
425            HoprPacket::MAX_SURBS_IN_PACKET
426        );
427
428        Ok(())
429    }
430}