hopr_transport_session/
types.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    hash::{Hash, Hasher},
4    pin::Pin,
5    str::FromStr,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::{SinkExt, StreamExt, TryStreamExt};
11use hopr_internal_types::prelude::HoprPseudonym;
12use hopr_network_types::{
13    prelude::{DestinationRouting, SealedHost},
14    utils::{AsyncWriteSink, DuplexIO},
15};
16use hopr_primitive_types::{
17    errors::GeneralError,
18    prelude::{BytesRepresentable, ToHex},
19};
20use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataIn, ApplicationDataOut, Tag};
21use hopr_protocol_session::{
22    AcknowledgementMode, AcknowledgementState, AcknowledgementStateConfig, ReliableSocket, SessionSocketConfig,
23    UnreliableSocket,
24};
25use hopr_protocol_start::StartProtocol;
26use tracing::{debug, instrument};
27
28use crate::{Capabilities, Capability, errors::TransportSessionError};
29
30/// Wrapper for [`Capabilities`] that makes conversion to/from `u8` possible.
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub struct ByteCapabilities(pub Capabilities);
33
34impl TryFrom<u8> for ByteCapabilities {
35    type Error = GeneralError;
36
37    fn try_from(value: u8) -> Result<Self, Self::Error> {
38        Capabilities::new(value)
39            .map(Self)
40            .map_err(|_| GeneralError::ParseError("capabilities".into()))
41    }
42}
43
44impl From<ByteCapabilities> for u8 {
45    fn from(value: ByteCapabilities) -> Self {
46        *value.0.as_ref()
47    }
48}
49
50impl From<ByteCapabilities> for Capabilities {
51    fn from(value: ByteCapabilities) -> Self {
52        value.0
53    }
54}
55
56impl From<Capabilities> for ByteCapabilities {
57    fn from(value: Capabilities) -> Self {
58        Self(value)
59    }
60}
61
62impl AsRef<Capabilities> for ByteCapabilities {
63    fn as_ref(&self) -> &Capabilities {
64        &self.0
65    }
66}
67
68/// Start protocol instantiation for HOPR.
69pub type HoprStartProtocol = StartProtocol<SessionId, SessionTarget, ByteCapabilities>;
70
71/// Calculates the maximum number of decimal digits needed to represent an N-byte unsigned integer.
72///
73/// The calculation is based on the formula: ⌈8n × log_10(2)⌉
74/// where n is the number of bytes.
75const fn max_decimal_digits_for_n_bytes(n: usize) -> usize {
76    // log_10(2) = 0.301029995664 multiplied by 1 000 000 to work with integers in a const function
77    const LOG10_2_SCALED: u64 = 301030;
78    const SCALE: u64 = 1_000_000;
79
80    // 8n * log_10(2) scaled
81    let scaled = 8 * n as u64 * LOG10_2_SCALED;
82
83    scaled.div_ceil(SCALE) as usize
84}
85
86// Enough to fit HoprPseudonym in hex (with 0x prefix), delimiter and tag number
87const MAX_SESSION_ID_STR_LEN: usize = 2 + 2 * HoprPseudonym::SIZE + 1 + max_decimal_digits_for_n_bytes(Tag::SIZE);
88
89/// Unique ID of a specific Session in a certain direction.
90///
91/// Simple wrapper around the maximum range of the port like session unique identifier.
92/// It is a simple combination of an application tag for the Session and
93/// a [`HoprPseudonym`].
94#[derive(Clone, Copy)]
95pub struct SessionId {
96    tag: Tag,
97    pseudonym: HoprPseudonym,
98    // Since this SessionId is commonly represented as a string,
99    // we cache its string representation here.
100    // Also, by using a statically allocated ArrayString, we allow the SessionId to remain Copy.
101    // This representation is possibly truncated to MAX_SESSION_ID_STR_LEN.
102    // This member is always computed and is therefore not serialized.
103    cached: arrayvec::ArrayString<MAX_SESSION_ID_STR_LEN>,
104}
105
106impl SessionId {
107    const DELIMITER: char = ':';
108
109    pub fn new<T: Into<Tag>>(tag: T, pseudonym: HoprPseudonym) -> Self {
110        let tag = tag.into();
111        let mut cached = format!("{pseudonym}{}{tag}", Self::DELIMITER);
112        cached.truncate(MAX_SESSION_ID_STR_LEN);
113
114        Self {
115            tag,
116            pseudonym,
117            cached: cached.parse().expect("cannot fail due to truncation"),
118        }
119    }
120
121    pub fn tag(&self) -> Tag {
122        self.tag
123    }
124
125    pub fn pseudonym(&self) -> &HoprPseudonym {
126        &self.pseudonym
127    }
128
129    pub fn as_str(&self) -> &str {
130        &self.cached
131    }
132}
133
134impl FromStr for SessionId {
135    type Err = TransportSessionError;
136
137    fn from_str(s: &str) -> Result<Self, Self::Err> {
138        s.split_once(Self::DELIMITER)
139            .ok_or(TransportSessionError::InvalidSessionId)
140            .and_then(
141                |(pseudonym, tag)| match (HoprPseudonym::from_hex(pseudonym), Tag::from_str(tag)) {
142                    (Ok(p), Ok(t)) => Ok(Self::new(t, p)),
143                    _ => Err(TransportSessionError::InvalidSessionId),
144                },
145            )
146    }
147}
148
149impl serde::Serialize for SessionId {
150    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
151    where
152        S: serde::Serializer,
153    {
154        use serde::ser::SerializeStruct;
155        let mut state = serializer.serialize_struct("SessionId", 2)?;
156        state.serialize_field("tag", &self.tag)?;
157        state.serialize_field("pseudonym", &self.pseudonym)?;
158        state.end()
159    }
160}
161
162impl<'de> serde::Deserialize<'de> for SessionId {
163    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
164    where
165        D: serde::Deserializer<'de>,
166    {
167        use serde::de;
168
169        #[derive(serde::Deserialize)]
170        #[serde(field_identifier, rename_all = "lowercase")]
171        enum Field {
172            Tag,
173            Pseudonym,
174        }
175
176        struct SessionIdVisitor;
177
178        impl<'de> de::Visitor<'de> for SessionIdVisitor {
179            type Value = SessionId;
180
181            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182                formatter.write_str("struct SessionId")
183            }
184
185            fn visit_seq<A>(self, mut seq: A) -> Result<SessionId, A::Error>
186            where
187                A: de::SeqAccess<'de>,
188            {
189                Ok(SessionId::new(
190                    seq.next_element::<Tag>()?
191                        .ok_or_else(|| de::Error::invalid_length(0, &self))?,
192                    seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?,
193                ))
194            }
195
196            fn visit_map<V>(self, mut map: V) -> Result<SessionId, V::Error>
197            where
198                V: de::MapAccess<'de>,
199            {
200                let mut tag: Option<Tag> = None;
201                let mut pseudonym: Option<HoprPseudonym> = None;
202                while let Some(key) = map.next_key()? {
203                    match key {
204                        Field::Tag => {
205                            if tag.is_some() {
206                                return Err(de::Error::duplicate_field("tag"));
207                            }
208                            tag = Some(map.next_value()?);
209                        }
210                        Field::Pseudonym => {
211                            if pseudonym.is_some() {
212                                return Err(de::Error::duplicate_field("pseudonym"));
213                            }
214                            pseudonym = Some(map.next_value()?);
215                        }
216                    }
217                }
218
219                Ok(SessionId::new(
220                    tag.ok_or_else(|| de::Error::missing_field("tag"))?,
221                    pseudonym.ok_or_else(|| de::Error::missing_field("pseudonym"))?,
222                ))
223            }
224        }
225
226        const FIELDS: &[&str] = &["tag", "pseudonym"];
227        deserializer.deserialize_struct("SessionId", FIELDS, SessionIdVisitor)
228    }
229}
230
231impl Display for SessionId {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        write!(f, "{}", self.as_str())
234    }
235}
236
237impl Debug for SessionId {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        write!(f, "{}", self.as_str())
240    }
241}
242
243impl PartialEq for SessionId {
244    fn eq(&self, other: &Self) -> bool {
245        self.tag == other.tag && self.pseudonym == other.pseudonym
246    }
247}
248
249impl Eq for SessionId {}
250
251impl Hash for SessionId {
252    fn hash<H: Hasher>(&self, state: &mut H) {
253        self.tag.hash(state);
254        self.pseudonym.hash(state);
255    }
256}
257
258fn caps_to_ack_mode(caps: Capabilities) -> AcknowledgementMode {
259    if caps.contains(Capability::RetransmissionAck | Capability::RetransmissionNack) {
260        AcknowledgementMode::Both
261    } else if caps.contains(Capability::RetransmissionAck) {
262        AcknowledgementMode::Full
263    } else {
264        AcknowledgementMode::Partial
265    }
266}
267
268/// Indicates the closure reason of a [`HoprSession`].
269#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::Display)]
270pub enum ClosureReason {
271    /// Write-half of the Session has been closed.
272    WriteClosed,
273    /// Read-part of the Session has been closed (encountered empty read).
274    EmptyRead,
275    /// Session has been evicted from the cache due to inactivity or capacity reasons.
276    Eviction,
277}
278
279/// Helper trait to allow Box aliasing
280trait AsyncReadWrite: futures::AsyncWrite + futures::AsyncRead + Send + Unpin {}
281impl<T: futures::AsyncWrite + futures::AsyncRead + Send + Unpin> AsyncReadWrite for T {}
282
283/// Describes a node service target.
284/// These are specialized [`SessionTargets`](SessionTarget::ExitNode)
285/// that are local to the Exit node and have different purposes, such as Cover Traffic.
286///
287/// These targets cannot be [sealed](SealedHost) from the Entry node.
288pub type ServiceId = u32;
289
290/// Defines what should happen with the data at the recipient where the
291/// data from the established session are supposed to be forwarded to some `target`.
292#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
293pub enum SessionTarget {
294    /// Target is running over UDP with the given IP address and port.
295    UdpStream(SealedHost),
296    /// Target is running over TCP with the given address and port.
297    TcpStream(SealedHost),
298    /// Target is a service directly at the exit node with the given service ID.
299    ExitNode(ServiceId),
300}
301
302/// Wrapper for incoming [`HoprSession`] along with other information
303/// extracted from the Start protocol during the session establishment.
304#[derive(Debug)]
305pub struct IncomingSession {
306    /// Actual incoming session.
307    pub session: HoprSession,
308    /// Desired [target](SessionTarget) of the data received over the session.
309    pub target: SessionTarget,
310}
311
312/// Configures the Session protocol socket over HOPR.
313#[derive(Copy, Clone, Debug, PartialEq, Eq, smart_default::SmartDefault)]
314pub struct HoprSessionConfig {
315    /// Capabilities of the Session protocol socket.
316    ///
317    /// Default is no capabilities.
318    #[default(Capabilities::empty())]
319    pub capabilities: Capabilities,
320    /// Expected frame size of the Session protocol socket.
321    ///
322    /// Default is 1500.
323    #[default(1500)]
324    pub frame_mtu: usize,
325    /// Maximum amount of time an incomplete frame can be kept in the buffer.
326    ///
327    /// Default is 800 ms
328    #[default(Duration::from_millis(800))]
329    pub frame_timeout: Duration,
330}
331
332/// Represents the Session protocol socket over HOPR.
333///
334/// This is essentially a HOPR-specific wrapper for [`ReliableSocket`] and [`UnreliableSocket`]
335/// Session protocol sockets.
336#[pin_project::pin_project]
337pub struct HoprSession {
338    id: SessionId,
339    #[pin]
340    inner: Box<dyn AsyncReadWrite>,
341    routing: DestinationRouting,
342    cfg: HoprSessionConfig,
343    on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
344}
345
346impl HoprSession {
347    /// Creates a new HOPR Session.
348    ///
349    /// It builds an [`futures::io::AsyncRead`] + [`futures::io::AsyncWrite`] transport
350    /// from the given `hopr` interface and passing it to the appropriate [`UnreliableSocket`] or [`ReliableSocket`]
351    /// based on the given `capabilities`.
352    ///
353    /// The `on_close` closure can be optionally called when the Session has been closed via `poll_close`.
354    #[tracing::instrument(skip(hopr, on_close), fields(session_id = %id))]
355    pub fn new<Tx, Rx>(
356        id: SessionId,
357        routing: DestinationRouting,
358        cfg: HoprSessionConfig,
359        hopr: (Tx, Rx),
360        on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
361    ) -> Result<Self, TransportSessionError>
362    where
363        Tx: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Send + Sync + Unpin + 'static,
364        Rx: futures::Stream<Item = ApplicationDataIn> + Send + Sync + Unpin + 'static,
365        Tx::Error: std::error::Error + Send + Sync,
366    {
367        let routing_clone = routing.clone();
368
369        // Wrap the HOPR transport so that it appears as regular transport to the SessionSocket
370        let transport = DuplexIO(
371            AsyncWriteSink::<{ ApplicationData::PAYLOAD_SIZE }, _>(hopr.0.sink_map_err(std::io::Error::other).with(
372                move |buf: Box<[u8]>| {
373                    // The Session protocol does not set any packet info on outgoing packets.
374                    // However, the SessionManager on top usually overrides this.
375                    futures::future::ready(
376                        ApplicationData::new(id.tag(), buf.into_vec())
377                            .map(|data| (routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
378                            .map_err(std::io::Error::other),
379                    )
380                },
381            )),
382            // The Session protocol ignores the packet info on incoming packets.
383            // It is typically SessionManager's job to interpret those.
384            hopr.1
385                .map(|data| Ok::<_, std::io::Error>(data.data.plain_text))
386                .into_async_read(),
387        );
388
389        // Based on the requested capabilities, see if we should use the Session protocol
390        let inner: Box<dyn AsyncReadWrite> = if cfg.capabilities.contains(Capability::Segmentation) {
391            let socket_cfg = SessionSocketConfig {
392                frame_size: cfg.frame_mtu,
393                frame_timeout: cfg.frame_timeout,
394                capacity: 16384,
395                flush_immediately: cfg.capabilities.contains(Capability::NoDelay),
396                ..Default::default()
397            };
398
399            // Need to test the capabilities separately, because any Retransmission capability
400            // implies Segmentation, and therefore `is_disjoint` would fail
401            if cfg.capabilities.contains(Capability::RetransmissionAck)
402                || cfg.capabilities.contains(Capability::RetransmissionNack)
403            {
404                // TODO: update config values
405                let ack_cfg = AcknowledgementStateConfig {
406                    // This is a very coarse assumption, that a single 3-hop packet
407                    // takes on average 200 ms to deliver.
408                    // We can no longer base this timeout on the number of hops because
409                    // it is not known for SURB-based routing.
410                    expected_packet_latency: Duration::from_millis(200),
411                    mode: caps_to_ack_mode(cfg.capabilities),
412                    backoff_base: 0.2,
413                    max_incoming_frame_retries: 1,
414                    max_outgoing_frame_retries: 2,
415                    ..Default::default()
416                };
417
418                debug!(?socket_cfg, ?ack_cfg, "opening new stateful session socket");
419
420                Box::new(ReliableSocket::new(
421                    transport,
422                    AcknowledgementState::<{ ApplicationData::PAYLOAD_SIZE }>::new(id, ack_cfg),
423                    socket_cfg,
424                )?)
425            } else {
426                debug!(?socket_cfg, "opening new stateless session socket");
427
428                Box::new(UnreliableSocket::<{ ApplicationData::PAYLOAD_SIZE }>::new_stateless(
429                    id, transport, socket_cfg,
430                )?)
431            }
432        } else {
433            debug!("opening raw session socket");
434            Box::new(transport)
435        };
436
437        Ok(Self {
438            id,
439            inner,
440            routing,
441            cfg,
442            on_close,
443        })
444    }
445
446    /// ID of this Session.
447    pub fn id(&self) -> &SessionId {
448        &self.id
449    }
450
451    /// Routing options used to deliver data.
452    pub fn routing(&self) -> &DestinationRouting {
453        &self.routing
454    }
455
456    /// Configuration of this Session.
457    pub fn config(&self) -> &HoprSessionConfig {
458        &self.cfg
459    }
460}
461
462impl std::fmt::Debug for HoprSession {
463    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
464        f.debug_struct("Session")
465            .field("id", &self.id)
466            .field("routing", &self.routing)
467            .finish_non_exhaustive()
468    }
469}
470
471impl futures::AsyncRead for HoprSession {
472    #[instrument(name = "Session::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = %self.id), ret)]
473    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
474        let this = self.project();
475        let read = futures::ready!(this.inner.poll_read(cx, buf))?;
476        if read == 0 {
477            tracing::trace!("hopr session empty read");
478            // Empty read signals end of the socket, notify if needed
479            if let Some(notifier) = this.on_close.take() {
480                tracing::trace!("notifying read half closure of session");
481                notifier(*this.id, ClosureReason::EmptyRead);
482            }
483        }
484        Poll::Ready(Ok(read))
485    }
486}
487
488impl futures::AsyncWrite for HoprSession {
489    #[instrument(name = "Session::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = %self.id), ret)]
490    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
491        self.project().inner.poll_write(cx, buf)
492    }
493
494    #[instrument(name = "Session::poll_flush", level = "trace", skip(self, cx), fields(session_id = %self.id), ret)]
495    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
496        self.project().inner.poll_flush(cx)
497    }
498
499    #[instrument(name = "Session::poll_close", level = "trace", skip(self, cx), fields(session_id = %self.id), ret)]
500    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
501        let this = self.project();
502        futures::ready!(this.inner.poll_close(cx))?;
503        tracing::trace!("hopr session closed");
504
505        if let Some(notifier) = this.on_close.take() {
506            tracing::trace!("notifying write half closure of session");
507            notifier(*this.id, ClosureReason::WriteClosed);
508        }
509
510        Poll::Ready(Ok(()))
511    }
512}
513
514#[cfg(feature = "runtime-tokio")]
515impl tokio::io::AsyncRead for HoprSession {
516    fn poll_read(
517        mut self: Pin<&mut Self>,
518        cx: &mut Context<'_>,
519        buf: &mut tokio::io::ReadBuf<'_>,
520    ) -> Poll<std::io::Result<()>> {
521        let slice = buf.initialize_unfilled();
522        let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
523        buf.advance(n);
524        Poll::Ready(Ok(()))
525    }
526}
527
528#[cfg(feature = "runtime-tokio")]
529impl tokio::io::AsyncWrite for HoprSession {
530    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
531        futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
532    }
533
534    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
535        futures::AsyncWrite::poll_flush(self.as_mut(), cx)
536    }
537
538    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
539        futures::AsyncWrite::poll_close(self.as_mut(), cx)
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use anyhow::Context;
546    use futures::{AsyncReadExt, AsyncWriteExt};
547    use hopr_crypto_random::Randomizable;
548    use hopr_crypto_types::prelude::*;
549    use hopr_network_types::prelude::*;
550    use hopr_primitive_types::prelude::*;
551
552    use super::*;
553
554    #[test]
555    fn test_session_id_to_str_from_str() -> anyhow::Result<()> {
556        let id = SessionId::new(1234_u64, HoprPseudonym::random());
557        assert_eq!(id.as_str(), id.to_string());
558        assert_eq!(id, SessionId::from_str(id.as_str())?);
559
560        Ok(())
561    }
562
563    #[test]
564    fn test_max_decimal_digits_for_n_bytes() {
565        assert_eq!(3, max_decimal_digits_for_n_bytes(size_of::<u8>()));
566        assert_eq!(5, max_decimal_digits_for_n_bytes(size_of::<u16>()));
567        assert_eq!(10, max_decimal_digits_for_n_bytes(size_of::<u32>()));
568        assert_eq!(20, max_decimal_digits_for_n_bytes(size_of::<u64>()));
569    }
570
571    #[test]
572    fn standard_session_id_must_fit_within_limit() {
573        let id = format!("{}:{}", SimplePseudonym::random(), Tag::Application(Tag::MAX));
574        assert!(id.len() <= MAX_SESSION_ID_STR_LEN);
575    }
576
577    #[test]
578    fn session_id_should_serialize_and_deserialize_correctly() -> anyhow::Result<()> {
579        let pseudonym = HoprPseudonym::random();
580        let tag: Tag = 1234u64.into();
581
582        let session_id_1 = SessionId::new(tag, pseudonym);
583        let data = serde_cbor_2::to_vec(&session_id_1)?;
584        let session_id_2: SessionId = serde_cbor_2::from_slice(&data)?;
585
586        assert_eq!(tag, session_id_2.tag());
587        assert_eq!(pseudonym, *session_id_2.pseudonym());
588
589        assert_eq!(session_id_1.as_str(), session_id_2.as_str());
590        assert_eq!(session_id_1, session_id_2);
591
592        Ok(())
593    }
594
595    #[test_log::test(tokio::test)]
596    async fn test_session_bidirectional_flow_without_segmentation() -> anyhow::Result<()> {
597        let dst: Address = (&ChainKeypair::random()).into();
598        let id = SessionId::new(1234_u64, HoprPseudonym::random());
599        const DATA_LEN: usize = 5000;
600
601        let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
602        let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
603
604        let mut alice_session = HoprSession::new(
605            id,
606            DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
607            Default::default(),
608            (
609                alice_tx,
610                alice_rx
611                    .map(|(_, data)| ApplicationDataIn {
612                        data: data.data,
613                        packet_info: Default::default(),
614                    })
615                    .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
616            ),
617            None,
618        )?;
619
620        let mut bob_session = HoprSession::new(
621            id,
622            DestinationRouting::Return(id.pseudonym().into()),
623            Default::default(),
624            (
625                bob_tx,
626                bob_rx
627                    .map(|(_, data)| ApplicationDataIn {
628                        data: data.data,
629                        packet_info: Default::default(),
630                    })
631                    .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
632            ),
633            None,
634        )?;
635
636        let alice_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
637        let bob_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
638
639        let mut bob_recv = [0u8; DATA_LEN];
640        let mut alice_recv = [0u8; DATA_LEN];
641
642        tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
643            .await
644            .context("alice write failed")?
645            .context("alice write timed out")?;
646        alice_session.flush().await?;
647
648        tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
649            .await
650            .context("bob write failed")?
651            .context("bob write timed out")?;
652        bob_session.flush().await?;
653
654        tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
655            .await
656            .context("bob read failed")?
657            .context("bob read timed out")?;
658
659        tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
660            .await
661            .context("alice read failed")?
662            .context("alice read timed out")?;
663
664        assert_eq!(&alice_sent, bob_recv.as_slice());
665        assert_eq!(bob_sent, alice_recv);
666
667        Ok(())
668    }
669
670    #[test_log::test(tokio::test)]
671    async fn test_session_bidirectional_flow_with_segmentation() -> anyhow::Result<()> {
672        let dst: Address = (&ChainKeypair::random()).into();
673        let id = SessionId::new(1234_u64, HoprPseudonym::random());
674        const DATA_LEN: usize = 5000;
675
676        let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
677        let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
678
679        let mut alice_session = HoprSession::new(
680            id,
681            DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
682            HoprSessionConfig {
683                capabilities: Capability::Segmentation.into(),
684                ..Default::default()
685            },
686            (
687                alice_tx,
688                alice_rx
689                    .map(|(_, data)| ApplicationDataIn {
690                        data: data.data,
691                        packet_info: Default::default(),
692                    })
693                    .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
694            ),
695            None,
696        )?;
697
698        let mut bob_session = HoprSession::new(
699            id,
700            DestinationRouting::Return(id.pseudonym().into()),
701            HoprSessionConfig {
702                capabilities: Capability::Segmentation.into(),
703                ..Default::default()
704            },
705            (
706                bob_tx,
707                bob_rx
708                    .map(|(_, data)| ApplicationDataIn {
709                        data: data.data,
710                        packet_info: Default::default(),
711                    })
712                    .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
713            ),
714            None,
715        )?;
716
717        let alice_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
718        let bob_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
719
720        let mut bob_recv = [0u8; DATA_LEN];
721        let mut alice_recv = [0u8; DATA_LEN];
722
723        tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
724            .await
725            .context("alice write failed")?
726            .context("alice write timed out")?;
727        alice_session.flush().await?;
728
729        tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
730            .await
731            .context("bob write failed")?
732            .context("bob write timed out")?;
733        bob_session.flush().await?;
734
735        tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
736            .await
737            .context("bob read failed")?
738            .context("bob read timed out")?;
739
740        tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
741            .await
742            .context("alice read failed")?
743            .context("alice read timed out")?;
744
745        assert_eq!(alice_sent, bob_recv);
746        assert_eq!(bob_sent, alice_recv);
747
748        Ok(())
749    }
750}