1pub mod ack_state;
4pub mod state;
5
6#[cfg(feature = "telemetry")]
8pub mod telemetry;
9
10use std::{
11 pin::Pin,
12 sync::{Arc, atomic::AtomicU32},
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt, future, future::AbortHandle};
18use futures_concurrency::stream::Merge;
19use state::{SocketComponents, SocketState, Stateless};
20use tracing::{Instrument, instrument};
21#[cfg(feature = "telemetry")]
22use {
23 strum::IntoDiscriminant,
24 telemetry::{SessionMessageDiscriminants, SessionTelemetryTracker},
25};
26
27use crate::{
28 errors::SessionError,
29 processing::{ReassemblerExt, SegmenterExt, SequencerExt, types::FrameInspector},
30 protocol::{OrderedFrame, SegmentRequest, SeqIndicator, SessionCodec, SessionMessage},
31};
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq, smart_default::SmartDefault)]
35pub struct SessionSocketConfig {
36 #[default(1500)]
46 pub frame_size: usize,
47 #[default(Duration::from_millis(800))]
51 pub frame_timeout: Duration,
52 #[default(0)]
57 pub max_buffered_segments: usize,
58 #[default(8192)]
63 pub capacity: usize,
64 #[default(false)]
68 pub flush_immediately: bool,
69 #[default(2048)]
75 pub control_channel_capacity: usize,
76}
77
78enum WriteState {
79 WriteOnly,
80 Writing,
81 Flushing(usize),
82}
83
84#[pin_project::pin_project]
92pub struct SessionSocket<const C: usize, S> {
93 state: S,
94 upstream_frames_in: Pin<Box<dyn futures::io::AsyncWrite + Send>>,
96 downstream_frames_out: Pin<Box<dyn futures::io::AsyncRead + Send>>,
98 write_state: WriteState,
99}
100
101impl<const C: usize> SessionSocket<C, Stateless<C>> {
102 pub fn new_stateless<T, I>(
107 id: I,
108 transport: T,
109 cfg: SessionSocketConfig,
110 #[cfg(feature = "telemetry")] stats: impl SessionTelemetryTracker + Clone + Send + 'static,
111 ) -> Result<Self, SessionError>
112 where
113 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
114 I: std::fmt::Display + Clone,
115 {
116 let frame_size = cfg.frame_size.clamp(
118 C,
119 (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
120 );
121
122 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
124
125 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
128
129 let (packets_out, packets_in) = framed.split();
131
132 #[cfg(feature = "telemetry")]
134 let (s0, s1, s2, s3) = { (stats.clone(), stats.clone(), stats.clone(), stats.clone()) };
135
136 let upstream_frames_in = packets_out
138 .with(move |segment| {
139 #[cfg(feature = "telemetry")]
140 s0.outgoing_message(SessionMessageDiscriminants::Segment);
141
142 future::ok::<_, SessionError>(SessionMessage::<C>::Segment(segment))
143 })
144 .segmenter_with_terminating_segment::<C>(frame_size);
145
146 let last_emitted_frame = Arc::new(AtomicU32::new(0));
147 let last_emitted_frame_clone = last_emitted_frame.clone();
148
149 let stage1_span = tracing::debug_span!("SessionSocket::packets_in::pre_reassembly", session_id = %id);
151 let stage2_span = tracing::debug_span!("SessionSocket::packets_in::pre_sequencing", session_id = %id);
152 let stage3_span = tracing::debug_span!("SessionSocket::packets_in::post_sequencing", session_id = %id);
153
154 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
155
156 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
160 .filter_map(move |packet| {
162 let _span = stage1_span.enter();
163 futures::future::ready(match packet {
164 Ok(packet) => {
165 packet.try_as_segment().filter(|s| {
166 #[cfg(feature = "telemetry")]
167 s1.incoming_message(SessionMessageDiscriminants::Segment);
168
169 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
171 if s.frame_id <= last_emitted_id {
172 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
173 false
174 } else {
175 true
176 }
177 })
178 }
179 Err(error) => {
180 tracing::error!(%error, "unparseable packet");
181 #[cfg(feature = "telemetry")]
182 s1.error();
183 None
184 }
185 })
186 })
187 .reassembler(cfg.frame_timeout, cfg.capacity)
189 .filter_map(move |maybe_frame| {
191 let _span = stage2_span.enter();
192 futures::future::ready(match maybe_frame {
193 Ok(frame) => {
194 #[cfg(feature = "telemetry")]
195 s2.frame_completed();
196 Some(OrderedFrame(frame))
197 }
198 Err(error) => {
199 tracing::error!(%error, "failed to reassemble frame");
200 #[cfg(feature = "telemetry")]
201 s2.incomplete_frame();
202 None
203 }
204 })
205 })
206 .sequencer(cfg.frame_timeout, cfg.capacity)
208 .filter_map(move |maybe_frame| {
210 let _span = stage3_span.enter();
211 future::ready(match maybe_frame {
212 Ok(frame) => {
213 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
214 if frame.0.is_terminating {
215 tracing::warn!("terminating frame received");
216 packets_in_abort_handle.abort();
217 }
218 #[cfg(feature = "telemetry")]
219 s3.frame_emitted();
220 Some(Ok(frame.0))
221 }
222 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
224 tracing::error!(frame_id, "frame discarded");
225 #[cfg(feature = "telemetry")]
226 s3.frame_discarded();
227 None
228 }
229 Err(err) => {
230 #[cfg(feature = "telemetry")]
231 s3.error();
232 Some(Err(std::io::Error::other(err)))
233 }
234 })
235 })
236 .into_async_read();
237
238 Ok(Self {
239 state: Stateless::new(id),
240 upstream_frames_in: Box::pin(upstream_frames_in),
241 downstream_frames_out: Box::pin(downstream_frames_out),
242 write_state: if cfg.flush_immediately {
243 WriteState::Writing
244 } else {
245 WriteState::WriteOnly
246 },
247 })
248 }
249}
250
251impl<const C: usize, S: SocketState<C> + Clone + 'static> SessionSocket<C, S> {
252 pub fn new<T>(
255 transport: T,
256 mut state: S,
257 cfg: SessionSocketConfig,
258 #[cfg(feature = "telemetry")] stats: impl SessionTelemetryTracker + Clone + Send + 'static,
259 ) -> Result<Self, SessionError>
260 where
261 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
262 {
263 let frame_size = cfg.frame_size.clamp(
265 C,
266 (C - SessionMessage::<C>::SEGMENT_OVERHEAD)
267 * SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME.min((SeqIndicator::MAX + 1) as usize),
268 );
269
270 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
272
273 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
276
277 #[cfg(feature = "telemetry")]
279 let (s0, s1, s2, s3) = { (stats.clone(), stats.clone(), stats.clone(), stats.clone()) };
280
281 let (packets_out, packets_in) = framed.split();
283
284 let inspector = FrameInspector::new(cfg.capacity);
285
286 tracing::debug!(
287 capacity = cfg.control_channel_capacity,
288 "creating session control channel"
289 );
290 let (ctl_tx, ctl_rx) = futures::channel::mpsc::channel(cfg.control_channel_capacity.max(128));
291 state.run(SocketComponents {
292 inspector: Some(inspector.clone()),
293 ctl_tx,
294 })?;
295
296 let (segments_tx, segments_rx) = futures::channel::mpsc::channel(cfg.capacity);
298 let mut st_1 = state.clone();
299 let upstream_frames_in = segments_tx
300 .with(move |segment| {
301 let _span =
302 tracing::debug_span!("SessionSocket::packets_out::segmenter", session_id = st_1.session_id())
303 .entered();
304 if let Err(error) = st_1.segment_sent(&segment) {
307 tracing::debug!(%error, "outgoing segment state update failed");
308 }
309 future::ok::<_, futures::channel::mpsc::SendError>(SessionMessage::<C>::Segment(segment))
310 })
311 .segmenter_with_terminating_segment::<C>(frame_size);
312
313 hopr_async_runtime::prelude::spawn(
316 (ctl_rx, segments_rx)
317 .merge()
318 .map(move |msg| {
319 #[cfg(feature = "telemetry")]
320 s0.outgoing_message(msg.discriminant());
321 Ok(msg)
322 })
323 .forward(packets_out)
324 .map(move |result| match result {
325 Ok(_) => tracing::debug!("outgoing packet processing done"),
326 Err(error) => {
327 tracing::error!(%error, "error while processing outgoing packets")
328 }
329 })
330 .instrument(tracing::debug_span!(
331 "SessionSocket::packets_out",
332 session_id = state.session_id()
333 )),
334 );
335
336 let last_emitted_frame = Arc::new(AtomicU32::new(0));
337 let last_emitted_frame_clone = last_emitted_frame.clone();
338
339 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
340
341 let mut st_1 = state.clone();
343 let mut st_2 = state.clone();
344 let mut st_3 = state.clone();
345
346 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
349 .filter_map(move |packet| {
351 let _span = tracing::debug_span!(
352 "SessionSocket::packets_in::pre_reassembly",
353 session_id = st_1.session_id()
354 )
355 .entered();
356 futures::future::ready(match packet {
357 Ok(packet) => {
358 if let Err(error) = st_1.incoming_message(&packet) {
359 tracing::debug!(%error, "incoming message state update failed");
360 }
361 #[cfg(feature = "telemetry")]
362 s1.incoming_message(packet.discriminant());
363
364 packet.try_as_segment().filter(|s| {
366 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
367 if s.frame_id <= last_emitted_id {
368 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
369 false
370 } else {
371 true
372 }
373 })
374 }
375 Err(error) => {
376 tracing::error!(%error, "unparseable packet");
377 #[cfg(feature = "telemetry")]
378 s1.error();
379 None
380 }
381 })
382 })
383 .reassembler_with_inspector(cfg.frame_timeout, cfg.capacity, inspector)
385 .filter_map(move |maybe_frame| {
387 let _span = tracing::debug_span!(
388 "SessionSocket::packets_in::pre_sequencing",
389 session_id = st_2.session_id()
390 )
391 .entered();
392 futures::future::ready(match maybe_frame {
393 Ok(frame) => {
394 if let Err(error) = st_2.frame_complete(frame.frame_id) {
395 tracing::error!(%error, "frame complete state update failed");
396 }
397 #[cfg(feature = "telemetry")]
398 s2.frame_completed();
399 Some(OrderedFrame(frame))
400 }
401 Err(error) => {
402 tracing::error!(%error, "failed to reassemble frame");
403 #[cfg(feature = "telemetry")]
404 s2.incomplete_frame();
405 None
406 }
407 })
408 })
409 .sequencer(cfg.frame_timeout, cfg.frame_size)
411 .filter_map(move |maybe_frame| {
414 let _span = tracing::debug_span!(
415 "SessionSocket::packets_in::post_sequencing",
416 session_id = st_3.session_id()
417 )
418 .entered();
419 future::ready(match maybe_frame {
421 Ok(frame) => {
422 if let Err(error) = st_3.frame_emitted(frame.0.frame_id) {
423 tracing::error!(%error, "frame received state update failed");
424 }
425 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
426 if frame.0.is_terminating {
427 tracing::warn!("terminating frame received");
428 packets_in_abort_handle.abort();
429 }
430 #[cfg(feature = "telemetry")]
431 s3.frame_emitted();
432 Some(Ok(frame.0))
433 }
434 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
435 if let Err(error) = st_3.frame_discarded(frame_id) {
436 tracing::error!(%error, "frame discarded state update failed");
437 }
438 #[cfg(feature = "telemetry")]
439 s3.frame_discarded();
440 None }
442 Err(err) => {
443 #[cfg(feature = "telemetry")]
444 s3.error();
445 Some(Err(std::io::Error::other(err)))
446 }
447 })
448 })
449 .into_async_read();
450
451 Ok(Self {
452 state,
453 upstream_frames_in: Box::pin(upstream_frames_in),
454 downstream_frames_out: Box::pin(downstream_frames_out),
455 write_state: if cfg.flush_immediately {
456 WriteState::Writing
457 } else {
458 WriteState::WriteOnly
459 },
460 })
461 }
462}
463
464impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncRead for SessionSocket<C, S> {
465 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
466 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
467 self.project().downstream_frames_out.as_mut().poll_read(cx, buf)
468 }
469}
470
471impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncWrite for SessionSocket<C, S> {
472 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
473 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
474 let this = self.project();
475 loop {
476 match this.write_state {
477 WriteState::WriteOnly => {
478 return this.upstream_frames_in.as_mut().poll_write(cx, buf);
479 }
480 WriteState::Writing => {
481 let len = futures::ready!(this.upstream_frames_in.as_mut().poll_write(cx, buf))?;
482 *this.write_state = WriteState::Flushing(len);
483 }
484 WriteState::Flushing(len) => {
485 let res = futures::ready!(this.upstream_frames_in.as_mut().poll_flush(cx)).map(|_| *len);
486 *this.write_state = WriteState::Writing;
487 return Poll::Ready(res);
488 }
489 }
490 }
491 }
492
493 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
494 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
495 self.project().upstream_frames_in.as_mut().poll_flush(cx)
496 }
497
498 #[instrument(name = "SessionSocket::poll_close", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
499 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
500 let this = self.project();
501 let _ = this.state.stop();
502 this.upstream_frames_in.as_mut().poll_close(cx)
503 }
504}
505
506#[cfg(feature = "runtime-tokio")]
507impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncRead for SessionSocket<C, S> {
508 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id()))]
509 fn poll_read(
510 mut self: Pin<&mut Self>,
511 cx: &mut Context<'_>,
512 buf: &mut tokio::io::ReadBuf<'_>,
513 ) -> Poll<std::io::Result<()>> {
514 let slice = buf.initialize_unfilled();
515 let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
516 buf.advance(n);
517 Poll::Ready(Ok(()))
518 }
519}
520
521#[cfg(feature = "runtime-tokio")]
522impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncWrite for SessionSocket<C, S> {
523 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
524 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
525 futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
526 }
527
528 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
529 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
530 futures::AsyncWrite::poll_flush(self.as_mut(), cx)
531 }
532
533 #[instrument(name = "SessionSocket::poll_shutdown", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
534 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
535 futures::AsyncWrite::poll_close(self.as_mut(), cx)
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use std::collections::HashSet;
542
543 use futures::{AsyncReadExt, AsyncWriteExt};
544 use futures_time::future::FutureExt;
545 use hopr_crypto_packet::prelude::HoprPacket;
546
547 use super::*;
548 #[cfg(feature = "telemetry")]
549 use crate::socket::telemetry::{NoopTracker, tests::TestTelemetryTracker};
550 use crate::{AcknowledgementState, AcknowledgementStateConfig, utils::test::*};
551
552 const MTU: usize = HoprPacket::PAYLOAD_SIZE;
553
554 const FRAME_SIZE: usize = 1500;
555
556 const DATA_SIZE: usize = 17 * MTU + 271; #[test_log::test(tokio::test)]
559 async fn stateless_socket_unidirectional_should_work() -> anyhow::Result<()> {
560 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
561
562 let sock_cfg = SessionSocketConfig {
563 frame_size: FRAME_SIZE,
564 ..Default::default()
565 };
566
567 #[cfg(feature = "telemetry")]
568 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
569
570 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
571 "alice",
572 alice,
573 sock_cfg,
574 #[cfg(feature = "telemetry")]
575 alice_tracker.clone(),
576 )?;
577 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
578 "bob",
579 bob,
580 sock_cfg,
581 #[cfg(feature = "telemetry")]
582 bob_tracker.clone(),
583 )?;
584
585 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
586
587 alice_socket
588 .write_all(&data)
589 .timeout(futures_time::time::Duration::from_secs(2))
590 .await??;
591 alice_socket.flush().await?;
592
593 let mut bob_data = [0u8; DATA_SIZE];
594 bob_socket
595 .read_exact(&mut bob_data)
596 .timeout(futures_time::time::Duration::from_secs(2))
597 .await??;
598 assert_eq!(data, bob_data);
599
600 alice_socket.close().await?;
601 bob_socket.close().await?;
602
603 #[cfg(feature = "telemetry")]
604 {
605 insta::assert_yaml_snapshot!(alice_tracker);
606 insta::assert_yaml_snapshot!(bob_tracker);
607 }
608
609 Ok(())
610 }
611
612 #[test_log::test(tokio::test)]
613 async fn stateful_socket_unidirectional_should_work() -> anyhow::Result<()> {
614 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
615
616 let sock_cfg = SessionSocketConfig {
617 frame_size: FRAME_SIZE,
618 ..Default::default()
619 };
620
621 let ack_cfg = AcknowledgementStateConfig {
622 expected_packet_latency: Duration::from_millis(2),
623 acknowledgement_delay: Duration::from_millis(5),
624 ..Default::default()
625 };
626
627 let mut alice_socket = SessionSocket::<MTU, _>::new(
628 alice,
629 AcknowledgementState::new("alice", ack_cfg),
630 sock_cfg,
631 #[cfg(feature = "telemetry")]
632 NoopTracker,
633 )?;
634 let mut bob_socket = SessionSocket::<MTU, _>::new(
635 bob,
636 AcknowledgementState::new("bob", ack_cfg),
637 sock_cfg,
638 #[cfg(feature = "telemetry")]
639 NoopTracker,
640 )?;
641
642 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
643
644 alice_socket
645 .write_all(&data)
646 .timeout(futures_time::time::Duration::from_secs(2))
647 .await??;
648 alice_socket.flush().await?;
649
650 let mut bob_data = [0u8; DATA_SIZE];
651 bob_socket
652 .read_exact(&mut bob_data)
653 .timeout(futures_time::time::Duration::from_secs(2))
654 .await??;
655 assert_eq!(data, bob_data);
656
657 alice_socket.close().await?;
658 bob_socket.close().await?;
659
660 Ok(())
661 }
662
663 #[test_log::test(tokio::test)]
664 async fn stateless_socket_bidirectional_should_work() -> anyhow::Result<()> {
665 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
666
667 let sock_cfg = SessionSocketConfig {
668 frame_size: FRAME_SIZE,
669 ..Default::default()
670 };
671
672 #[cfg(feature = "telemetry")]
673 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
674
675 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
676 "alice",
677 alice,
678 sock_cfg,
679 #[cfg(feature = "telemetry")]
680 alice_tracker.clone(),
681 )?;
682 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
683 "bob",
684 bob,
685 sock_cfg,
686 #[cfg(feature = "telemetry")]
687 bob_tracker.clone(),
688 )?;
689
690 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
691 alice_socket
692 .write_all(&alice_sent_data)
693 .timeout(futures_time::time::Duration::from_secs(2))
694 .await??;
695 alice_socket.flush().await?;
696
697 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
698 bob_socket
699 .write_all(&bob_sent_data)
700 .timeout(futures_time::time::Duration::from_secs(2))
701 .await??;
702 bob_socket.flush().await?;
703
704 let mut bob_recv_data = [0u8; DATA_SIZE];
705 bob_socket
706 .read_exact(&mut bob_recv_data)
707 .timeout(futures_time::time::Duration::from_secs(2))
708 .await??;
709 assert_eq!(alice_sent_data, bob_recv_data);
710
711 let mut alice_recv_data = [0u8; DATA_SIZE];
712 alice_socket
713 .read_exact(&mut alice_recv_data)
714 .timeout(futures_time::time::Duration::from_secs(2))
715 .await??;
716 assert_eq!(bob_sent_data, alice_recv_data);
717
718 #[cfg(feature = "telemetry")]
719 {
720 insta::assert_yaml_snapshot!(alice_tracker);
721 insta::assert_yaml_snapshot!(bob_tracker);
722 }
723
724 Ok(())
725 }
726
727 #[test_log::test(tokio::test)]
728 async fn stateful_socket_bidirectional_should_work() -> anyhow::Result<()> {
729 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
730
731 let sock_cfg = SessionSocketConfig {
735 frame_size: FRAME_SIZE,
736 ..Default::default()
737 };
738
739 let ack_cfg = AcknowledgementStateConfig {
740 expected_packet_latency: Duration::from_millis(2),
741 acknowledgement_delay: Duration::from_millis(10),
742 ..Default::default()
743 };
744
745 let mut alice_socket = SessionSocket::<MTU, _>::new(
746 alice,
747 AcknowledgementState::new("alice", ack_cfg),
748 sock_cfg,
749 #[cfg(feature = "telemetry")]
750 NoopTracker,
751 )?;
752 let mut bob_socket = SessionSocket::<MTU, _>::new(
753 bob,
754 AcknowledgementState::new("bob", ack_cfg),
755 sock_cfg,
756 #[cfg(feature = "telemetry")]
757 NoopTracker,
758 )?;
759
760 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
761 alice_socket
762 .write_all(&alice_sent_data)
763 .timeout(futures_time::time::Duration::from_secs(2))
764 .await??;
765 alice_socket.flush().await?;
766
767 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
768 bob_socket
769 .write_all(&bob_sent_data)
770 .timeout(futures_time::time::Duration::from_secs(2))
771 .await??;
772 bob_socket.flush().await?;
773
774 let mut bob_recv_data = [0u8; DATA_SIZE];
775 bob_socket
776 .read_exact(&mut bob_recv_data)
777 .timeout(futures_time::time::Duration::from_secs(2))
778 .await??;
779 assert_eq!(alice_sent_data, bob_recv_data);
780
781 let mut alice_recv_data = [0u8; DATA_SIZE];
782 alice_socket
783 .read_exact(&mut alice_recv_data)
784 .timeout(futures_time::time::Duration::from_secs(2))
785 .await??;
786 assert_eq!(bob_sent_data, alice_recv_data);
787
788 Ok(())
789 }
790
791 #[test_log::test(tokio::test)]
792 async fn stateless_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
793 let network_cfg = FaultyNetworkConfig {
794 mixing_factor: 10,
795 ..Default::default()
796 };
797
798 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
799
800 let sock_cfg = SessionSocketConfig {
801 frame_size: FRAME_SIZE,
802 ..Default::default()
803 };
804
805 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
806 "alice",
807 alice,
808 sock_cfg,
809 #[cfg(feature = "telemetry")]
810 NoopTracker,
811 )?;
812 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
813 "bob",
814 bob,
815 sock_cfg,
816 #[cfg(feature = "telemetry")]
817 NoopTracker,
818 )?;
819
820 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
821 alice_socket
822 .write_all(&data)
823 .timeout(futures_time::time::Duration::from_secs(2))
824 .await??;
825 alice_socket.flush().await?;
826
827 let mut bob_recv_data = [0u8; DATA_SIZE];
828 bob_socket
829 .read_exact(&mut bob_recv_data)
830 .timeout(futures_time::time::Duration::from_secs(2))
831 .await??;
832 assert_eq!(data, bob_recv_data);
833
834 alice_socket.close().await?;
835 bob_socket.close().await?;
836
837 Ok(())
838 }
839
840 #[test_log::test(tokio::test)]
841 async fn stateful_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
842 let network_cfg = FaultyNetworkConfig {
843 mixing_factor: 10,
844 ..Default::default()
845 };
846
847 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
848
849 let sock_cfg = SessionSocketConfig {
850 frame_size: FRAME_SIZE,
851 ..Default::default()
852 };
853
854 let ack_cfg = AcknowledgementStateConfig {
855 expected_packet_latency: Duration::from_millis(2),
856 acknowledgement_delay: Duration::from_millis(5),
857 ..Default::default()
858 };
859
860 let mut alice_socket = SessionSocket::<MTU, _>::new(
861 alice,
862 AcknowledgementState::new("alice", ack_cfg),
863 sock_cfg,
864 #[cfg(feature = "telemetry")]
865 NoopTracker,
866 )?;
867 let mut bob_socket = SessionSocket::<MTU, _>::new(
868 bob,
869 AcknowledgementState::new("bob", ack_cfg),
870 sock_cfg,
871 #[cfg(feature = "telemetry")]
872 NoopTracker,
873 )?;
874
875 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
876 alice_socket
877 .write_all(&data)
878 .timeout(futures_time::time::Duration::from_secs(2))
879 .await??;
880 alice_socket.flush().await?;
881
882 let mut bob_recv_data = [0u8; DATA_SIZE];
883 bob_socket
884 .read_exact(&mut bob_recv_data)
885 .timeout(futures_time::time::Duration::from_secs(2))
886 .await??;
887 assert_eq!(data, bob_recv_data);
888
889 alice_socket.close().await?;
890 bob_socket.close().await?;
891
892 Ok(())
893 }
894
895 #[test_log::test(tokio::test)]
896 async fn stateless_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
897 let network_cfg = FaultyNetworkConfig {
898 mixing_factor: 10,
899 ..Default::default()
900 };
901
902 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
903
904 let sock_cfg = SessionSocketConfig {
905 frame_size: FRAME_SIZE,
906 ..Default::default()
907 };
908
909 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
910 "alice",
911 alice,
912 sock_cfg,
913 #[cfg(feature = "telemetry")]
914 NoopTracker,
915 )?;
916 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
917 "bob",
918 bob,
919 sock_cfg,
920 #[cfg(feature = "telemetry")]
921 NoopTracker,
922 )?;
923
924 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
925 alice_socket
926 .write_all(&alice_sent_data)
927 .timeout(futures_time::time::Duration::from_secs(2))
928 .await??;
929 alice_socket.flush().await?;
930
931 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
932 bob_socket
933 .write_all(&bob_sent_data)
934 .timeout(futures_time::time::Duration::from_secs(2))
935 .await??;
936 bob_socket.flush().await?;
937
938 let mut bob_recv_data = [0u8; DATA_SIZE];
939 bob_socket
940 .read_exact(&mut bob_recv_data)
941 .timeout(futures_time::time::Duration::from_secs(2))
942 .await??;
943 assert_eq!(alice_sent_data, bob_recv_data);
944
945 let mut alice_recv_data = [0u8; DATA_SIZE];
946 alice_socket
947 .read_exact(&mut alice_recv_data)
948 .timeout(futures_time::time::Duration::from_secs(2))
949 .await??;
950 assert_eq!(bob_sent_data, alice_recv_data);
951
952 alice_socket.close().await?;
953 bob_socket.close().await?;
954
955 Ok(())
956 }
957
958 #[test_log::test(tokio::test)]
959 async fn stateful_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
960 let network_cfg = FaultyNetworkConfig {
961 mixing_factor: 10,
962 ..Default::default()
963 };
964
965 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
966
967 let sock_cfg = SessionSocketConfig {
968 frame_size: FRAME_SIZE,
969 ..Default::default()
970 };
971
972 let ack_cfg = AcknowledgementStateConfig {
973 expected_packet_latency: Duration::from_millis(2),
974 acknowledgement_delay: Duration::from_millis(5),
975 ..Default::default()
976 };
977
978 let mut alice_socket = SessionSocket::<MTU, _>::new(
979 alice,
980 AcknowledgementState::new("alice", ack_cfg),
981 sock_cfg,
982 #[cfg(feature = "telemetry")]
983 NoopTracker,
984 )?;
985 let mut bob_socket = SessionSocket::<MTU, _>::new(
986 bob,
987 AcknowledgementState::new("bob", ack_cfg),
988 sock_cfg,
989 #[cfg(feature = "telemetry")]
990 NoopTracker,
991 )?;
992
993 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
994 alice_socket
995 .write_all(&alice_sent_data)
996 .timeout(futures_time::time::Duration::from_secs(2))
997 .await??;
998 alice_socket.flush().await?;
999
1000 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1001 bob_socket
1002 .write_all(&bob_sent_data)
1003 .timeout(futures_time::time::Duration::from_secs(2))
1004 .await??;
1005 bob_socket.flush().await?;
1006
1007 let mut bob_recv_data = [0u8; DATA_SIZE];
1008 bob_socket
1009 .read_exact(&mut bob_recv_data)
1010 .timeout(futures_time::time::Duration::from_secs(2))
1011 .await??;
1012 assert_eq!(alice_sent_data, bob_recv_data);
1013
1014 let mut alice_recv_data = [0u8; DATA_SIZE];
1015 alice_socket
1016 .read_exact(&mut alice_recv_data)
1017 .timeout(futures_time::time::Duration::from_secs(2))
1018 .await??;
1019 assert_eq!(bob_sent_data, alice_recv_data);
1020
1021 alice_socket.close().await?;
1022 bob_socket.close().await?;
1023
1024 Ok(())
1025 }
1026
1027 #[test_log::test(tokio::test)]
1028 async fn stateless_socket_unidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
1029 let (alice, bob) = setup_alice_bob::<MTU>(
1030 FaultyNetworkConfig {
1031 avg_delay: Duration::from_millis(10),
1032 ids_to_drop: HashSet::from_iter([0_usize]),
1033 ..Default::default()
1034 },
1035 None,
1036 None,
1037 );
1038
1039 let alice_cfg = SessionSocketConfig {
1040 frame_size: FRAME_SIZE,
1041 ..Default::default()
1042 };
1043
1044 let bob_cfg = SessionSocketConfig {
1045 frame_size: FRAME_SIZE,
1046 frame_timeout: Duration::from_millis(55),
1047 ..Default::default()
1048 };
1049
1050 #[cfg(feature = "telemetry")]
1051 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
1052
1053 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
1054 "alice",
1055 alice,
1056 alice_cfg,
1057 #[cfg(feature = "telemetry")]
1058 alice_tracker.clone(),
1059 )?;
1060 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
1061 "bob",
1062 bob,
1063 bob_cfg,
1064 #[cfg(feature = "telemetry")]
1065 bob_tracker.clone(),
1066 )?;
1067
1068 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1069 alice_socket
1070 .write_all(&data)
1071 .timeout(futures_time::time::Duration::from_secs(2))
1072 .await??;
1073 alice_socket.flush().await?;
1074 alice_socket.close().await?;
1075
1076 let mut bob_data = Vec::with_capacity(DATA_SIZE);
1077 bob_socket
1078 .read_to_end(&mut bob_data)
1079 .timeout(futures_time::time::Duration::from_secs(2))
1080 .await??;
1081
1082 assert_eq!(data.len() - 1500, bob_data.len());
1084 assert_eq!(&data[1500..], &bob_data);
1085
1086 bob_socket.close().await?;
1087
1088 #[cfg(feature = "telemetry")]
1089 {
1090 insta::assert_yaml_snapshot!(alice_tracker);
1091 insta::assert_yaml_snapshot!(bob_tracker);
1092 }
1093
1094 Ok(())
1095 }
1096
1097 #[test_log::test(tokio::test)]
1098 async fn stateful_socket_unidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1099 let (alice, bob) = setup_alice_bob::<MTU>(
1100 FaultyNetworkConfig {
1101 avg_delay: Duration::from_millis(10),
1102 ids_to_drop: HashSet::from_iter([0_usize]),
1103 ..Default::default()
1104 },
1105 None,
1106 None,
1107 );
1108
1109 let alice_cfg = SessionSocketConfig {
1110 frame_size: FRAME_SIZE,
1111 ..Default::default()
1112 };
1113
1114 let bob_cfg = SessionSocketConfig {
1115 frame_size: FRAME_SIZE,
1116 frame_timeout: Duration::from_millis(1000),
1117 ..Default::default()
1118 };
1119
1120 let ack_cfg = AcknowledgementStateConfig {
1121 expected_packet_latency: Duration::from_millis(10),
1122 acknowledgement_delay: Duration::from_millis(40),
1123 ..Default::default()
1124 };
1125
1126 let mut alice_socket = SessionSocket::<MTU, _>::new(
1127 alice,
1128 AcknowledgementState::new("alice", ack_cfg),
1129 alice_cfg,
1130 #[cfg(feature = "telemetry")]
1131 NoopTracker,
1132 )?;
1133 let mut bob_socket = SessionSocket::<MTU, _>::new(
1134 bob,
1135 AcknowledgementState::new("bob", ack_cfg),
1136 bob_cfg,
1137 #[cfg(feature = "telemetry")]
1138 NoopTracker,
1139 )?;
1140
1141 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1142
1143 let alice_jh = tokio::spawn(async move {
1144 alice_socket
1145 .write_all(&data)
1146 .timeout(futures_time::time::Duration::from_secs(5))
1147 .await??;
1148
1149 alice_socket.flush().await?;
1150
1151 let mut vec = Vec::new();
1153 alice_socket.read_to_end(&mut vec).await?;
1154 alice_socket.close().await?;
1155
1156 Ok::<_, std::io::Error>(vec)
1157 });
1158
1159 let mut bob_data = [0u8; DATA_SIZE];
1160 bob_socket
1161 .read_exact(&mut bob_data)
1162 .timeout(futures_time::time::Duration::from_secs(5))
1163 .await??;
1164 assert_eq!(data, bob_data);
1165
1166 bob_socket.close().await?;
1167
1168 let alice_recv = alice_jh.await??;
1169 assert!(alice_recv.is_empty());
1170
1171 Ok(())
1172 }
1173
1174 #[test_log::test(tokio::test)]
1175 async fn stateless_socket_bidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
1176 let (alice, bob) = setup_alice_bob::<MTU>(
1177 FaultyNetworkConfig {
1178 avg_delay: Duration::from_millis(10),
1179 ids_to_drop: HashSet::from_iter([0_usize]),
1180 ..Default::default()
1181 },
1182 None,
1183 None,
1184 );
1185
1186 let alice_cfg = SessionSocketConfig {
1187 frame_size: FRAME_SIZE,
1188 frame_timeout: Duration::from_millis(55),
1189 ..Default::default()
1190 };
1191
1192 let bob_cfg = SessionSocketConfig {
1193 frame_size: FRAME_SIZE,
1194 frame_timeout: Duration::from_millis(55),
1195 ..Default::default()
1196 };
1197
1198 #[cfg(feature = "telemetry")]
1199 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
1200
1201 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
1202 "alice",
1203 alice,
1204 alice_cfg,
1205 #[cfg(feature = "telemetry")]
1206 alice_tracker.clone(),
1207 )?;
1208 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
1209 "bob",
1210 bob,
1211 bob_cfg,
1212 #[cfg(feature = "telemetry")]
1213 bob_tracker.clone(),
1214 )?;
1215
1216 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1217 alice_socket
1218 .write_all(&alice_sent_data)
1219 .timeout(futures_time::time::Duration::from_secs(2))
1220 .await??;
1221 alice_socket.flush().await?;
1222
1223 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1224 bob_socket
1225 .write_all(&bob_sent_data)
1226 .timeout(futures_time::time::Duration::from_secs(2))
1227 .await??;
1228 bob_socket.flush().await?;
1229
1230 alice_socket.close().await?;
1231 bob_socket.close().await?;
1232
1233 let mut alice_recv_data = Vec::with_capacity(DATA_SIZE);
1234 alice_socket
1235 .read_to_end(&mut alice_recv_data)
1236 .timeout(futures_time::time::Duration::from_secs(2))
1237 .await??;
1238
1239 let mut bob_recv_data = Vec::with_capacity(DATA_SIZE);
1240 bob_socket
1241 .read_to_end(&mut bob_recv_data)
1242 .timeout(futures_time::time::Duration::from_secs(2))
1243 .await??;
1244
1245 assert_eq!(bob_sent_data.len() - 1500, alice_recv_data.len());
1247 assert_eq!(&bob_sent_data[1500..], &alice_recv_data);
1248
1249 assert_eq!(alice_sent_data.len() - 1500, bob_recv_data.len());
1250 assert_eq!(&alice_sent_data[1500..], &bob_recv_data);
1251
1252 #[cfg(feature = "telemetry")]
1253 {
1254 insta::assert_yaml_snapshot!(alice_tracker);
1255 insta::assert_yaml_snapshot!(bob_tracker);
1256 }
1257
1258 Ok(())
1259 }
1260
1261 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1263 async fn stateful_socket_bidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1264 let (alice, bob) = setup_alice_bob::<MTU>(
1265 FaultyNetworkConfig {
1266 avg_delay: Duration::from_millis(10),
1267 ids_to_drop: HashSet::from_iter([0_usize]),
1268 ..Default::default()
1269 },
1270 None,
1271 None,
1272 );
1273
1274 let alice_cfg = SessionSocketConfig {
1278 frame_size: FRAME_SIZE,
1279 frame_timeout: Duration::from_millis(1000),
1280 ..Default::default()
1281 };
1282
1283 let bob_cfg = SessionSocketConfig {
1284 frame_size: FRAME_SIZE,
1285 frame_timeout: Duration::from_millis(1000),
1286 ..Default::default()
1287 };
1288
1289 let ack_cfg = AcknowledgementStateConfig {
1290 expected_packet_latency: Duration::from_millis(10),
1291 acknowledgement_delay: Duration::from_millis(40),
1292 ..Default::default()
1293 };
1294
1295 let (mut alice_rx, mut alice_tx) = SessionSocket::<MTU, _>::new(
1296 alice,
1297 AcknowledgementState::new("alice", ack_cfg),
1298 alice_cfg,
1299 #[cfg(feature = "telemetry")]
1300 NoopTracker,
1301 )?
1302 .split();
1303
1304 let (mut bob_rx, mut bob_tx) = SessionSocket::<MTU, _>::new(
1305 bob,
1306 AcknowledgementState::new("bob", ack_cfg),
1307 bob_cfg,
1308 #[cfg(feature = "telemetry")]
1309 NoopTracker,
1310 )?
1311 .split();
1312
1313 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1314 let (alice_data_tx, alice_recv_data) = futures::channel::oneshot::channel();
1315 let alice_rx_jh = tokio::spawn(async move {
1316 let mut alice_recv_data = vec![0u8; DATA_SIZE];
1317 alice_rx.read_exact(&mut alice_recv_data).await?;
1318 alice_data_tx
1319 .send(alice_recv_data)
1320 .map_err(|_| std::io::Error::other("tx error"))?;
1321
1322 alice_rx.read_to_end(&mut Vec::new()).await?;
1324 Ok::<_, std::io::Error>(())
1325 });
1326
1327 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1328 let (bob_data_tx, bob_recv_data) = futures::channel::oneshot::channel();
1329 let bob_rx_jh = tokio::spawn(async move {
1330 let mut bob_recv_data = vec![0u8; DATA_SIZE];
1331 bob_rx.read_exact(&mut bob_recv_data).await?;
1332 bob_data_tx
1333 .send(bob_recv_data)
1334 .map_err(|_| std::io::Error::other("tx error"))?;
1335
1336 bob_rx.read_to_end(&mut Vec::new()).await?;
1338 Ok::<_, std::io::Error>(())
1339 });
1340
1341 let alice_tx_jh = tokio::spawn(async move {
1342 alice_tx
1343 .write_all(&alice_sent_data)
1344 .timeout(futures_time::time::Duration::from_secs(2))
1345 .await??;
1346 alice_tx.flush().await?;
1347
1348 let out = alice_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1350 alice_tx.close().await?;
1351 tracing::info!("alice closed");
1352 Ok::<_, std::io::Error>(out)
1353 });
1354
1355 let bob_tx_jh = tokio::spawn(async move {
1356 bob_tx
1357 .write_all(&bob_sent_data)
1358 .timeout(futures_time::time::Duration::from_secs(2))
1359 .await??;
1360 bob_tx.flush().await?;
1361
1362 let out = bob_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1364 bob_tx.close().await?;
1365 tracing::info!("bob closed");
1366 Ok::<_, std::io::Error>(out)
1367 });
1368
1369 let (alice_recv_data, bob_recv_data, a, b) =
1370 futures::future::try_join4(alice_tx_jh, bob_tx_jh, alice_rx_jh, bob_rx_jh)
1371 .timeout(futures_time::time::Duration::from_secs(4))
1372 .await??;
1373
1374 assert_eq!(&alice_sent_data, bob_recv_data?.as_slice());
1375 assert_eq!(&bob_sent_data, alice_recv_data?.as_slice());
1376 assert!(a.is_ok());
1377 assert!(b.is_ok());
1378
1379 Ok(())
1380 }
1381}