hopr_protocol_session/socket/
mod.rs

1//! This module defines the socket-like interface for Session protocol.
2
3pub mod ack_state;
4pub mod state;
5
6use std::{
7    pin::Pin,
8    sync::{Arc, atomic::AtomicU32},
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt, future, future::AbortHandle};
14use futures_concurrency::stream::Merge;
15use state::SocketState;
16use tracing::{Instrument, instrument};
17
18use crate::{
19    errors::SessionError,
20    processing::{ReassemblerExt, SegmenterExt, SequencerExt, types::FrameInspector},
21    protocol::{OrderedFrame, SegmentRequest, SeqIndicator, SessionCodec, SessionMessage},
22    socket::state::{SocketComponents, Stateless},
23};
24
25/// Configuration object for [`SessionSocket`].
26#[derive(Debug, Copy, Clone, Eq, PartialEq, smart_default::SmartDefault)]
27pub struct SessionSocketConfig {
28    /// The maximum size of a frame on the read/write interface of the [`SessionSocket`].
29    ///
30    /// The size is always greater or equal to the MTU `C` of the underlying transport, and
31    /// less or equal to:
32    /// - (`C` - `SessionMessage::SEGMENT_OVERHEAD`) * (`SeqIndicator::MAX` + 1) for stateless sockets, or
33    /// - (`C` - `SessionMessage::SEGMENT_OVERHEAD`) * min(`SeqIndicator::MAX` + 1,
34    ///   `SegmentRequest::MAX_MISSING_SEGMENTS_PER_FRAME`) for stateful sockets
35    ///
36    /// Default is 1500 bytes.
37    #[default(1500)]
38    pub frame_size: usize,
39    /// The maximum time to wait for a frame to be fully received.
40    ///
41    /// Default is 800 ms.
42    #[default(Duration::from_millis(800))]
43    pub frame_timeout: Duration,
44    /// Maximum number of segments to buffer in the downstream transport.
45    /// If 0 is given, the transport is unbuffered.
46    ///
47    /// Default is 0.
48    #[default(0)]
49    pub max_buffered_segments: usize,
50    /// Capacity of the frame reconstructor, the maximum number of incomplete frames, before
51    /// they are dropped.
52    ///
53    /// Default is 8192.
54    #[default(8192)]
55    pub capacity: usize,
56
57    /// Flushes data written to the socket immediately to the underlying transport.
58    ///
59    /// Default is false.
60    #[default(false)]
61    pub flush_immediately: bool,
62}
63
64enum WriteState {
65    WriteOnly,
66    Writing,
67    Flushing(usize),
68}
69
70/// Socket-like object implementing the Session protocol that can operate on any transport that
71/// implements [`futures::io::AsyncRead`] and [`futures::io::AsyncWrite`].
72///
73/// The [`SocketState`] `S` given during instantiation can facilitate reliable or unreliable
74/// behavior (see [`AcknowledgementState`](ack_state::AcknowledgementState))
75///
76/// The constant argument `C` specifies the MTU in bytes of the underlying transport.
77#[pin_project::pin_project]
78pub struct SessionSocket<const C: usize, S> {
79    // This is where upstream writes the to-be-segmented frame data to
80    upstream_frames_in: Pin<Box<dyn futures::io::AsyncWrite + Send>>,
81    // This is where upstream reads the reconstructed frame data from
82    downstream_frames_out: Pin<Box<dyn futures::io::AsyncRead + Send>>,
83    state: S,
84    write_state: WriteState,
85}
86
87impl<const C: usize> SessionSocket<C, Stateless<C>> {
88    /// Creates a new stateless socket suitable for fast UDP-like communication.
89    ///
90    /// Note that this results in a faster socket than if created via [`SessionSocket::new`] with
91    /// [`Stateless`]. This is because the frame inspector does not need to be instantiated.
92    pub fn new_stateless<T, I>(id: I, transport: T, cfg: SessionSocketConfig) -> Result<Self, SessionError>
93    where
94        T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
95        I: std::fmt::Display + Clone,
96    {
97        // The maximum frame size in a stateless socket is only bounded by the size of the SeqIndicator
98        let frame_size = cfg.frame_size.clamp(
99            C,
100            (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
101        );
102
103        // Segment data incoming/outgoing using underlying transport
104        let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
105
106        // Check if we allow sending multiple segments to downstream in a single write
107        // The HWM cannot be 0 bytes
108        framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
109
110        // Downstream transport
111        let (packets_out, packets_in) = framed.split();
112
113        // Pipeline IN: Data incoming from Upstream
114        let upstream_frames_in = packets_out
115            .with(|segment| future::ok::<_, SessionError>(SessionMessage::<C>::Segment(segment)))
116            .segmenter_with_terminating_segment::<C>(frame_size);
117
118        let last_emitted_frame = Arc::new(AtomicU32::new(0));
119        let last_emitted_frame_clone = last_emitted_frame.clone();
120
121        let session_id_1 = id.to_string();
122        let session_id_2 = id.to_string();
123        let session_id_3 = id.to_string();
124
125        let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
126
127        // Pipeline OUT: Packets incoming from Downstream
128        // Continue receiving packets from downstream, unless we received a terminating frame.
129        // Once the terminating frame is received, the `packets_in_abort_handle` is triggered, terminating the pipeline.
130        let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
131            // Filter-out segments that we've seen already
132            .filter_map(move |packet| {
133                futures::future::ready(match packet {
134                    Ok(packet) => packet.try_as_segment().filter(|s| {
135                        // Filter old frame ids to save space in the Reassembler
136                        let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
137                        if s.frame_id <= last_emitted_id {
138                            tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
139                            false
140                        } else {
141                            true
142                        }
143                    }),
144                    Err(error) => {
145                        tracing::error!(%error, "unparseable packet");
146                        None
147                    }
148                })
149                .instrument(tracing::debug_span!(
150                    "SessionSocket::packets_in::pre_reassembly",
151                    session_id = session_id_1
152                ))
153            })
154            // Reassemble the segments into frames
155            .reassembler(cfg.frame_timeout, cfg.capacity)
156            // Discard frames that we could not reassemble
157            .filter_map(move |maybe_frame| {
158                futures::future::ready(match maybe_frame {
159                    Ok(frame) => Some(OrderedFrame(frame)),
160                    Err(error) => {
161                        tracing::error!(%error, "failed to reassemble frame");
162                        None
163                    }
164                })
165                .instrument(tracing::debug_span!(
166                    "SessionSocket::packets_in::pre_sequencing",
167                    session_id = session_id_2
168                ))
169            })
170            // Put the frames into the correct sequence by Frame Ids
171            .sequencer(cfg.frame_timeout, cfg.capacity)
172            // Discard frames missing from the sequence
173            .filter_map(move |maybe_frame| {
174                future::ready(match maybe_frame {
175                    Ok(frame) => {
176                        last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
177                        if frame.0.is_terminating {
178                            tracing::warn!("terminating frame received");
179                            packets_in_abort_handle.abort();
180                        }
181                        Some(Ok(frame.0))
182                    }
183                    // Downstream skips discarded frames
184                    Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
185                        tracing::error!(frame_id, "frame discarded");
186                        None
187                    }
188                    Err(err) => Some(Err(std::io::Error::other(err))),
189                })
190                .instrument(tracing::debug_span!(
191                    "SessionSocket::packets_in::post_sequencing",
192                    session_id = session_id_3
193                ))
194            })
195            .into_async_read();
196
197        Ok(Self {
198            state: Stateless::new(id),
199            upstream_frames_in: Box::pin(upstream_frames_in),
200            downstream_frames_out: Box::pin(downstream_frames_out),
201            write_state: if cfg.flush_immediately {
202                WriteState::Writing
203            } else {
204                WriteState::WriteOnly
205            },
206        })
207    }
208}
209
210impl<const C: usize, S: SocketState<C> + Clone + 'static> SessionSocket<C, S> {
211    /// Creates a stateful socket with frame inspection capabilities - suitable for communication
212    /// requiring TCP-like delivery guarantees.
213    pub fn new<T>(transport: T, mut state: S, cfg: SessionSocketConfig) -> Result<Self, SessionError>
214    where
215        T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
216    {
217        // The maximum frame size is reduced due to the size of the missing segment bitmap in SegmentRequests
218        let frame_size = cfg.frame_size.clamp(
219            C,
220            (C - SessionMessage::<C>::SEGMENT_OVERHEAD)
221                * SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME.min((SeqIndicator::MAX + 1) as usize),
222        );
223
224        // Segment data incoming/outgoing using underlying transport
225        let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
226
227        // Check if we allow sending multiple segments to downstream in a single write
228        // The HWM cannot be 0 bytes
229        framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
230
231        // Downstream transport
232        let (packets_out, packets_in) = framed.split();
233
234        let inspector = FrameInspector::new(cfg.capacity);
235
236        let ctl_channel_capacity = std::env::var("HOPR_INTERNAL_SESSION_CTL_CHANNEL_CAPACITY")
237            .ok()
238            .and_then(|s| s.trim().parse::<usize>().ok())
239            .filter(|&c| c > 0)
240            .unwrap_or(2048);
241
242        tracing::debug!(capacity = ctl_channel_capacity, "Creating session control channel");
243        let (ctl_tx, ctl_rx) = futures::channel::mpsc::channel(ctl_channel_capacity);
244        state.run(SocketComponents {
245            inspector: Some(inspector.clone()),
246            ctl_tx,
247        })?;
248
249        // Pipeline IN: Data incoming from Upstream
250        let (segments_tx, segments_rx) = futures::channel::mpsc::channel(cfg.capacity);
251        let mut st_1 = state.clone();
252        let upstream_frames_in = segments_tx
253            .with(move |segment| {
254                // The segment_sent event is raised only for segments coming from Upstream,
255                // not for the segments from the Control stream (= segment resends).
256                if let Err(error) = st_1.segment_sent(&segment) {
257                    tracing::debug!(%error, "outgoing segment state update failed");
258                }
259                future::ok::<_, futures::channel::mpsc::SendError>(SessionMessage::<C>::Segment(segment))
260            })
261            .segmenter_with_terminating_segment::<C>(frame_size);
262
263        // We have to merge the streams here and spawn a special task for it
264        // Since the control messages from the State can come independent of Upstream writes.
265        hopr_async_runtime::prelude::spawn(
266            (ctl_rx, segments_rx)
267                .merge()
268                .map(Ok)
269                .forward(packets_out)
270                .map(move |result| match result {
271                    Ok(_) => tracing::debug!("outgoing packet processing done"),
272                    Err(error) => {
273                        tracing::error!(%error, "error while processing outgoing packets")
274                    }
275                })
276                .instrument(tracing::debug_span!(
277                    "SessionSocket::packets_out",
278                    session_id = state.session_id()
279                )),
280        );
281
282        let last_emitted_frame = Arc::new(AtomicU32::new(0));
283        let last_emitted_frame_clone = last_emitted_frame.clone();
284
285        let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
286
287        // Pipeline OUT: Packets incoming from Downstream
288        let mut st_1 = state.clone();
289        let mut st_2 = state.clone();
290        let mut st_3 = state.clone();
291
292        // Continue receiving packets from downstream, unless we received a terminating frame.
293        // Once the terminating frame is received, the `packets_in_abort_handle` is triggered, terminating the pipeline.
294        let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
295            // Filter out Session control messages and update the State, pass only Segments onwards
296            .filter_map(move |packet| {
297                futures::future::ready(match packet {
298                    Ok(packet) => {
299                        if let Err(error) = match &packet {
300                            SessionMessage::Segment(s) => st_1.incoming_segment(&s.id(), s.seq_flags),
301                            SessionMessage::Request(r) => st_1.incoming_retransmission_request(r.clone()),
302                            SessionMessage::Acknowledge(a) => st_1.incoming_acknowledged_frames(a.clone()),
303                        } {
304                            tracing::debug!(%error, "incoming message state update failed");
305                        }
306                        // Filter old frame ids to save space in the Reassembler
307                        packet.try_as_segment().filter(|s| {
308                            let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
309                            if s.frame_id <= last_emitted_id {
310                                tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
311                                false
312                            } else {
313                                true
314                            }
315                        })
316                    }
317                    Err(error) => {
318                        tracing::error!(%error, "unparseable packet");
319                        None
320                    }
321                })
322                .instrument(tracing::debug_span!(
323                    "SessionSocket::packets_in::pre_reassembly",
324                    session_id = st_1.session_id()
325                ))
326            })
327            // Reassemble segments into frames
328            .reassembler_with_inspector(cfg.frame_timeout, cfg.capacity, inspector)
329            // Notify State once a frame has been reassembled, discard frames that we could not reassemble
330            .filter_map(move |maybe_frame| {
331                futures::future::ready(match maybe_frame {
332                    Ok(frame) => {
333                        if let Err(error) = st_2.frame_complete(frame.frame_id) {
334                            tracing::error!(%error, "frame complete state update failed");
335                        }
336                        Some(OrderedFrame(frame))
337                    }
338                    Err(error) => {
339                        tracing::error!(%error, "failed to reassemble frame");
340                        None
341                    }
342                })
343                .instrument(tracing::debug_span!(
344                    "SessionSocket::packets_in::pre_sequencing",
345                    session_id = st_2.session_id()
346                ))
347            })
348            // Put the frames into the correct sequence by Frame Ids
349            .sequencer(cfg.frame_timeout, cfg.frame_size)
350            // Discard frames missing from the sequence and
351            // notify the State about emitted or discarded frames
352            .filter_map(move |maybe_frame| {
353                // Filter out discarded Frames and dispatch events to the State if needed
354                future::ready(match maybe_frame {
355                    Ok(frame) => {
356                        if let Err(error) = st_3.frame_emitted(frame.0.frame_id) {
357                            tracing::error!(%error, "frame received state update failed");
358                        }
359                        last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
360                        if frame.0.is_terminating {
361                            tracing::warn!("terminating frame received");
362                            packets_in_abort_handle.abort();
363                        }
364                        Some(Ok(frame.0))
365                    }
366                    Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
367                        if let Err(error) = st_3.frame_discarded(frame_id) {
368                            tracing::error!(%error, "frame discarded state update failed");
369                        }
370                        None // Downstream skips discarded frames
371                    }
372                    Err(err) => Some(Err(std::io::Error::other(err))),
373                })
374                .instrument(tracing::debug_span!(
375                    "SessionSocket::packets_in::post_sequencing",
376                    session_id = st_3.session_id()
377                ))
378            })
379            .into_async_read();
380
381        Ok(Self {
382            state,
383            upstream_frames_in: Box::pin(upstream_frames_in),
384            downstream_frames_out: Box::pin(downstream_frames_out),
385            write_state: if cfg.flush_immediately {
386                WriteState::Writing
387            } else {
388                WriteState::WriteOnly
389            },
390        })
391    }
392}
393
394impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncRead for SessionSocket<C, S> {
395    #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
396    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
397        self.project().downstream_frames_out.as_mut().poll_read(cx, buf)
398    }
399}
400
401impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncWrite for SessionSocket<C, S> {
402    #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
403    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
404        let this = self.project();
405        loop {
406            match this.write_state {
407                WriteState::WriteOnly => {
408                    return this.upstream_frames_in.as_mut().poll_write(cx, buf);
409                }
410                WriteState::Writing => {
411                    let len = futures::ready!(this.upstream_frames_in.as_mut().poll_write(cx, buf))?;
412                    *this.write_state = WriteState::Flushing(len);
413                }
414                WriteState::Flushing(len) => {
415                    let res = futures::ready!(this.upstream_frames_in.as_mut().poll_flush(cx)).map(|_| *len);
416                    *this.write_state = WriteState::Writing;
417                    return Poll::Ready(res);
418                }
419            }
420        }
421    }
422
423    #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
424    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
425        self.project().upstream_frames_in.as_mut().poll_flush(cx)
426    }
427
428    #[instrument(name = "SessionSocket::poll_close", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
429    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
430        let this = self.project();
431        let _ = this.state.stop();
432        this.upstream_frames_in.as_mut().poll_close(cx)
433    }
434}
435
436#[cfg(feature = "runtime-tokio")]
437impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncRead for SessionSocket<C, S> {
438    #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id()))]
439    fn poll_read(
440        mut self: Pin<&mut Self>,
441        cx: &mut Context<'_>,
442        buf: &mut tokio::io::ReadBuf<'_>,
443    ) -> Poll<std::io::Result<()>> {
444        let slice = buf.initialize_unfilled();
445        let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
446        buf.advance(n);
447        Poll::Ready(Ok(()))
448    }
449}
450
451#[cfg(feature = "runtime-tokio")]
452impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncWrite for SessionSocket<C, S> {
453    #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
454    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
455        futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
456    }
457
458    #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
459    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
460        futures::AsyncWrite::poll_flush(self.as_mut(), cx)
461    }
462
463    #[instrument(name = "SessionSocket::poll_shutdown", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
464    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
465        futures::AsyncWrite::poll_close(self.as_mut(), cx)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use std::collections::HashSet;
472
473    use futures::{AsyncReadExt, AsyncWriteExt};
474    use futures_time::future::FutureExt;
475    use hopr_crypto_packet::prelude::HoprPacket;
476
477    use super::*;
478    use crate::{AcknowledgementState, AcknowledgementStateConfig, utils::test::*};
479
480    const MTU: usize = HoprPacket::PAYLOAD_SIZE;
481
482    const FRAME_SIZE: usize = 1500;
483
484    const DATA_SIZE: usize = 17 * MTU + 271; // Use some size not directly divisible by the MTU
485
486    #[test_log::test(tokio::test)]
487    async fn stateless_socket_unidirectional_should_work() -> anyhow::Result<()> {
488        let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
489
490        let sock_cfg = SessionSocketConfig {
491            frame_size: FRAME_SIZE,
492            ..Default::default()
493        };
494
495        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
496        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
497
498        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
499
500        alice_socket
501            .write_all(&data)
502            .timeout(futures_time::time::Duration::from_secs(2))
503            .await??;
504        alice_socket.flush().await?;
505
506        let mut bob_data = [0u8; DATA_SIZE];
507        bob_socket
508            .read_exact(&mut bob_data)
509            .timeout(futures_time::time::Duration::from_secs(2))
510            .await??;
511        assert_eq!(data, bob_data);
512
513        alice_socket.close().await?;
514        bob_socket.close().await?;
515
516        Ok(())
517    }
518
519    #[test_log::test(tokio::test)]
520    async fn stateful_socket_unidirectional_should_work() -> anyhow::Result<()> {
521        let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
522
523        let sock_cfg = SessionSocketConfig {
524            frame_size: FRAME_SIZE,
525            ..Default::default()
526        };
527
528        let ack_cfg = AcknowledgementStateConfig {
529            expected_packet_latency: Duration::from_millis(2),
530            acknowledgement_delay: Duration::from_millis(5),
531            ..Default::default()
532        };
533
534        let mut alice_socket =
535            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
536        let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
537
538        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
539
540        alice_socket
541            .write_all(&data)
542            .timeout(futures_time::time::Duration::from_secs(2))
543            .await??;
544        alice_socket.flush().await?;
545
546        let mut bob_data = [0u8; DATA_SIZE];
547        bob_socket
548            .read_exact(&mut bob_data)
549            .timeout(futures_time::time::Duration::from_secs(2))
550            .await??;
551        assert_eq!(data, bob_data);
552
553        alice_socket.close().await?;
554        bob_socket.close().await?;
555
556        Ok(())
557    }
558
559    #[test_log::test(tokio::test)]
560    async fn stateless_socket_bidirectional_should_work() -> anyhow::Result<()> {
561        let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
562
563        let sock_cfg = SessionSocketConfig {
564            frame_size: FRAME_SIZE,
565            ..Default::default()
566        };
567
568        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
569        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
570
571        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
572        alice_socket
573            .write_all(&alice_sent_data)
574            .timeout(futures_time::time::Duration::from_secs(2))
575            .await??;
576        alice_socket.flush().await?;
577
578        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
579        bob_socket
580            .write_all(&bob_sent_data)
581            .timeout(futures_time::time::Duration::from_secs(2))
582            .await??;
583        bob_socket.flush().await?;
584
585        let mut bob_recv_data = [0u8; DATA_SIZE];
586        bob_socket
587            .read_exact(&mut bob_recv_data)
588            .timeout(futures_time::time::Duration::from_secs(2))
589            .await??;
590        assert_eq!(alice_sent_data, bob_recv_data);
591
592        let mut alice_recv_data = [0u8; DATA_SIZE];
593        alice_socket
594            .read_exact(&mut alice_recv_data)
595            .timeout(futures_time::time::Duration::from_secs(2))
596            .await??;
597        assert_eq!(bob_sent_data, alice_recv_data);
598
599        Ok(())
600    }
601
602    #[test_log::test(tokio::test)]
603    async fn stateful_socket_bidirectional_should_work() -> anyhow::Result<()> {
604        let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
605
606        // use hopr_network_types::capture::PcapIoExt;
607        // let (alice, bob) = (alice.capture("alice.pcap"), bob.capture("bob.pcap"));
608
609        let sock_cfg = SessionSocketConfig {
610            frame_size: FRAME_SIZE,
611            ..Default::default()
612        };
613
614        let ack_cfg = AcknowledgementStateConfig {
615            expected_packet_latency: Duration::from_millis(2),
616            acknowledgement_delay: Duration::from_millis(10),
617            ..Default::default()
618        };
619
620        let mut alice_socket =
621            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
622        let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
623
624        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
625        alice_socket
626            .write_all(&alice_sent_data)
627            .timeout(futures_time::time::Duration::from_secs(2))
628            .await??;
629        alice_socket.flush().await?;
630
631        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
632        bob_socket
633            .write_all(&bob_sent_data)
634            .timeout(futures_time::time::Duration::from_secs(2))
635            .await??;
636        bob_socket.flush().await?;
637
638        let mut bob_recv_data = [0u8; DATA_SIZE];
639        bob_socket
640            .read_exact(&mut bob_recv_data)
641            .timeout(futures_time::time::Duration::from_secs(2))
642            .await??;
643        assert_eq!(alice_sent_data, bob_recv_data);
644
645        let mut alice_recv_data = [0u8; DATA_SIZE];
646        alice_socket
647            .read_exact(&mut alice_recv_data)
648            .timeout(futures_time::time::Duration::from_secs(2))
649            .await??;
650        assert_eq!(bob_sent_data, alice_recv_data);
651
652        Ok(())
653    }
654
655    #[test_log::test(tokio::test)]
656    async fn stateless_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
657        let network_cfg = FaultyNetworkConfig {
658            mixing_factor: 10,
659            ..Default::default()
660        };
661
662        let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
663
664        let sock_cfg = SessionSocketConfig {
665            frame_size: FRAME_SIZE,
666            ..Default::default()
667        };
668
669        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
670        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
671
672        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
673        alice_socket
674            .write_all(&data)
675            .timeout(futures_time::time::Duration::from_secs(2))
676            .await??;
677        alice_socket.flush().await?;
678
679        let mut bob_recv_data = [0u8; DATA_SIZE];
680        bob_socket
681            .read_exact(&mut bob_recv_data)
682            .timeout(futures_time::time::Duration::from_secs(2))
683            .await??;
684        assert_eq!(data, bob_recv_data);
685
686        alice_socket.close().await?;
687        bob_socket.close().await?;
688
689        Ok(())
690    }
691
692    #[test_log::test(tokio::test)]
693    async fn stateful_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
694        let network_cfg = FaultyNetworkConfig {
695            mixing_factor: 10,
696            ..Default::default()
697        };
698
699        let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
700
701        let sock_cfg = SessionSocketConfig {
702            frame_size: FRAME_SIZE,
703            ..Default::default()
704        };
705
706        let ack_cfg = AcknowledgementStateConfig {
707            expected_packet_latency: Duration::from_millis(2),
708            acknowledgement_delay: Duration::from_millis(5),
709            ..Default::default()
710        };
711
712        let mut alice_socket =
713            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
714        let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
715
716        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
717        alice_socket
718            .write_all(&data)
719            .timeout(futures_time::time::Duration::from_secs(2))
720            .await??;
721        alice_socket.flush().await?;
722
723        let mut bob_recv_data = [0u8; DATA_SIZE];
724        bob_socket
725            .read_exact(&mut bob_recv_data)
726            .timeout(futures_time::time::Duration::from_secs(2))
727            .await??;
728        assert_eq!(data, bob_recv_data);
729
730        alice_socket.close().await?;
731        bob_socket.close().await?;
732
733        Ok(())
734    }
735
736    #[test_log::test(tokio::test)]
737    async fn stateless_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
738        let network_cfg = FaultyNetworkConfig {
739            mixing_factor: 10,
740            ..Default::default()
741        };
742
743        let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
744
745        let sock_cfg = SessionSocketConfig {
746            frame_size: FRAME_SIZE,
747            ..Default::default()
748        };
749
750        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
751        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
752
753        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
754        alice_socket
755            .write_all(&alice_sent_data)
756            .timeout(futures_time::time::Duration::from_secs(2))
757            .await??;
758        alice_socket.flush().await?;
759
760        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
761        bob_socket
762            .write_all(&bob_sent_data)
763            .timeout(futures_time::time::Duration::from_secs(2))
764            .await??;
765        bob_socket.flush().await?;
766
767        let mut bob_recv_data = [0u8; DATA_SIZE];
768        bob_socket
769            .read_exact(&mut bob_recv_data)
770            .timeout(futures_time::time::Duration::from_secs(2))
771            .await??;
772        assert_eq!(alice_sent_data, bob_recv_data);
773
774        let mut alice_recv_data = [0u8; DATA_SIZE];
775        alice_socket
776            .read_exact(&mut alice_recv_data)
777            .timeout(futures_time::time::Duration::from_secs(2))
778            .await??;
779        assert_eq!(bob_sent_data, alice_recv_data);
780
781        alice_socket.close().await?;
782        bob_socket.close().await?;
783
784        Ok(())
785    }
786
787    #[test_log::test(tokio::test)]
788    async fn stateful_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
789        let network_cfg = FaultyNetworkConfig {
790            mixing_factor: 10,
791            ..Default::default()
792        };
793
794        let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
795
796        let sock_cfg = SessionSocketConfig {
797            frame_size: FRAME_SIZE,
798            ..Default::default()
799        };
800
801        let ack_cfg = AcknowledgementStateConfig {
802            expected_packet_latency: Duration::from_millis(2),
803            acknowledgement_delay: Duration::from_millis(5),
804            ..Default::default()
805        };
806
807        let mut alice_socket =
808            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
809        let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
810
811        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
812        alice_socket
813            .write_all(&alice_sent_data)
814            .timeout(futures_time::time::Duration::from_secs(2))
815            .await??;
816        alice_socket.flush().await?;
817
818        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
819        bob_socket
820            .write_all(&bob_sent_data)
821            .timeout(futures_time::time::Duration::from_secs(2))
822            .await??;
823        bob_socket.flush().await?;
824
825        let mut bob_recv_data = [0u8; DATA_SIZE];
826        bob_socket
827            .read_exact(&mut bob_recv_data)
828            .timeout(futures_time::time::Duration::from_secs(2))
829            .await??;
830        assert_eq!(alice_sent_data, bob_recv_data);
831
832        let mut alice_recv_data = [0u8; DATA_SIZE];
833        alice_socket
834            .read_exact(&mut alice_recv_data)
835            .timeout(futures_time::time::Duration::from_secs(2))
836            .await??;
837        assert_eq!(bob_sent_data, alice_recv_data);
838
839        alice_socket.close().await?;
840        bob_socket.close().await?;
841
842        Ok(())
843    }
844
845    #[test_log::test(tokio::test)]
846    async fn stateless_socket_unidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
847        let (alice, bob) = setup_alice_bob::<MTU>(
848            FaultyNetworkConfig {
849                avg_delay: Duration::from_millis(10),
850                ids_to_drop: HashSet::from_iter([0_usize]),
851                ..Default::default()
852            },
853            None,
854            None,
855        );
856
857        let alice_cfg = SessionSocketConfig {
858            frame_size: FRAME_SIZE,
859            ..Default::default()
860        };
861
862        let bob_cfg = SessionSocketConfig {
863            frame_size: FRAME_SIZE,
864            frame_timeout: Duration::from_millis(55),
865            ..Default::default()
866        };
867
868        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, alice_cfg)?;
869        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, bob_cfg)?;
870
871        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
872        alice_socket
873            .write_all(&data)
874            .timeout(futures_time::time::Duration::from_secs(2))
875            .await??;
876        alice_socket.flush().await?;
877        alice_socket.close().await?;
878
879        let mut bob_data = Vec::with_capacity(DATA_SIZE);
880        bob_socket
881            .read_to_end(&mut bob_data)
882            .timeout(futures_time::time::Duration::from_secs(2))
883            .await??;
884
885        // The whole first frame is discarded due to the missing first segment
886        assert_eq!(data.len() - 1500, bob_data.len());
887        assert_eq!(&data[1500..], &bob_data);
888
889        bob_socket.close().await?;
890
891        Ok(())
892    }
893
894    #[test_log::test(tokio::test)]
895    async fn stateful_socket_unidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
896        let (alice, bob) = setup_alice_bob::<MTU>(
897            FaultyNetworkConfig {
898                avg_delay: Duration::from_millis(10),
899                ids_to_drop: HashSet::from_iter([0_usize]),
900                ..Default::default()
901            },
902            None,
903            None,
904        );
905
906        let alice_cfg = SessionSocketConfig {
907            frame_size: FRAME_SIZE,
908            ..Default::default()
909        };
910
911        let bob_cfg = SessionSocketConfig {
912            frame_size: FRAME_SIZE,
913            frame_timeout: Duration::from_millis(1000),
914            ..Default::default()
915        };
916
917        let ack_cfg = AcknowledgementStateConfig {
918            expected_packet_latency: Duration::from_millis(10),
919            acknowledgement_delay: Duration::from_millis(40),
920            ..Default::default()
921        };
922
923        let mut alice_socket =
924            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), alice_cfg)?;
925        let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), bob_cfg)?;
926
927        let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
928
929        let alice_jh = tokio::spawn(async move {
930            alice_socket
931                .write_all(&data)
932                .timeout(futures_time::time::Duration::from_secs(5))
933                .await??;
934
935            alice_socket.flush().await?;
936
937            // Alice has to keep reading so that it is ready for retransmitting
938            let mut vec = Vec::new();
939            alice_socket.read_to_end(&mut vec).await?;
940            alice_socket.close().await?;
941
942            Ok::<_, std::io::Error>(vec)
943        });
944
945        let mut bob_data = [0u8; DATA_SIZE];
946        bob_socket
947            .read_exact(&mut bob_data)
948            .timeout(futures_time::time::Duration::from_secs(5))
949            .await??;
950        assert_eq!(data, bob_data);
951
952        bob_socket.close().await?;
953
954        let alice_recv = alice_jh.await??;
955        assert!(alice_recv.is_empty());
956
957        Ok(())
958    }
959
960    #[test_log::test(tokio::test)]
961    async fn stateless_socket_bidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
962        let (alice, bob) = setup_alice_bob::<MTU>(
963            FaultyNetworkConfig {
964                avg_delay: Duration::from_millis(10),
965                ids_to_drop: HashSet::from_iter([0_usize]),
966                ..Default::default()
967            },
968            None,
969            None,
970        );
971
972        let alice_cfg = SessionSocketConfig {
973            frame_size: FRAME_SIZE,
974            frame_timeout: Duration::from_millis(55),
975            ..Default::default()
976        };
977
978        let bob_cfg = SessionSocketConfig {
979            frame_size: FRAME_SIZE,
980            frame_timeout: Duration::from_millis(55),
981            ..Default::default()
982        };
983
984        let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, alice_cfg)?;
985        let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, bob_cfg)?;
986
987        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
988        alice_socket
989            .write_all(&alice_sent_data)
990            .timeout(futures_time::time::Duration::from_secs(2))
991            .await??;
992        alice_socket.flush().await?;
993
994        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
995        bob_socket
996            .write_all(&bob_sent_data)
997            .timeout(futures_time::time::Duration::from_secs(2))
998            .await??;
999        bob_socket.flush().await?;
1000
1001        alice_socket.close().await?;
1002        bob_socket.close().await?;
1003
1004        let mut alice_recv_data = Vec::with_capacity(DATA_SIZE);
1005        alice_socket
1006            .read_to_end(&mut alice_recv_data)
1007            .timeout(futures_time::time::Duration::from_secs(2))
1008            .await??;
1009
1010        let mut bob_recv_data = Vec::with_capacity(DATA_SIZE);
1011        bob_socket
1012            .read_to_end(&mut bob_recv_data)
1013            .timeout(futures_time::time::Duration::from_secs(2))
1014            .await??;
1015
1016        // The whole first frame is discarded due to the missing first segment
1017        assert_eq!(bob_sent_data.len() - 1500, alice_recv_data.len());
1018        assert_eq!(&bob_sent_data[1500..], &alice_recv_data);
1019
1020        assert_eq!(alice_sent_data.len() - 1500, bob_recv_data.len());
1021        assert_eq!(&alice_sent_data[1500..], &bob_recv_data);
1022
1023        Ok(())
1024    }
1025
1026    //#[test_log::test(tokio::test)]
1027    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1028    async fn stateful_socket_bidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1029        let (alice, bob) = setup_alice_bob::<MTU>(
1030            FaultyNetworkConfig {
1031                avg_delay: Duration::from_millis(10),
1032                ids_to_drop: HashSet::from_iter([0_usize]),
1033                ..Default::default()
1034            },
1035            None,
1036            None,
1037        );
1038
1039        // use hopr_network_types::capture::PcapIoExt;
1040        // let (alice, bob) = (alice.capture("alice.pcap"), bob.capture("bob.pcap"));
1041
1042        let alice_cfg = SessionSocketConfig {
1043            frame_size: FRAME_SIZE,
1044            frame_timeout: Duration::from_millis(1000),
1045            ..Default::default()
1046        };
1047
1048        let bob_cfg = SessionSocketConfig {
1049            frame_size: FRAME_SIZE,
1050            frame_timeout: Duration::from_millis(1000),
1051            ..Default::default()
1052        };
1053
1054        let ack_cfg = AcknowledgementStateConfig {
1055            expected_packet_latency: Duration::from_millis(10),
1056            acknowledgement_delay: Duration::from_millis(40),
1057            ..Default::default()
1058        };
1059
1060        let (mut alice_rx, mut alice_tx) =
1061            SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), alice_cfg)?.split();
1062
1063        let (mut bob_rx, mut bob_tx) =
1064            SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), bob_cfg)?.split();
1065
1066        let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1067        let (alice_data_tx, alice_recv_data) = futures::channel::oneshot::channel();
1068        let alice_rx_jh = tokio::spawn(async move {
1069            let mut alice_recv_data = vec![0u8; DATA_SIZE];
1070            alice_rx.read_exact(&mut alice_recv_data).await?;
1071            alice_data_tx
1072                .send(alice_recv_data)
1073                .map_err(|_| std::io::Error::other("tx error"))?;
1074
1075            // Keep reading until the socket is closed
1076            alice_rx.read_to_end(&mut Vec::new()).await?;
1077            Ok::<_, std::io::Error>(())
1078        });
1079
1080        let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1081        let (bob_data_tx, bob_recv_data) = futures::channel::oneshot::channel();
1082        let bob_rx_jh = tokio::spawn(async move {
1083            let mut bob_recv_data = vec![0u8; DATA_SIZE];
1084            bob_rx.read_exact(&mut bob_recv_data).await?;
1085            bob_data_tx
1086                .send(bob_recv_data)
1087                .map_err(|_| std::io::Error::other("tx error"))?;
1088
1089            // Keep reading until the socket is closed
1090            bob_rx.read_to_end(&mut Vec::new()).await?;
1091            Ok::<_, std::io::Error>(())
1092        });
1093
1094        let alice_tx_jh = tokio::spawn(async move {
1095            alice_tx
1096                .write_all(&alice_sent_data)
1097                .timeout(futures_time::time::Duration::from_secs(2))
1098                .await??;
1099            alice_tx.flush().await?;
1100
1101            // Once all data is sent, wait for the other side to receive it and close the socket
1102            let out = alice_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1103            alice_tx.close().await?;
1104            tracing::info!("alice closed");
1105            Ok::<_, std::io::Error>(out)
1106        });
1107
1108        let bob_tx_jh = tokio::spawn(async move {
1109            bob_tx
1110                .write_all(&bob_sent_data)
1111                .timeout(futures_time::time::Duration::from_secs(2))
1112                .await??;
1113            bob_tx.flush().await?;
1114
1115            // Once all data is sent, wait for the other side to receive it and close the socket
1116            let out = bob_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1117            bob_tx.close().await?;
1118            tracing::info!("bob closed");
1119            Ok::<_, std::io::Error>(out)
1120        });
1121
1122        let (alice_recv_data, bob_recv_data, a, b) =
1123            futures::future::try_join4(alice_tx_jh, bob_tx_jh, alice_rx_jh, bob_rx_jh)
1124                .timeout(futures_time::time::Duration::from_secs(4))
1125                .await??;
1126
1127        assert_eq!(&alice_sent_data, bob_recv_data?.as_slice());
1128        assert_eq!(&bob_sent_data, alice_recv_data?.as_slice());
1129        assert!(a.is_ok());
1130        assert!(b.is_ok());
1131
1132        Ok(())
1133    }
1134}