1pub mod ack_state;
4pub mod state;
5
6#[cfg(feature = "telemetry")]
8pub mod telemetry;
9
10use std::{
11 pin::Pin,
12 sync::{Arc, atomic::AtomicU32},
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt, future, future::AbortHandle};
18use futures_concurrency::stream::Merge;
19use state::{SocketComponents, SocketState, Stateless};
20use tracing::{Instrument, instrument};
21#[cfg(feature = "telemetry")]
22use {
23 strum::IntoDiscriminant,
24 telemetry::{SessionMessageDiscriminants, SessionTelemetryTracker},
25};
26
27use crate::{
28 errors::SessionError,
29 processing::{ReassemblerExt, SegmenterExt, SequencerExt, types::FrameInspector},
30 protocol::{OrderedFrame, SegmentRequest, SeqIndicator, SessionCodec, SessionMessage},
31};
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq, smart_default::SmartDefault)]
35pub struct SessionSocketConfig {
36 #[default(1500)]
46 pub frame_size: usize,
47 #[default(Duration::from_millis(800))]
51 pub frame_timeout: Duration,
52 #[default(0)]
57 pub max_buffered_segments: usize,
58 #[default(8192)]
63 pub capacity: usize,
64 #[default(false)]
68 pub flush_immediately: bool,
69 #[default(2048)]
75 pub control_channel_capacity: usize,
76}
77
78enum WriteState {
79 WriteOnly,
80 Writing,
81 Flushing(usize),
82}
83
84#[pin_project::pin_project(PinnedDrop)]
92pub struct SessionSocket<const C: usize, S: SocketState<C>> {
93 state: S,
94 upstream_frames_in: Pin<Box<dyn futures::io::AsyncWrite + Send>>,
96 downstream_frames_out: Pin<Box<dyn futures::io::AsyncRead + Send>>,
98 write_state: WriteState,
99}
100
101#[pin_project::pinned_drop]
107impl<const C: usize, S: SocketState<C>> PinnedDrop for SessionSocket<C, S> {
108 fn drop(self: Pin<&mut Self>) {
109 let this = self.project();
110 if let Err(error) = this.state.stop() {
111 tracing::debug!(%error, "state.stop on SessionSocket drop failed");
112 }
113 }
114}
115
116impl<const C: usize> SessionSocket<C, Stateless<C>> {
117 pub fn new_stateless<T, I>(
122 id: I,
123 transport: T,
124 cfg: SessionSocketConfig,
125 #[cfg(feature = "telemetry")] stats: impl SessionTelemetryTracker + Clone + Send + 'static,
126 ) -> Result<Self, SessionError>
127 where
128 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
129 I: std::fmt::Display + Clone,
130 {
131 let frame_size = cfg.frame_size.clamp(
133 C,
134 (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
135 );
136
137 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
139
140 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
143
144 let (packets_out, packets_in) = framed.split();
146
147 #[cfg(feature = "telemetry")]
149 let (s0, s1, s2, s3) = { (stats.clone(), stats.clone(), stats.clone(), stats.clone()) };
150
151 let upstream_frames_in = packets_out
153 .with(move |segment| {
154 #[cfg(feature = "telemetry")]
155 s0.outgoing_message(SessionMessageDiscriminants::Segment);
156
157 future::ok::<_, SessionError>(SessionMessage::<C>::Segment(segment))
158 })
159 .segmenter_with_terminating_segment::<C>(frame_size);
160
161 let last_emitted_frame = Arc::new(AtomicU32::new(0));
162 let last_emitted_frame_clone = last_emitted_frame.clone();
163
164 let stage1_span = tracing::debug_span!("SessionSocket::packets_in::pre_reassembly", session_id = %id);
166 let stage2_span = tracing::debug_span!("SessionSocket::packets_in::pre_sequencing", session_id = %id);
167 let stage3_span = tracing::debug_span!("SessionSocket::packets_in::post_sequencing", session_id = %id);
168
169 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
170
171 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
175 .filter_map(move |packet| {
177 let _span = stage1_span.enter();
178 futures::future::ready(match packet {
179 Ok(packet) => {
180 packet.try_as_segment().filter(|s| {
181 #[cfg(feature = "telemetry")]
182 s1.incoming_message(SessionMessageDiscriminants::Segment);
183
184 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
186 if s.frame_id <= last_emitted_id {
187 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
188 false
189 } else {
190 true
191 }
192 })
193 }
194 Err(error) => {
195 tracing::error!(%error, "unparseable packet");
196 #[cfg(feature = "telemetry")]
197 s1.error();
198 None
199 }
200 })
201 })
202 .reassembler(cfg.frame_timeout, cfg.capacity)
204 .filter_map(move |maybe_frame| {
206 let _span = stage2_span.enter();
207 futures::future::ready(match maybe_frame {
208 Ok(frame) => {
209 #[cfg(feature = "telemetry")]
210 s2.frame_completed();
211 Some(OrderedFrame(frame))
212 }
213 Err(error) => {
214 tracing::error!(%error, "failed to reassemble frame");
215 #[cfg(feature = "telemetry")]
216 s2.incomplete_frame();
217 None
218 }
219 })
220 })
221 .sequencer(cfg.frame_timeout, cfg.capacity)
223 .filter_map(move |maybe_frame| {
225 let _span = stage3_span.enter();
226 future::ready(match maybe_frame {
227 Ok(frame) => {
228 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
229 if frame.0.is_terminating {
230 tracing::warn!("terminating frame received");
231 packets_in_abort_handle.abort();
232 }
233 #[cfg(feature = "telemetry")]
234 s3.frame_emitted();
235 Some(Ok(frame.0))
236 }
237 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
239 tracing::error!(frame_id, "frame discarded");
240 #[cfg(feature = "telemetry")]
241 s3.frame_discarded();
242 None
243 }
244 Err(err) => {
245 #[cfg(feature = "telemetry")]
246 s3.error();
247 Some(Err(std::io::Error::other(err)))
248 }
249 })
250 })
251 .into_async_read();
252
253 Ok(Self {
254 state: Stateless::new(id),
255 upstream_frames_in: Box::pin(upstream_frames_in),
256 downstream_frames_out: Box::pin(downstream_frames_out),
257 write_state: if cfg.flush_immediately {
258 WriteState::Writing
259 } else {
260 WriteState::WriteOnly
261 },
262 })
263 }
264}
265
266impl<const C: usize, S: SocketState<C> + Clone + 'static> SessionSocket<C, S> {
267 pub fn new<T>(
270 transport: T,
271 mut state: S,
272 cfg: SessionSocketConfig,
273 #[cfg(feature = "telemetry")] stats: impl SessionTelemetryTracker + Clone + Send + 'static,
274 ) -> Result<Self, SessionError>
275 where
276 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
277 {
278 let frame_size = cfg.frame_size.clamp(
280 C,
281 (C - SessionMessage::<C>::SEGMENT_OVERHEAD)
282 * SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME.min((SeqIndicator::MAX + 1) as usize),
283 );
284
285 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
287
288 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
291
292 #[cfg(feature = "telemetry")]
294 let (s0, s1, s2, s3) = { (stats.clone(), stats.clone(), stats.clone(), stats.clone()) };
295
296 let (packets_out, packets_in) = framed.split();
298
299 let inspector = FrameInspector::new(cfg.capacity);
300
301 tracing::debug!(
302 capacity = cfg.control_channel_capacity,
303 "creating session control channel"
304 );
305 let (ctl_tx, ctl_rx) = futures::channel::mpsc::channel(cfg.control_channel_capacity.max(128));
306 state.run(SocketComponents {
307 inspector: Some(inspector.clone()),
308 ctl_tx,
309 })?;
310
311 let (segments_tx, segments_rx) = futures::channel::mpsc::channel(cfg.capacity);
313 let mut st_1 = state.clone();
314 let upstream_frames_in = segments_tx
315 .with(move |segment| {
316 let _span =
317 tracing::debug_span!("SessionSocket::packets_out::segmenter", session_id = st_1.session_id())
318 .entered();
319 if let Err(error) = st_1.segment_sent(&segment) {
322 tracing::debug!(session_id = st_1.session_id(), %error, "outgoing segment state update failed");
323 }
324 future::ok::<_, futures::channel::mpsc::SendError>(SessionMessage::<C>::Segment(segment))
325 })
326 .segmenter_with_terminating_segment::<C>(frame_size);
327
328 hopr_utils::runtime::prelude::spawn(
331 (ctl_rx, segments_rx)
332 .merge()
333 .map(move |msg| {
334 #[cfg(feature = "telemetry")]
335 s0.outgoing_message(msg.discriminant());
336 Ok(msg)
337 })
338 .forward(packets_out)
339 .map(move |result| match result {
340 Ok(_) => tracing::debug!("outgoing packet processing done"),
341 Err(error) => {
342 tracing::error!(%error, "error while processing outgoing packets")
343 }
344 })
345 .instrument(tracing::debug_span!(
346 "SessionSocket::packets_out",
347 session_id = state.session_id()
348 )),
349 );
350
351 let last_emitted_frame = Arc::new(AtomicU32::new(0));
352 let last_emitted_frame_clone = last_emitted_frame.clone();
353
354 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
355
356 let mut st_1 = state.clone();
358 let mut st_2 = state.clone();
359 let mut st_3 = state.clone();
360
361 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
364 .filter_map(move |packet| {
366 let _span = tracing::debug_span!(
367 "SessionSocket::packets_in::pre_reassembly",
368 session_id = st_1.session_id()
369 )
370 .entered();
371 futures::future::ready(match packet {
372 Ok(packet) => {
373 if let Err(error) = st_1.incoming_message(&packet) {
374 tracing::debug!(%error, "incoming message state update failed");
375 }
376 #[cfg(feature = "telemetry")]
377 s1.incoming_message(packet.discriminant());
378
379 packet.try_as_segment().filter(|s| {
381 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
382 if s.frame_id <= last_emitted_id {
383 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
384 false
385 } else {
386 true
387 }
388 })
389 }
390 Err(error) => {
391 tracing::error!(%error, "unparseable packet");
392 #[cfg(feature = "telemetry")]
393 s1.error();
394 None
395 }
396 })
397 })
398 .reassembler_with_inspector(cfg.frame_timeout, cfg.capacity, inspector)
400 .filter_map(move |maybe_frame| {
402 let _span = tracing::debug_span!(
403 "SessionSocket::packets_in::pre_sequencing",
404 session_id = st_2.session_id()
405 )
406 .entered();
407 futures::future::ready(match maybe_frame {
408 Ok(frame) => {
409 if let Err(error) = st_2.frame_complete(frame.frame_id) {
410 tracing::error!(%error, "frame complete state update failed");
411 }
412 #[cfg(feature = "telemetry")]
413 s2.frame_completed();
414 Some(OrderedFrame(frame))
415 }
416 Err(error) => {
417 tracing::error!(%error, "failed to reassemble frame");
418 #[cfg(feature = "telemetry")]
419 s2.incomplete_frame();
420 None
421 }
422 })
423 })
424 .sequencer(cfg.frame_timeout, cfg.frame_size)
426 .filter_map(move |maybe_frame| {
429 let _span = tracing::debug_span!(
430 "SessionSocket::packets_in::post_sequencing",
431 session_id = st_3.session_id()
432 )
433 .entered();
434 future::ready(match maybe_frame {
436 Ok(frame) => {
437 if let Err(error) = st_3.frame_emitted(frame.0.frame_id) {
438 tracing::error!(%error, "frame received state update failed");
439 }
440 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
441 if frame.0.is_terminating {
442 tracing::warn!("terminating frame received");
443 packets_in_abort_handle.abort();
444 }
445 #[cfg(feature = "telemetry")]
446 s3.frame_emitted();
447 Some(Ok(frame.0))
448 }
449 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
450 if let Err(error) = st_3.frame_discarded(frame_id) {
451 tracing::error!(%error, "frame discarded state update failed");
452 }
453 #[cfg(feature = "telemetry")]
454 s3.frame_discarded();
455 None }
457 Err(err) => {
458 #[cfg(feature = "telemetry")]
459 s3.error();
460 Some(Err(std::io::Error::other(err)))
461 }
462 })
463 })
464 .into_async_read();
465
466 Ok(Self {
467 state,
468 upstream_frames_in: Box::pin(upstream_frames_in),
469 downstream_frames_out: Box::pin(downstream_frames_out),
470 write_state: if cfg.flush_immediately {
471 WriteState::Writing
472 } else {
473 WriteState::WriteOnly
474 },
475 })
476 }
477}
478
479impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncRead for SessionSocket<C, S> {
480 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
481 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
482 self.project().downstream_frames_out.as_mut().poll_read(cx, buf)
483 }
484}
485
486impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncWrite for SessionSocket<C, S> {
487 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
488 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
489 let this = self.project();
490 loop {
491 match this.write_state {
492 WriteState::WriteOnly => {
493 return this.upstream_frames_in.as_mut().poll_write(cx, buf);
494 }
495 WriteState::Writing => {
496 let len = futures::ready!(this.upstream_frames_in.as_mut().poll_write(cx, buf))?;
497 *this.write_state = WriteState::Flushing(len);
498 }
499 WriteState::Flushing(len) => {
500 let res = futures::ready!(this.upstream_frames_in.as_mut().poll_flush(cx)).map(|_| *len);
501 *this.write_state = WriteState::Writing;
502 return Poll::Ready(res);
503 }
504 }
505 }
506 }
507
508 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
509 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
510 self.project().upstream_frames_in.as_mut().poll_flush(cx)
511 }
512
513 #[instrument(name = "SessionSocket::poll_close", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
514 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
515 let this = self.project();
516 let _ = this.state.stop();
517 this.upstream_frames_in.as_mut().poll_close(cx)
518 }
519}
520
521#[cfg(feature = "runtime-tokio")]
522impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncRead for SessionSocket<C, S> {
523 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id()))]
524 fn poll_read(
525 mut self: Pin<&mut Self>,
526 cx: &mut Context<'_>,
527 buf: &mut tokio::io::ReadBuf<'_>,
528 ) -> Poll<std::io::Result<()>> {
529 let slice = buf.initialize_unfilled();
530 let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
531 buf.advance(n);
532 Poll::Ready(Ok(()))
533 }
534}
535
536#[cfg(feature = "runtime-tokio")]
537impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncWrite for SessionSocket<C, S> {
538 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
539 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
540 futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
541 }
542
543 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
544 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
545 futures::AsyncWrite::poll_flush(self.as_mut(), cx)
546 }
547
548 #[instrument(name = "SessionSocket::poll_shutdown", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
549 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
550 futures::AsyncWrite::poll_close(self.as_mut(), cx)
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use std::collections::HashSet;
557
558 use futures::{AsyncReadExt, AsyncWriteExt};
559 use futures_time::future::FutureExt;
560 use hopr_crypto_packet::prelude::HoprPacket;
561
562 use super::*;
563 #[cfg(feature = "telemetry")]
564 use crate::socket::telemetry::{NoopTracker, tests::TestTelemetryTracker};
565 use crate::{AcknowledgementState, AcknowledgementStateConfig, utils::test::*};
566
567 const MTU: usize = HoprPacket::PAYLOAD_SIZE;
568
569 const FRAME_SIZE: usize = 1500;
570
571 const DATA_SIZE: usize = 17 * MTU + 271; #[test_log::test(tokio::test)]
574 async fn stateless_socket_unidirectional_should_work() -> anyhow::Result<()> {
575 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
576
577 let sock_cfg = SessionSocketConfig {
578 frame_size: FRAME_SIZE,
579 ..Default::default()
580 };
581
582 #[cfg(feature = "telemetry")]
583 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
584
585 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
586 "alice",
587 alice,
588 sock_cfg,
589 #[cfg(feature = "telemetry")]
590 alice_tracker.clone(),
591 )?;
592 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
593 "bob",
594 bob,
595 sock_cfg,
596 #[cfg(feature = "telemetry")]
597 bob_tracker.clone(),
598 )?;
599
600 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
601
602 alice_socket
603 .write_all(&data)
604 .timeout(futures_time::time::Duration::from_secs(2))
605 .await??;
606 alice_socket.flush().await?;
607
608 let mut bob_data = [0u8; DATA_SIZE];
609 bob_socket
610 .read_exact(&mut bob_data)
611 .timeout(futures_time::time::Duration::from_secs(2))
612 .await??;
613 assert_eq!(data, bob_data);
614
615 alice_socket.close().await?;
616 bob_socket.close().await?;
617
618 #[cfg(feature = "telemetry")]
619 {
620 insta::assert_yaml_snapshot!(alice_tracker);
621 insta::assert_yaml_snapshot!(bob_tracker);
622 }
623
624 Ok(())
625 }
626
627 #[test_log::test(tokio::test)]
628 async fn stateful_socket_unidirectional_should_work() -> anyhow::Result<()> {
629 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
630
631 let sock_cfg = SessionSocketConfig {
632 frame_size: FRAME_SIZE,
633 ..Default::default()
634 };
635
636 let ack_cfg = AcknowledgementStateConfig {
637 expected_packet_latency: Duration::from_millis(2),
638 acknowledgement_delay: Duration::from_millis(5),
639 ..Default::default()
640 };
641
642 let mut alice_socket = SessionSocket::<MTU, _>::new(
643 alice,
644 AcknowledgementState::new("alice", ack_cfg),
645 sock_cfg,
646 #[cfg(feature = "telemetry")]
647 NoopTracker,
648 )?;
649 let mut bob_socket = SessionSocket::<MTU, _>::new(
650 bob,
651 AcknowledgementState::new("bob", ack_cfg),
652 sock_cfg,
653 #[cfg(feature = "telemetry")]
654 NoopTracker,
655 )?;
656
657 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
658
659 alice_socket
660 .write_all(&data)
661 .timeout(futures_time::time::Duration::from_secs(2))
662 .await??;
663 alice_socket.flush().await?;
664
665 let mut bob_data = [0u8; DATA_SIZE];
666 bob_socket
667 .read_exact(&mut bob_data)
668 .timeout(futures_time::time::Duration::from_secs(2))
669 .await??;
670 assert_eq!(data, bob_data);
671
672 alice_socket.close().await?;
673 bob_socket.close().await?;
674
675 Ok(())
676 }
677
678 #[test_log::test(tokio::test)]
679 async fn stateless_socket_bidirectional_should_work() -> anyhow::Result<()> {
680 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
681
682 let sock_cfg = SessionSocketConfig {
683 frame_size: FRAME_SIZE,
684 ..Default::default()
685 };
686
687 #[cfg(feature = "telemetry")]
688 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
689
690 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
691 "alice",
692 alice,
693 sock_cfg,
694 #[cfg(feature = "telemetry")]
695 alice_tracker.clone(),
696 )?;
697 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
698 "bob",
699 bob,
700 sock_cfg,
701 #[cfg(feature = "telemetry")]
702 bob_tracker.clone(),
703 )?;
704
705 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
706 alice_socket
707 .write_all(&alice_sent_data)
708 .timeout(futures_time::time::Duration::from_secs(2))
709 .await??;
710 alice_socket.flush().await?;
711
712 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
713 bob_socket
714 .write_all(&bob_sent_data)
715 .timeout(futures_time::time::Duration::from_secs(2))
716 .await??;
717 bob_socket.flush().await?;
718
719 let mut bob_recv_data = [0u8; DATA_SIZE];
720 bob_socket
721 .read_exact(&mut bob_recv_data)
722 .timeout(futures_time::time::Duration::from_secs(2))
723 .await??;
724 assert_eq!(alice_sent_data, bob_recv_data);
725
726 let mut alice_recv_data = [0u8; DATA_SIZE];
727 alice_socket
728 .read_exact(&mut alice_recv_data)
729 .timeout(futures_time::time::Duration::from_secs(2))
730 .await??;
731 assert_eq!(bob_sent_data, alice_recv_data);
732
733 #[cfg(feature = "telemetry")]
734 {
735 insta::assert_yaml_snapshot!(alice_tracker);
736 insta::assert_yaml_snapshot!(bob_tracker);
737 }
738
739 Ok(())
740 }
741
742 #[test_log::test(tokio::test)]
743 async fn stateful_socket_bidirectional_should_work() -> anyhow::Result<()> {
744 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
745
746 let sock_cfg = SessionSocketConfig {
750 frame_size: FRAME_SIZE,
751 ..Default::default()
752 };
753
754 let ack_cfg = AcknowledgementStateConfig {
755 expected_packet_latency: Duration::from_millis(2),
756 acknowledgement_delay: Duration::from_millis(10),
757 ..Default::default()
758 };
759
760 let mut alice_socket = SessionSocket::<MTU, _>::new(
761 alice,
762 AcknowledgementState::new("alice", ack_cfg),
763 sock_cfg,
764 #[cfg(feature = "telemetry")]
765 NoopTracker,
766 )?;
767 let mut bob_socket = SessionSocket::<MTU, _>::new(
768 bob,
769 AcknowledgementState::new("bob", ack_cfg),
770 sock_cfg,
771 #[cfg(feature = "telemetry")]
772 NoopTracker,
773 )?;
774
775 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
776 alice_socket
777 .write_all(&alice_sent_data)
778 .timeout(futures_time::time::Duration::from_secs(2))
779 .await??;
780 alice_socket.flush().await?;
781
782 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
783 bob_socket
784 .write_all(&bob_sent_data)
785 .timeout(futures_time::time::Duration::from_secs(2))
786 .await??;
787 bob_socket.flush().await?;
788
789 let mut bob_recv_data = [0u8; DATA_SIZE];
790 bob_socket
791 .read_exact(&mut bob_recv_data)
792 .timeout(futures_time::time::Duration::from_secs(2))
793 .await??;
794 assert_eq!(alice_sent_data, bob_recv_data);
795
796 let mut alice_recv_data = [0u8; DATA_SIZE];
797 alice_socket
798 .read_exact(&mut alice_recv_data)
799 .timeout(futures_time::time::Duration::from_secs(2))
800 .await??;
801 assert_eq!(bob_sent_data, alice_recv_data);
802
803 Ok(())
804 }
805
806 #[test_log::test(tokio::test)]
807 async fn stateless_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
808 let network_cfg = FaultyNetworkConfig {
809 mixing_factor: 10,
810 ..Default::default()
811 };
812
813 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
814
815 let sock_cfg = SessionSocketConfig {
816 frame_size: FRAME_SIZE,
817 ..Default::default()
818 };
819
820 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
821 "alice",
822 alice,
823 sock_cfg,
824 #[cfg(feature = "telemetry")]
825 NoopTracker,
826 )?;
827 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
828 "bob",
829 bob,
830 sock_cfg,
831 #[cfg(feature = "telemetry")]
832 NoopTracker,
833 )?;
834
835 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
836 alice_socket
837 .write_all(&data)
838 .timeout(futures_time::time::Duration::from_secs(2))
839 .await??;
840 alice_socket.flush().await?;
841
842 let mut bob_recv_data = [0u8; DATA_SIZE];
843 bob_socket
844 .read_exact(&mut bob_recv_data)
845 .timeout(futures_time::time::Duration::from_secs(2))
846 .await??;
847 assert_eq!(data, bob_recv_data);
848
849 alice_socket.close().await?;
850 bob_socket.close().await?;
851
852 Ok(())
853 }
854
855 #[test_log::test(tokio::test)]
856 async fn stateful_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
857 let network_cfg = FaultyNetworkConfig {
858 mixing_factor: 10,
859 ..Default::default()
860 };
861
862 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
863
864 let sock_cfg = SessionSocketConfig {
865 frame_size: FRAME_SIZE,
866 ..Default::default()
867 };
868
869 let ack_cfg = AcknowledgementStateConfig {
870 expected_packet_latency: Duration::from_millis(2),
871 acknowledgement_delay: Duration::from_millis(5),
872 ..Default::default()
873 };
874
875 let mut alice_socket = SessionSocket::<MTU, _>::new(
876 alice,
877 AcknowledgementState::new("alice", ack_cfg),
878 sock_cfg,
879 #[cfg(feature = "telemetry")]
880 NoopTracker,
881 )?;
882 let mut bob_socket = SessionSocket::<MTU, _>::new(
883 bob,
884 AcknowledgementState::new("bob", ack_cfg),
885 sock_cfg,
886 #[cfg(feature = "telemetry")]
887 NoopTracker,
888 )?;
889
890 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
891 alice_socket
892 .write_all(&data)
893 .timeout(futures_time::time::Duration::from_secs(2))
894 .await??;
895 alice_socket.flush().await?;
896
897 let mut bob_recv_data = [0u8; DATA_SIZE];
898 bob_socket
899 .read_exact(&mut bob_recv_data)
900 .timeout(futures_time::time::Duration::from_secs(2))
901 .await??;
902 assert_eq!(data, bob_recv_data);
903
904 alice_socket.close().await?;
905 bob_socket.close().await?;
906
907 Ok(())
908 }
909
910 #[test_log::test(tokio::test)]
911 async fn stateless_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
912 let network_cfg = FaultyNetworkConfig {
913 mixing_factor: 10,
914 ..Default::default()
915 };
916
917 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
918
919 let sock_cfg = SessionSocketConfig {
920 frame_size: FRAME_SIZE,
921 ..Default::default()
922 };
923
924 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
925 "alice",
926 alice,
927 sock_cfg,
928 #[cfg(feature = "telemetry")]
929 NoopTracker,
930 )?;
931 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
932 "bob",
933 bob,
934 sock_cfg,
935 #[cfg(feature = "telemetry")]
936 NoopTracker,
937 )?;
938
939 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
940 alice_socket
941 .write_all(&alice_sent_data)
942 .timeout(futures_time::time::Duration::from_secs(2))
943 .await??;
944 alice_socket.flush().await?;
945
946 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
947 bob_socket
948 .write_all(&bob_sent_data)
949 .timeout(futures_time::time::Duration::from_secs(2))
950 .await??;
951 bob_socket.flush().await?;
952
953 let mut bob_recv_data = [0u8; DATA_SIZE];
954 bob_socket
955 .read_exact(&mut bob_recv_data)
956 .timeout(futures_time::time::Duration::from_secs(2))
957 .await??;
958 assert_eq!(alice_sent_data, bob_recv_data);
959
960 let mut alice_recv_data = [0u8; DATA_SIZE];
961 alice_socket
962 .read_exact(&mut alice_recv_data)
963 .timeout(futures_time::time::Duration::from_secs(2))
964 .await??;
965 assert_eq!(bob_sent_data, alice_recv_data);
966
967 alice_socket.close().await?;
968 bob_socket.close().await?;
969
970 Ok(())
971 }
972
973 #[test_log::test(tokio::test)]
974 async fn stateful_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
975 let network_cfg = FaultyNetworkConfig {
976 mixing_factor: 10,
977 ..Default::default()
978 };
979
980 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
981
982 let sock_cfg = SessionSocketConfig {
983 frame_size: FRAME_SIZE,
984 ..Default::default()
985 };
986
987 let ack_cfg = AcknowledgementStateConfig {
988 expected_packet_latency: Duration::from_millis(2),
989 acknowledgement_delay: Duration::from_millis(5),
990 ..Default::default()
991 };
992
993 let mut alice_socket = SessionSocket::<MTU, _>::new(
994 alice,
995 AcknowledgementState::new("alice", ack_cfg),
996 sock_cfg,
997 #[cfg(feature = "telemetry")]
998 NoopTracker,
999 )?;
1000 let mut bob_socket = SessionSocket::<MTU, _>::new(
1001 bob,
1002 AcknowledgementState::new("bob", ack_cfg),
1003 sock_cfg,
1004 #[cfg(feature = "telemetry")]
1005 NoopTracker,
1006 )?;
1007
1008 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1009 alice_socket
1010 .write_all(&alice_sent_data)
1011 .timeout(futures_time::time::Duration::from_secs(2))
1012 .await??;
1013 alice_socket.flush().await?;
1014
1015 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1016 bob_socket
1017 .write_all(&bob_sent_data)
1018 .timeout(futures_time::time::Duration::from_secs(2))
1019 .await??;
1020 bob_socket.flush().await?;
1021
1022 let mut bob_recv_data = [0u8; DATA_SIZE];
1023 bob_socket
1024 .read_exact(&mut bob_recv_data)
1025 .timeout(futures_time::time::Duration::from_secs(2))
1026 .await??;
1027 assert_eq!(alice_sent_data, bob_recv_data);
1028
1029 let mut alice_recv_data = [0u8; DATA_SIZE];
1030 alice_socket
1031 .read_exact(&mut alice_recv_data)
1032 .timeout(futures_time::time::Duration::from_secs(2))
1033 .await??;
1034 assert_eq!(bob_sent_data, alice_recv_data);
1035
1036 alice_socket.close().await?;
1037 bob_socket.close().await?;
1038
1039 Ok(())
1040 }
1041
1042 #[test_log::test(tokio::test)]
1043 async fn stateless_socket_unidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
1044 let (alice, bob) = setup_alice_bob::<MTU>(
1045 FaultyNetworkConfig {
1046 avg_delay: Duration::from_millis(10),
1047 ids_to_drop: HashSet::from_iter([0_usize]),
1048 ..Default::default()
1049 },
1050 None,
1051 None,
1052 );
1053
1054 let alice_cfg = SessionSocketConfig {
1055 frame_size: FRAME_SIZE,
1056 ..Default::default()
1057 };
1058
1059 let bob_cfg = SessionSocketConfig {
1060 frame_size: FRAME_SIZE,
1061 frame_timeout: Duration::from_millis(55),
1062 ..Default::default()
1063 };
1064
1065 #[cfg(feature = "telemetry")]
1066 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
1067
1068 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
1069 "alice",
1070 alice,
1071 alice_cfg,
1072 #[cfg(feature = "telemetry")]
1073 alice_tracker.clone(),
1074 )?;
1075 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
1076 "bob",
1077 bob,
1078 bob_cfg,
1079 #[cfg(feature = "telemetry")]
1080 bob_tracker.clone(),
1081 )?;
1082
1083 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1084 alice_socket
1085 .write_all(&data)
1086 .timeout(futures_time::time::Duration::from_secs(2))
1087 .await??;
1088 alice_socket.flush().await?;
1089 alice_socket.close().await?;
1090
1091 let mut bob_data = Vec::with_capacity(DATA_SIZE);
1092 bob_socket
1093 .read_to_end(&mut bob_data)
1094 .timeout(futures_time::time::Duration::from_secs(2))
1095 .await??;
1096
1097 assert_eq!(data.len() - 1500, bob_data.len());
1099 assert_eq!(&data[1500..], &bob_data);
1100
1101 bob_socket.close().await?;
1102
1103 #[cfg(feature = "telemetry")]
1104 {
1105 insta::assert_yaml_snapshot!(alice_tracker);
1106 insta::assert_yaml_snapshot!(bob_tracker);
1107 }
1108
1109 Ok(())
1110 }
1111
1112 #[test_log::test(tokio::test)]
1113 async fn stateful_socket_unidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1114 let (alice, bob) = setup_alice_bob::<MTU>(
1115 FaultyNetworkConfig {
1116 avg_delay: Duration::from_millis(10),
1117 ids_to_drop: HashSet::from_iter([0_usize]),
1118 ..Default::default()
1119 },
1120 None,
1121 None,
1122 );
1123
1124 let alice_cfg = SessionSocketConfig {
1125 frame_size: FRAME_SIZE,
1126 ..Default::default()
1127 };
1128
1129 let bob_cfg = SessionSocketConfig {
1130 frame_size: FRAME_SIZE,
1131 frame_timeout: Duration::from_millis(1000),
1132 ..Default::default()
1133 };
1134
1135 let ack_cfg = AcknowledgementStateConfig {
1136 expected_packet_latency: Duration::from_millis(10),
1137 acknowledgement_delay: Duration::from_millis(40),
1138 ..Default::default()
1139 };
1140
1141 let mut alice_socket = SessionSocket::<MTU, _>::new(
1142 alice,
1143 AcknowledgementState::new("alice", ack_cfg),
1144 alice_cfg,
1145 #[cfg(feature = "telemetry")]
1146 NoopTracker,
1147 )?;
1148 let mut bob_socket = SessionSocket::<MTU, _>::new(
1149 bob,
1150 AcknowledgementState::new("bob", ack_cfg),
1151 bob_cfg,
1152 #[cfg(feature = "telemetry")]
1153 NoopTracker,
1154 )?;
1155
1156 let data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1157
1158 let alice_jh = tokio::spawn(async move {
1159 alice_socket
1160 .write_all(&data)
1161 .timeout(futures_time::time::Duration::from_secs(5))
1162 .await??;
1163
1164 alice_socket.flush().await?;
1165
1166 let mut vec = Vec::new();
1168 alice_socket.read_to_end(&mut vec).await?;
1169 alice_socket.close().await?;
1170
1171 Ok::<_, std::io::Error>(vec)
1172 });
1173
1174 let mut bob_data = [0u8; DATA_SIZE];
1175 bob_socket
1176 .read_exact(&mut bob_data)
1177 .timeout(futures_time::time::Duration::from_secs(5))
1178 .await??;
1179 assert_eq!(data, bob_data);
1180
1181 bob_socket.close().await?;
1182
1183 let alice_recv = alice_jh.await??;
1184 assert!(alice_recv.is_empty());
1185
1186 Ok(())
1187 }
1188
1189 #[test_log::test(tokio::test)]
1190 async fn stateless_socket_bidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
1191 let (alice, bob) = setup_alice_bob::<MTU>(
1192 FaultyNetworkConfig {
1193 avg_delay: Duration::from_millis(10),
1194 ids_to_drop: HashSet::from_iter([0_usize]),
1195 ..Default::default()
1196 },
1197 None,
1198 None,
1199 );
1200
1201 let alice_cfg = SessionSocketConfig {
1202 frame_size: FRAME_SIZE,
1203 frame_timeout: Duration::from_millis(55),
1204 ..Default::default()
1205 };
1206
1207 let bob_cfg = SessionSocketConfig {
1208 frame_size: FRAME_SIZE,
1209 frame_timeout: Duration::from_millis(55),
1210 ..Default::default()
1211 };
1212
1213 #[cfg(feature = "telemetry")]
1214 let (alice_tracker, bob_tracker) = (TestTelemetryTracker::default(), TestTelemetryTracker::default());
1215
1216 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless(
1217 "alice",
1218 alice,
1219 alice_cfg,
1220 #[cfg(feature = "telemetry")]
1221 alice_tracker.clone(),
1222 )?;
1223 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless(
1224 "bob",
1225 bob,
1226 bob_cfg,
1227 #[cfg(feature = "telemetry")]
1228 bob_tracker.clone(),
1229 )?;
1230
1231 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1232 alice_socket
1233 .write_all(&alice_sent_data)
1234 .timeout(futures_time::time::Duration::from_secs(2))
1235 .await??;
1236 alice_socket.flush().await?;
1237
1238 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1239 bob_socket
1240 .write_all(&bob_sent_data)
1241 .timeout(futures_time::time::Duration::from_secs(2))
1242 .await??;
1243 bob_socket.flush().await?;
1244
1245 alice_socket.close().await?;
1246 bob_socket.close().await?;
1247
1248 let mut alice_recv_data = Vec::with_capacity(DATA_SIZE);
1249 alice_socket
1250 .read_to_end(&mut alice_recv_data)
1251 .timeout(futures_time::time::Duration::from_secs(2))
1252 .await??;
1253
1254 let mut bob_recv_data = Vec::with_capacity(DATA_SIZE);
1255 bob_socket
1256 .read_to_end(&mut bob_recv_data)
1257 .timeout(futures_time::time::Duration::from_secs(2))
1258 .await??;
1259
1260 assert_eq!(bob_sent_data.len() - 1500, alice_recv_data.len());
1262 assert_eq!(&bob_sent_data[1500..], &alice_recv_data);
1263
1264 assert_eq!(alice_sent_data.len() - 1500, bob_recv_data.len());
1265 assert_eq!(&alice_sent_data[1500..], &bob_recv_data);
1266
1267 #[cfg(feature = "telemetry")]
1268 {
1269 insta::assert_yaml_snapshot!(alice_tracker);
1270 insta::assert_yaml_snapshot!(bob_tracker);
1271 }
1272
1273 Ok(())
1274 }
1275
1276 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1278 async fn stateful_socket_bidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1279 let (alice, bob) = setup_alice_bob::<MTU>(
1280 FaultyNetworkConfig {
1281 avg_delay: Duration::from_millis(10),
1282 ids_to_drop: HashSet::from_iter([0_usize]),
1283 ..Default::default()
1284 },
1285 None,
1286 None,
1287 );
1288
1289 let alice_cfg = SessionSocketConfig {
1293 frame_size: FRAME_SIZE,
1294 frame_timeout: Duration::from_millis(1000),
1295 ..Default::default()
1296 };
1297
1298 let bob_cfg = SessionSocketConfig {
1299 frame_size: FRAME_SIZE,
1300 frame_timeout: Duration::from_millis(1000),
1301 ..Default::default()
1302 };
1303
1304 let ack_cfg = AcknowledgementStateConfig {
1305 expected_packet_latency: Duration::from_millis(10),
1306 acknowledgement_delay: Duration::from_millis(40),
1307 ..Default::default()
1308 };
1309
1310 let (mut alice_rx, mut alice_tx) = SessionSocket::<MTU, _>::new(
1311 alice,
1312 AcknowledgementState::new("alice", ack_cfg),
1313 alice_cfg,
1314 #[cfg(feature = "telemetry")]
1315 NoopTracker,
1316 )?
1317 .split();
1318
1319 let (mut bob_rx, mut bob_tx) = SessionSocket::<MTU, _>::new(
1320 bob,
1321 AcknowledgementState::new("bob", ack_cfg),
1322 bob_cfg,
1323 #[cfg(feature = "telemetry")]
1324 NoopTracker,
1325 )?
1326 .split();
1327
1328 let alice_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1329 let (alice_data_tx, alice_recv_data) = futures::channel::oneshot::channel();
1330 let alice_rx_jh = tokio::spawn(async move {
1331 let mut alice_recv_data = vec![0u8; DATA_SIZE];
1332 alice_rx.read_exact(&mut alice_recv_data).await?;
1333 alice_data_tx
1334 .send(alice_recv_data)
1335 .map_err(|_| std::io::Error::other("tx error"))?;
1336
1337 alice_rx.read_to_end(&mut Vec::new()).await?;
1339 Ok::<_, std::io::Error>(())
1340 });
1341
1342 let bob_sent_data = hopr_types::crypto_random::random_bytes::<DATA_SIZE>();
1343 let (bob_data_tx, bob_recv_data) = futures::channel::oneshot::channel();
1344 let bob_rx_jh = tokio::spawn(async move {
1345 let mut bob_recv_data = vec![0u8; DATA_SIZE];
1346 bob_rx.read_exact(&mut bob_recv_data).await?;
1347 bob_data_tx
1348 .send(bob_recv_data)
1349 .map_err(|_| std::io::Error::other("tx error"))?;
1350
1351 bob_rx.read_to_end(&mut Vec::new()).await?;
1353 Ok::<_, std::io::Error>(())
1354 });
1355
1356 let alice_tx_jh = tokio::spawn(async move {
1357 alice_tx
1358 .write_all(&alice_sent_data)
1359 .timeout(futures_time::time::Duration::from_secs(2))
1360 .await??;
1361 alice_tx.flush().await?;
1362
1363 let out = alice_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1365 alice_tx.close().await?;
1366 tracing::info!("alice closed");
1367 Ok::<_, std::io::Error>(out)
1368 });
1369
1370 let bob_tx_jh = tokio::spawn(async move {
1371 bob_tx
1372 .write_all(&bob_sent_data)
1373 .timeout(futures_time::time::Duration::from_secs(2))
1374 .await??;
1375 bob_tx.flush().await?;
1376
1377 let out = bob_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1379 bob_tx.close().await?;
1380 tracing::info!("bob closed");
1381 Ok::<_, std::io::Error>(out)
1382 });
1383
1384 let (alice_recv_data, bob_recv_data, a, b) =
1385 futures::future::try_join4(alice_tx_jh, bob_tx_jh, alice_rx_jh, bob_rx_jh)
1386 .timeout(futures_time::time::Duration::from_secs(4))
1387 .await??;
1388
1389 assert_eq!(&alice_sent_data, bob_recv_data?.as_slice());
1390 assert_eq!(&bob_sent_data, alice_recv_data?.as_slice());
1391 assert!(a.is_ok());
1392 assert!(b.is_ok());
1393
1394 Ok(())
1395 }
1396}