Skip to main content

hopr_transport_session/
types.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    hash::{Hash, Hasher},
4    pin::Pin,
5    str::FromStr,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::{SinkExt, StreamExt, TryStreamExt};
11use hopr_network_types::{
12    prelude::SealedHost,
13    utils::{AsyncWriteSink, DuplexIO},
14};
15use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataIn, ApplicationDataOut, Tag};
16use hopr_protocol_session::{
17    AcknowledgementMode, AcknowledgementState, AcknowledgementStateConfig, ReliableSocket, SessionSocketConfig,
18    UnreliableSocket,
19};
20use hopr_protocol_start::StartProtocol;
21use hopr_types::{
22    internal::{prelude::HoprPseudonym, routing::DestinationRouting},
23    primitive::{
24        errors::GeneralError,
25        prelude::{BytesRepresentable, ToHex},
26    },
27};
28use tracing::{debug, instrument};
29
30use crate::{Capabilities, Capability, errors::TransportSessionError};
31
32/// Wrapper for [`Capabilities`] that makes conversion to/from `u8` possible.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub struct ByteCapabilities(pub Capabilities);
35
36impl TryFrom<u8> for ByteCapabilities {
37    type Error = GeneralError;
38
39    fn try_from(value: u8) -> Result<Self, Self::Error> {
40        Capabilities::new(value)
41            .map(Self)
42            .map_err(|_| GeneralError::ParseError("capabilities".into()))
43    }
44}
45
46impl From<ByteCapabilities> for u8 {
47    fn from(value: ByteCapabilities) -> Self {
48        *value.0.as_ref()
49    }
50}
51
52impl From<ByteCapabilities> for Capabilities {
53    fn from(value: ByteCapabilities) -> Self {
54        value.0
55    }
56}
57
58impl From<Capabilities> for ByteCapabilities {
59    fn from(value: Capabilities) -> Self {
60        Self(value)
61    }
62}
63
64impl AsRef<Capabilities> for ByteCapabilities {
65    fn as_ref(&self) -> &Capabilities {
66        &self.0
67    }
68}
69
70/// Start protocol instantiation for HOPR.
71pub type HoprStartProtocol = StartProtocol<SessionId, SessionTarget, ByteCapabilities>;
72
73/// Calculates the maximum number of decimal digits needed to represent an N-byte unsigned integer.
74///
75/// The calculation is based on the formula: ⌈8n × log_10(2)⌉
76/// where n is the number of bytes.
77const fn max_decimal_digits_for_n_bytes(n: usize) -> usize {
78    // log_10(2) = 0.301029995664 multiplied by 1 000 000 to work with integers in a const function
79    const LOG10_2_SCALED: u64 = 301030;
80    const SCALE: u64 = 1_000_000;
81
82    // 8n * log_10(2) scaled
83    let scaled = 8 * n as u64 * LOG10_2_SCALED;
84
85    scaled.div_ceil(SCALE) as usize
86}
87
88// Enough to fit HoprPseudonym in hex (with 0x prefix), delimiter and tag number
89const MAX_SESSION_ID_STR_LEN: usize = 2 + 2 * HoprPseudonym::SIZE + 1 + max_decimal_digits_for_n_bytes(Tag::SIZE);
90
91/// Unique ID of a specific Session in a certain direction.
92///
93/// Simple wrapper around the maximum range of the port like session unique identifier.
94/// It is a simple combination of an application tag for the Session and
95/// a [`HoprPseudonym`].
96#[derive(Clone, Copy)]
97pub struct SessionId {
98    tag: Tag,
99    pseudonym: HoprPseudonym,
100    // Since this SessionId is commonly represented as a string,
101    // we cache its string representation here.
102    // Also, by using a statically allocated ArrayString, we allow the SessionId to remain Copy.
103    // This representation is possibly truncated to MAX_SESSION_ID_STR_LEN.
104    // This member is always computed and is therefore not serialized.
105    cached: arrayvec::ArrayString<MAX_SESSION_ID_STR_LEN>,
106}
107
108impl SessionId {
109    const DELIMITER: char = ':';
110
111    pub fn new<T: Into<Tag>>(tag: T, pseudonym: HoprPseudonym) -> Self {
112        let tag = tag.into();
113        let mut cached = format!("{pseudonym}{}{tag}", Self::DELIMITER);
114        cached.truncate(MAX_SESSION_ID_STR_LEN);
115
116        Self {
117            tag,
118            pseudonym,
119            cached: cached.parse().expect("cannot fail due to truncation"),
120        }
121    }
122
123    pub fn tag(&self) -> Tag {
124        self.tag
125    }
126
127    pub fn pseudonym(&self) -> &HoprPseudonym {
128        &self.pseudonym
129    }
130
131    pub fn as_str(&self) -> &str {
132        &self.cached
133    }
134}
135
136impl FromStr for SessionId {
137    type Err = TransportSessionError;
138
139    fn from_str(s: &str) -> Result<Self, Self::Err> {
140        s.split_once(Self::DELIMITER)
141            .ok_or(TransportSessionError::InvalidSessionId)
142            .and_then(
143                |(pseudonym, tag)| match (HoprPseudonym::from_hex(pseudonym), Tag::from_str(tag)) {
144                    (Ok(p), Ok(t)) => Ok(Self::new(t, p)),
145                    _ => Err(TransportSessionError::InvalidSessionId),
146                },
147            )
148    }
149}
150
151impl serde::Serialize for SessionId {
152    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153    where
154        S: serde::Serializer,
155    {
156        use serde::ser::SerializeStruct;
157        let mut state = serializer.serialize_struct("SessionId", 2)?;
158        state.serialize_field("tag", &self.tag)?;
159        state.serialize_field("pseudonym", &self.pseudonym)?;
160        state.end()
161    }
162}
163
164impl<'de> serde::Deserialize<'de> for SessionId {
165    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
166    where
167        D: serde::Deserializer<'de>,
168    {
169        use serde::de;
170
171        #[derive(serde::Deserialize)]
172        #[serde(field_identifier, rename_all = "lowercase")]
173        enum Field {
174            Tag,
175            Pseudonym,
176        }
177
178        struct SessionIdVisitor;
179
180        impl<'de> de::Visitor<'de> for SessionIdVisitor {
181            type Value = SessionId;
182
183            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
184                formatter.write_str("struct SessionId")
185            }
186
187            fn visit_seq<A>(self, mut seq: A) -> Result<SessionId, A::Error>
188            where
189                A: de::SeqAccess<'de>,
190            {
191                Ok(SessionId::new(
192                    seq.next_element::<Tag>()?
193                        .ok_or_else(|| de::Error::invalid_length(0, &self))?,
194                    seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?,
195                ))
196            }
197
198            fn visit_map<V>(self, mut map: V) -> Result<SessionId, V::Error>
199            where
200                V: de::MapAccess<'de>,
201            {
202                let mut tag: Option<Tag> = None;
203                let mut pseudonym: Option<HoprPseudonym> = None;
204                while let Some(key) = map.next_key()? {
205                    match key {
206                        Field::Tag => {
207                            if tag.is_some() {
208                                return Err(de::Error::duplicate_field("tag"));
209                            }
210                            tag = Some(map.next_value()?);
211                        }
212                        Field::Pseudonym => {
213                            if pseudonym.is_some() {
214                                return Err(de::Error::duplicate_field("pseudonym"));
215                            }
216                            pseudonym = Some(map.next_value()?);
217                        }
218                    }
219                }
220
221                Ok(SessionId::new(
222                    tag.ok_or_else(|| de::Error::missing_field("tag"))?,
223                    pseudonym.ok_or_else(|| de::Error::missing_field("pseudonym"))?,
224                ))
225            }
226        }
227
228        const FIELDS: &[&str] = &["tag", "pseudonym"];
229        deserializer.deserialize_struct("SessionId", FIELDS, SessionIdVisitor)
230    }
231}
232
233impl Display for SessionId {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        write!(f, "{}", self.as_str())
236    }
237}
238
239impl Debug for SessionId {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        write!(f, "{}", self.as_str())
242    }
243}
244
245impl PartialEq for SessionId {
246    fn eq(&self, other: &Self) -> bool {
247        self.tag == other.tag && self.pseudonym == other.pseudonym
248    }
249}
250
251impl Eq for SessionId {}
252
253impl Hash for SessionId {
254    fn hash<H: Hasher>(&self, state: &mut H) {
255        self.tag.hash(state);
256        self.pseudonym.hash(state);
257    }
258}
259
260pub(crate) fn caps_to_ack_mode(caps: Capabilities) -> AcknowledgementMode {
261    if caps.contains(Capability::RetransmissionAck | Capability::RetransmissionNack) {
262        AcknowledgementMode::Both
263    } else if caps.contains(Capability::RetransmissionAck) {
264        AcknowledgementMode::Full
265    } else {
266        AcknowledgementMode::Partial
267    }
268}
269
270/// Indicates the closure reason of a [`HoprSession`].
271#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::Display)]
272pub enum ClosureReason {
273    /// Write-half of the Session has been closed.
274    WriteClosed,
275    /// Read-part of the Session has been closed (encountered empty read).
276    EmptyRead,
277    /// Session has been evicted from the cache due to inactivity or capacity reasons.
278    Eviction,
279}
280
281/// Helper trait to allow Box aliasing
282trait AsyncReadWrite: futures::AsyncWrite + futures::AsyncRead + Send + Unpin {}
283impl<T: futures::AsyncWrite + futures::AsyncRead + Send + Unpin> AsyncReadWrite for T {}
284
285/// Describes a node service target.
286/// These are specialized [`SessionTargets`](SessionTarget::ExitNode)
287/// that are local to the Exit node and have different purposes, such as Cover Traffic.
288///
289/// These targets cannot be [sealed](SealedHost) from the Entry node.
290pub type ServiceId = u32;
291
292/// Defines what should happen with the data at the recipient where the
293/// data from the established session are supposed to be forwarded to some `target`.
294#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
295pub enum SessionTarget {
296    /// Target is running over UDP with the given IP address and port.
297    UdpStream(SealedHost),
298    /// Target is running over TCP with the given address and port.
299    TcpStream(SealedHost),
300    /// Target is a service directly at the exit node with the given service ID.
301    ExitNode(ServiceId),
302}
303
304/// Wrapper for incoming [`HoprSession`] along with other information
305/// extracted from the Start protocol during the session establishment.
306#[derive(Debug)]
307pub struct IncomingSession {
308    /// Actual incoming session.
309    pub session: HoprSession,
310    /// Desired [target](SessionTarget) of the data received over the session.
311    pub target: SessionTarget,
312}
313
314/// Configures the Session protocol socket over HOPR.
315#[derive(Copy, Clone, Debug, PartialEq, Eq, smart_default::SmartDefault, serde::Serialize)]
316pub struct HoprSessionConfig {
317    /// Capabilities of the Session protocol socket.
318    ///
319    /// Default is no capabilities.
320    #[default(Capabilities::empty())]
321    pub capabilities: Capabilities,
322    /// Expected frame size of the Session protocol socket.
323    ///
324    /// Default is 1500.
325    #[default(1500)]
326    pub frame_mtu: usize,
327    /// Maximum amount of time an incomplete frame can be kept in the buffer.
328    ///
329    /// Default is 800 ms
330    #[default(Duration::from_millis(800))]
331    #[serde(with = "humantime_serde")]
332    pub frame_timeout: Duration,
333}
334
335/// Represents the Session protocol socket over HOPR.
336///
337/// This is essentially a HOPR-specific wrapper for [`ReliableSocket`] and [`UnreliableSocket`]
338/// Session protocol sockets.
339#[pin_project::pin_project]
340pub struct HoprSession {
341    id: SessionId,
342    #[pin]
343    inner: Box<dyn AsyncReadWrite>,
344    routing: DestinationRouting,
345    cfg: HoprSessionConfig,
346    on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
347}
348
349pub(crate) const SESSION_SOCKET_CAPACITY: usize = 16384;
350
351impl HoprSession {
352    /// Creates a new HOPR Session.
353    ///
354    /// It builds an [`futures::io::AsyncRead`] + [`futures::io::AsyncWrite`] transport
355    /// from the given `hopr` interface and passing it to the appropriate [`UnreliableSocket`] or [`ReliableSocket`]
356    /// based on the given `capabilities`.
357    ///
358    /// The `on_close` closure can be optionally called when the Session has been closed via `poll_close`.
359    #[tracing::instrument(skip_all, fields(id, routing, cfg, session_id = %id))]
360    pub fn new<Tx, Rx>(
361        id: SessionId,
362        routing: DestinationRouting,
363        cfg: HoprSessionConfig,
364        hopr: (Tx, Rx),
365        on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
366    ) -> Result<Self, TransportSessionError>
367    where
368        Tx: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Send + Sync + Unpin + 'static,
369        Rx: futures::Stream<Item = ApplicationDataIn> + Send + Sync + Unpin + 'static,
370        Tx::Error: std::error::Error + Send + Sync,
371    {
372        let routing_clone = routing.clone();
373
374        #[cfg(feature = "telemetry")]
375        let (session_id_write, session_id_read) = (id, id);
376
377        // Wrap the HOPR transport so that it appears as regular transport to the SessionSocket
378        let transport = DuplexIO(
379            AsyncWriteSink::<{ ApplicationData::PAYLOAD_SIZE }, _>(hopr.0.sink_map_err(std::io::Error::other).with(
380                move |buf: Box<[u8]>| {
381                    #[cfg(feature = "telemetry")]
382                    crate::telemetry::record_session_write(&session_id_write, buf.len());
383                    // The Session protocol does not set any packet info on outgoing packets.
384                    // However, the SessionManager on top usually overrides this.
385                    futures::future::ready(
386                        ApplicationData::new(id.tag(), buf.into_vec())
387                            .map(|data| (routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
388                            .map_err(std::io::Error::other),
389                    )
390                },
391            )),
392            // The Session protocol ignores the packet info on incoming packets.
393            // It is typically SessionManager's job to interpret those.
394            hopr.1
395                .map(move |data| {
396                    #[cfg(feature = "telemetry")]
397                    crate::telemetry::record_session_read(&session_id_read, data.data.plain_text.len());
398                    Ok::<_, std::io::Error>(data.data.plain_text)
399                })
400                .into_async_read(),
401        );
402
403        // Based on the requested capabilities, see if we should use the Session protocol
404        let inner: Box<dyn AsyncReadWrite> = if cfg.capabilities.contains(Capability::Segmentation) {
405            let socket_cfg = SessionSocketConfig {
406                frame_size: cfg.frame_mtu,
407                frame_timeout: cfg.frame_timeout,
408                capacity: SESSION_SOCKET_CAPACITY,
409                flush_immediately: cfg.capabilities.contains(Capability::NoDelay),
410                ..Default::default()
411            };
412
413            // Need to test the capabilities separately, because any Retransmission capability
414            // implies Segmentation, and therefore `is_disjoint` would fail
415            if cfg.capabilities.contains(Capability::RetransmissionAck)
416                || cfg.capabilities.contains(Capability::RetransmissionNack)
417            {
418                // TODO: update config values
419                let ack_cfg = AcknowledgementStateConfig {
420                    // This is a very coarse assumption, that a single 3-hop packet
421                    // takes on average 200 ms to deliver.
422                    // We can no longer base this timeout on the number of hops because
423                    // it is not known for SURB-based routing.
424                    expected_packet_latency: Duration::from_millis(200),
425                    mode: caps_to_ack_mode(cfg.capabilities),
426                    backoff_base: 0.2,
427                    max_incoming_frame_retries: 1,
428                    max_outgoing_frame_retries: 2,
429                    ..Default::default()
430                };
431
432                debug!(?socket_cfg, ?ack_cfg, "opening new stateful session socket");
433
434                Box::new(ReliableSocket::new(
435                    transport,
436                    AcknowledgementState::<{ ApplicationData::PAYLOAD_SIZE }>::new(id, ack_cfg),
437                    socket_cfg,
438                    #[cfg(feature = "telemetry")]
439                    id,
440                )?)
441            } else {
442                debug!(?socket_cfg, "opening new stateless session socket");
443
444                Box::new(UnreliableSocket::<{ ApplicationData::PAYLOAD_SIZE }>::new_stateless(
445                    id,
446                    transport,
447                    socket_cfg,
448                    #[cfg(feature = "telemetry")]
449                    id,
450                )?)
451            }
452        } else {
453            debug!("opening raw session socket");
454            Box::new(transport)
455        };
456
457        Ok(Self {
458            id,
459            inner,
460            routing,
461            cfg,
462            on_close,
463        })
464    }
465
466    /// ID of this Session.
467    pub fn id(&self) -> &SessionId {
468        &self.id
469    }
470
471    /// Routing options used to deliver data.
472    pub fn routing(&self) -> &DestinationRouting {
473        &self.routing
474    }
475
476    /// Configuration of this Session.
477    pub fn config(&self) -> &HoprSessionConfig {
478        &self.cfg
479    }
480}
481
482impl std::fmt::Debug for HoprSession {
483    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
484        f.debug_struct("Session")
485            .field("id", &self.id)
486            .field("routing", &self.routing)
487            .finish_non_exhaustive()
488    }
489}
490
491impl futures::AsyncRead for HoprSession {
492    #[instrument(name = "Session::poll_read", level = "trace", skip_all, fields(session_id = %self.id), ret)]
493    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
494        let this = self.project();
495        let read = futures::ready!(this.inner.poll_read(cx, buf))?;
496        if read == 0 {
497            tracing::trace!("hopr session empty read");
498            // Empty read signals end of the socket, notify if needed
499            if let Some(notifier) = this.on_close.take() {
500                tracing::trace!("notifying read half closure of session");
501                notifier(*this.id, ClosureReason::EmptyRead);
502            }
503        }
504        Poll::Ready(Ok(read))
505    }
506}
507
508impl futures::AsyncWrite for HoprSession {
509    #[instrument(name = "Session::poll_write", level = "trace", skip_all, fields(session_id = %self.id), ret)]
510    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
511        self.project().inner.poll_write(cx, buf)
512    }
513
514    #[instrument(name = "Session::poll_flush", level = "trace", skip_all, fields(session_id = %self.id), ret)]
515    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
516        self.project().inner.poll_flush(cx)
517    }
518
519    #[instrument(name = "Session::poll_close", level = "trace", skip_all, fields(session_id = %self.id), ret)]
520    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
521        let this = self.project();
522        futures::ready!(this.inner.poll_close(cx))?;
523        tracing::trace!("hopr session closed");
524
525        #[cfg(feature = "telemetry")]
526        crate::telemetry::set_session_state(this.id, crate::telemetry::SessionLifecycleState::Closing);
527
528        if let Some(notifier) = this.on_close.take() {
529            tracing::trace!("notifying write half closure of session");
530            notifier(*this.id, ClosureReason::WriteClosed);
531        }
532
533        Poll::Ready(Ok(()))
534    }
535}
536
537#[cfg(feature = "runtime-tokio")]
538impl tokio::io::AsyncRead for HoprSession {
539    fn poll_read(
540        mut self: Pin<&mut Self>,
541        cx: &mut Context<'_>,
542        buf: &mut tokio::io::ReadBuf<'_>,
543    ) -> Poll<std::io::Result<()>> {
544        let slice = buf.initialize_unfilled();
545        let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
546        buf.advance(n);
547        Poll::Ready(Ok(()))
548    }
549}
550
551#[cfg(feature = "runtime-tokio")]
552impl tokio::io::AsyncWrite for HoprSession {
553    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
554        futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
555    }
556
557    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
558        futures::AsyncWrite::poll_flush(self.as_mut(), cx)
559    }
560
561    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
562        futures::AsyncWrite::poll_close(self.as_mut(), cx)
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use anyhow::Context;
569    use futures::{AsyncReadExt, AsyncWriteExt};
570    use hopr_types::{
571        crypto::prelude::*, crypto_random::Randomizable, internal::routing::RoutingOptions, primitive::prelude::*,
572    };
573
574    use super::*;
575
576    // --- ByteCapabilities tests ---
577
578    #[test]
579    fn byte_capabilities_roundtrip_via_u8() -> anyhow::Result<()> {
580        let flags: Capabilities = Capability::Segmentation.into();
581        let caps = ByteCapabilities::from(flags);
582        let byte_val: u8 = caps.into();
583        let restored = ByteCapabilities::try_from(byte_val)?;
584        assert_eq!(caps, restored);
585        Ok(())
586    }
587
588    #[test]
589    fn byte_capabilities_invalid_bits_are_rejected() {
590        // 0xFF has bits set that don't correspond to any Capability
591        assert!(ByteCapabilities::try_from(0xFF_u8).is_err());
592    }
593
594    #[test]
595    fn byte_capabilities_empty_is_zero() {
596        let caps = ByteCapabilities::from(Capabilities::empty());
597        let byte_val: u8 = caps.into();
598        assert_eq!(byte_val, 0);
599    }
600
601    #[test]
602    fn byte_capabilities_combined_flags() -> anyhow::Result<()> {
603        let caps: Capabilities = Capability::Segmentation | Capability::NoRateControl;
604        let byte_caps = ByteCapabilities::from(caps);
605        let byte_val: u8 = byte_caps.into();
606        let restored = ByteCapabilities::try_from(byte_val)?;
607        assert_eq!(*restored.as_ref(), caps);
608        Ok(())
609    }
610
611    // --- caps_to_ack_mode tests ---
612
613    #[test]
614    fn caps_to_ack_mode_both_when_ack_and_nack() {
615        let caps: Capabilities = Capability::RetransmissionAck | Capability::RetransmissionNack;
616        assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Both);
617    }
618
619    #[test]
620    fn caps_to_ack_mode_full_when_only_ack() {
621        let caps: Capabilities = Capability::RetransmissionAck.into();
622        assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Full);
623    }
624
625    #[test]
626    fn caps_to_ack_mode_partial_when_no_retransmission() {
627        let caps: Capabilities = Capability::Segmentation.into();
628        assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Partial);
629    }
630
631    #[test]
632    fn caps_to_ack_mode_partial_when_empty() {
633        assert_eq!(caps_to_ack_mode(Capabilities::empty()), AcknowledgementMode::Partial);
634    }
635
636    #[test]
637    fn caps_to_ack_mode_should_be_partial_when_only_nack() {
638        let caps: Capabilities = Capability::RetransmissionNack.into();
639        assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Partial);
640    }
641
642    // --- ClosureReason tests ---
643
644    #[test]
645    fn closure_reason_display_values_are_stable() {
646        let reasons = [
647            ClosureReason::WriteClosed,
648            ClosureReason::EmptyRead,
649            ClosureReason::Eviction,
650        ];
651        insta::assert_debug_snapshot!(reasons);
652    }
653
654    // --- HoprSessionConfig tests ---
655
656    #[test]
657    fn hopr_session_config_default_snapshot() {
658        let cfg = HoprSessionConfig::default();
659        insta::assert_yaml_snapshot!(cfg);
660    }
661
662    // --- SessionTarget tests ---
663
664    #[test]
665    fn session_target_variants_debug_snapshot() -> anyhow::Result<()> {
666        let targets: Vec<SessionTarget> = vec![
667            SessionTarget::UdpStream(SealedHost::Plain(
668                "127.0.0.1:8080".parse().context("parsing UDP target")?,
669            )),
670            SessionTarget::TcpStream(SealedHost::Plain("10.0.0.1:443".parse().context("parsing TCP target")?)),
671            SessionTarget::ExitNode(42),
672        ];
673        insta::assert_debug_snapshot!(targets);
674        Ok(())
675    }
676
677    // --- SessionId edge cases ---
678
679    #[test]
680    fn session_id_from_str_rejects_missing_delimiter() {
681        assert!(SessionId::from_str("nodelmiter").is_err());
682    }
683
684    #[test]
685    fn session_id_from_str_rejects_invalid_pseudonym() {
686        assert!(SessionId::from_str("notahexvalue:1234").is_err());
687    }
688
689    #[test]
690    fn session_id_from_str_should_reject_invalid_tag() {
691        let pseudonym = HoprPseudonym::random();
692        let hex = pseudonym.to_hex();
693        let bad = format!("{hex}:not_a_number");
694        assert!(SessionId::from_str(&bad).is_err());
695    }
696
697    #[test]
698    fn session_id_display_and_debug_should_be_identical() {
699        let id = SessionId::new(42_u64, HoprPseudonym::random());
700        assert_eq!(format!("{id}"), format!("{id:?}"));
701    }
702
703    #[test]
704    fn session_id_hash_eq_consistency() {
705        use std::collections::HashSet;
706        let pseudonym = HoprPseudonym::random();
707        let id1 = SessionId::new(1234_u64, pseudonym);
708        let id2 = SessionId::new(1234_u64, pseudonym);
709        let id3 = SessionId::new(5678_u64, pseudonym);
710        let id4 = SessionId::new(1234_u64, HoprPseudonym::random());
711
712        let mut set = HashSet::new();
713        set.insert(id1);
714        assert!(set.contains(&id2));
715        assert!(!set.contains(&id3));
716        assert!(!set.contains(&id4), "same id but different pseudonym should not match");
717    }
718
719    // --- Existing tests ---
720
721    #[test]
722    fn test_session_id_to_str_from_str() -> anyhow::Result<()> {
723        let id = SessionId::new(1234_u64, HoprPseudonym::random());
724        assert_eq!(id.as_str(), id.to_string());
725        assert_eq!(id, SessionId::from_str(id.as_str())?);
726
727        Ok(())
728    }
729
730    #[test]
731    fn test_max_decimal_digits_for_n_bytes() {
732        assert_eq!(3, max_decimal_digits_for_n_bytes(size_of::<u8>()));
733        assert_eq!(5, max_decimal_digits_for_n_bytes(size_of::<u16>()));
734        assert_eq!(10, max_decimal_digits_for_n_bytes(size_of::<u32>()));
735        assert_eq!(20, max_decimal_digits_for_n_bytes(size_of::<u64>()));
736    }
737
738    #[test]
739    fn standard_session_id_must_fit_within_limit() {
740        let id = format!("{}:{}", SimplePseudonym::random(), Tag::Application(Tag::MAX));
741        assert!(id.len() <= MAX_SESSION_ID_STR_LEN);
742    }
743
744    #[test]
745    fn session_id_should_serialize_and_deserialize_correctly() -> anyhow::Result<()> {
746        let pseudonym = HoprPseudonym::random();
747        let tag: Tag = 1234u64.into();
748
749        let session_id_1 = SessionId::new(tag, pseudonym);
750        let data = serde_cbor_2::to_vec(&session_id_1)?;
751        let session_id_2: SessionId = serde_cbor_2::from_slice(&data)?;
752
753        assert_eq!(tag, session_id_2.tag());
754        assert_eq!(pseudonym, *session_id_2.pseudonym());
755
756        assert_eq!(session_id_1.as_str(), session_id_2.as_str());
757        assert_eq!(session_id_1, session_id_2);
758
759        Ok(())
760    }
761
762    #[test_log::test(tokio::test)]
763    async fn test_session_bidirectional_flow_without_segmentation() -> anyhow::Result<()> {
764        let dst: Address = (&ChainKeypair::random()).into();
765        let id = SessionId::new(1234_u64, HoprPseudonym::random());
766        const DATA_LEN: usize = 5000;
767
768        let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
769        let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
770
771        let mut alice_session = HoprSession::new(
772            id,
773            DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
774            Default::default(),
775            (
776                alice_tx,
777                alice_rx
778                    .map(|(_, data)| ApplicationDataIn {
779                        data: data.data,
780                        packet_info: Default::default(),
781                    })
782                    .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
783            ),
784            None,
785        )?;
786
787        let mut bob_session = HoprSession::new(
788            id,
789            DestinationRouting::Return(id.pseudonym().into()),
790            Default::default(),
791            (
792                bob_tx,
793                bob_rx
794                    .map(|(_, data)| ApplicationDataIn {
795                        data: data.data,
796                        packet_info: Default::default(),
797                    })
798                    .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
799            ),
800            None,
801        )?;
802
803        let alice_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
804        let bob_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
805
806        let mut bob_recv = [0u8; DATA_LEN];
807        let mut alice_recv = [0u8; DATA_LEN];
808
809        tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
810            .await
811            .context("alice write failed")?
812            .context("alice write timed out")?;
813        alice_session.flush().await?;
814
815        tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
816            .await
817            .context("bob write failed")?
818            .context("bob write timed out")?;
819        bob_session.flush().await?;
820
821        tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
822            .await
823            .context("bob read failed")?
824            .context("bob read timed out")?;
825
826        tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
827            .await
828            .context("alice read failed")?
829            .context("alice read timed out")?;
830
831        assert_eq!(&alice_sent, bob_recv.as_slice());
832        assert_eq!(bob_sent, alice_recv);
833
834        Ok(())
835    }
836
837    #[test_log::test(tokio::test)]
838    async fn test_session_bidirectional_flow_with_segmentation() -> anyhow::Result<()> {
839        let dst: Address = (&ChainKeypair::random()).into();
840        let id = SessionId::new(1234_u64, HoprPseudonym::random());
841        const DATA_LEN: usize = 5000;
842
843        let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
844        let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
845
846        let mut alice_session = HoprSession::new(
847            id,
848            DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
849            HoprSessionConfig {
850                capabilities: Capability::Segmentation.into(),
851                ..Default::default()
852            },
853            (
854                alice_tx,
855                alice_rx
856                    .map(|(_, data)| ApplicationDataIn {
857                        data: data.data,
858                        packet_info: Default::default(),
859                    })
860                    .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
861            ),
862            None,
863        )?;
864
865        let mut bob_session = HoprSession::new(
866            id,
867            DestinationRouting::Return(id.pseudonym().into()),
868            HoprSessionConfig {
869                capabilities: Capability::Segmentation.into(),
870                ..Default::default()
871            },
872            (
873                bob_tx,
874                bob_rx
875                    .map(|(_, data)| ApplicationDataIn {
876                        data: data.data,
877                        packet_info: Default::default(),
878                    })
879                    .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
880            ),
881            None,
882        )?;
883
884        let alice_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
885        let bob_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
886
887        let mut bob_recv = [0u8; DATA_LEN];
888        let mut alice_recv = [0u8; DATA_LEN];
889
890        tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
891            .await
892            .context("alice write failed")?
893            .context("alice write timed out")?;
894        alice_session.flush().await?;
895
896        tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
897            .await
898            .context("bob write failed")?
899            .context("bob write timed out")?;
900        bob_session.flush().await?;
901
902        tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
903            .await
904            .context("bob read failed")?
905            .context("bob read timed out")?;
906
907        tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
908            .await
909            .context("alice read failed")?
910            .context("alice read timed out")?;
911
912        assert_eq!(alice_sent, bob_recv);
913        assert_eq!(bob_sent, alice_recv);
914
915        Ok(())
916    }
917}