1pub mod ack_state;
4pub mod state;
5
6use std::{
7 pin::Pin,
8 sync::{Arc, atomic::AtomicU32},
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt, future, future::AbortHandle};
14use futures_concurrency::stream::Merge;
15use state::SocketState;
16use tracing::{Instrument, instrument};
17
18use crate::{
19 errors::SessionError,
20 processing::{ReassemblerExt, SegmenterExt, SequencerExt, types::FrameInspector},
21 protocol::{OrderedFrame, SegmentRequest, SeqIndicator, SessionCodec, SessionMessage},
22 socket::state::{SocketComponents, Stateless},
23};
24
25#[derive(Debug, Copy, Clone, Eq, PartialEq, smart_default::SmartDefault)]
27pub struct SessionSocketConfig {
28 #[default(1500)]
38 pub frame_size: usize,
39 #[default(Duration::from_millis(800))]
43 pub frame_timeout: Duration,
44 #[default(0)]
49 pub max_buffered_segments: usize,
50 #[default(8192)]
55 pub capacity: usize,
56
57 #[default(false)]
61 pub flush_immediately: bool,
62}
63
64enum WriteState {
65 WriteOnly,
66 Writing,
67 Flushing(usize),
68}
69
70#[pin_project::pin_project]
78pub struct SessionSocket<const C: usize, S> {
79 upstream_frames_in: Pin<Box<dyn futures::io::AsyncWrite + Send>>,
81 downstream_frames_out: Pin<Box<dyn futures::io::AsyncRead + Send>>,
83 state: S,
84 write_state: WriteState,
85}
86
87impl<const C: usize> SessionSocket<C, Stateless<C>> {
88 pub fn new_stateless<T, I>(id: I, transport: T, cfg: SessionSocketConfig) -> Result<Self, SessionError>
93 where
94 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
95 I: std::fmt::Display + Clone,
96 {
97 let frame_size = cfg.frame_size.clamp(
99 C,
100 (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
101 );
102
103 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
105
106 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
109
110 let (packets_out, packets_in) = framed.split();
112
113 let upstream_frames_in = packets_out
115 .with(|segment| future::ok::<_, SessionError>(SessionMessage::<C>::Segment(segment)))
116 .segmenter_with_terminating_segment::<C>(frame_size);
117
118 let last_emitted_frame = Arc::new(AtomicU32::new(0));
119 let last_emitted_frame_clone = last_emitted_frame.clone();
120
121 let session_id_1 = id.to_string();
122 let session_id_2 = id.to_string();
123 let session_id_3 = id.to_string();
124
125 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
126
127 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
131 .filter_map(move |packet| {
133 futures::future::ready(match packet {
134 Ok(packet) => packet.try_as_segment().filter(|s| {
135 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
137 if s.frame_id <= last_emitted_id {
138 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
139 false
140 } else {
141 true
142 }
143 }),
144 Err(error) => {
145 tracing::error!(%error, "unparseable packet");
146 None
147 }
148 })
149 .instrument(tracing::debug_span!(
150 "SessionSocket::packets_in::pre_reassembly",
151 session_id = session_id_1
152 ))
153 })
154 .reassembler(cfg.frame_timeout, cfg.capacity)
156 .filter_map(move |maybe_frame| {
158 futures::future::ready(match maybe_frame {
159 Ok(frame) => Some(OrderedFrame(frame)),
160 Err(error) => {
161 tracing::error!(%error, "failed to reassemble frame");
162 None
163 }
164 })
165 .instrument(tracing::debug_span!(
166 "SessionSocket::packets_in::pre_sequencing",
167 session_id = session_id_2
168 ))
169 })
170 .sequencer(cfg.frame_timeout, cfg.capacity)
172 .filter_map(move |maybe_frame| {
174 future::ready(match maybe_frame {
175 Ok(frame) => {
176 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
177 if frame.0.is_terminating {
178 tracing::warn!("terminating frame received");
179 packets_in_abort_handle.abort();
180 }
181 Some(Ok(frame.0))
182 }
183 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
185 tracing::error!(frame_id, "frame discarded");
186 None
187 }
188 Err(err) => Some(Err(std::io::Error::other(err))),
189 })
190 .instrument(tracing::debug_span!(
191 "SessionSocket::packets_in::post_sequencing",
192 session_id = session_id_3
193 ))
194 })
195 .into_async_read();
196
197 Ok(Self {
198 state: Stateless::new(id),
199 upstream_frames_in: Box::pin(upstream_frames_in),
200 downstream_frames_out: Box::pin(downstream_frames_out),
201 write_state: if cfg.flush_immediately {
202 WriteState::Writing
203 } else {
204 WriteState::WriteOnly
205 },
206 })
207 }
208}
209
210impl<const C: usize, S: SocketState<C> + Clone + 'static> SessionSocket<C, S> {
211 pub fn new<T>(transport: T, mut state: S, cfg: SessionSocketConfig) -> Result<Self, SessionError>
214 where
215 T: futures::io::AsyncRead + futures::io::AsyncWrite + Send + Unpin + 'static,
216 {
217 let frame_size = cfg.frame_size.clamp(
219 C,
220 (C - SessionMessage::<C>::SEGMENT_OVERHEAD)
221 * SegmentRequest::<C>::MAX_MISSING_SEGMENTS_PER_FRAME.min((SeqIndicator::MAX + 1) as usize),
222 );
223
224 let mut framed = asynchronous_codec::Framed::new(transport, SessionCodec::<C>);
226
227 framed.set_send_high_water_mark(1.max(cfg.max_buffered_segments * C));
230
231 let (packets_out, packets_in) = framed.split();
233
234 let inspector = FrameInspector::new(cfg.capacity);
235
236 let ctl_channel_capacity = std::env::var("HOPR_INTERNAL_SESSION_CTL_CHANNEL_CAPACITY")
237 .ok()
238 .and_then(|s| s.trim().parse::<usize>().ok())
239 .filter(|&c| c > 0)
240 .unwrap_or(2048);
241
242 tracing::debug!(capacity = ctl_channel_capacity, "Creating session control channel");
243 let (ctl_tx, ctl_rx) = futures::channel::mpsc::channel(ctl_channel_capacity);
244 state.run(SocketComponents {
245 inspector: Some(inspector.clone()),
246 ctl_tx,
247 })?;
248
249 let (segments_tx, segments_rx) = futures::channel::mpsc::channel(cfg.capacity);
251 let mut st_1 = state.clone();
252 let upstream_frames_in = segments_tx
253 .with(move |segment| {
254 if let Err(error) = st_1.segment_sent(&segment) {
257 tracing::debug!(%error, "outgoing segment state update failed");
258 }
259 future::ok::<_, futures::channel::mpsc::SendError>(SessionMessage::<C>::Segment(segment))
260 })
261 .segmenter_with_terminating_segment::<C>(frame_size);
262
263 hopr_async_runtime::prelude::spawn(
266 (ctl_rx, segments_rx)
267 .merge()
268 .map(Ok)
269 .forward(packets_out)
270 .map(move |result| match result {
271 Ok(_) => tracing::debug!("outgoing packet processing done"),
272 Err(error) => {
273 tracing::error!(%error, "error while processing outgoing packets")
274 }
275 })
276 .instrument(tracing::debug_span!(
277 "SessionSocket::packets_out",
278 session_id = state.session_id()
279 )),
280 );
281
282 let last_emitted_frame = Arc::new(AtomicU32::new(0));
283 let last_emitted_frame_clone = last_emitted_frame.clone();
284
285 let (packets_in_abort_handle, packets_in_abort_reg) = AbortHandle::new_pair();
286
287 let mut st_1 = state.clone();
289 let mut st_2 = state.clone();
290 let mut st_3 = state.clone();
291
292 let downstream_frames_out = futures::stream::Abortable::new(packets_in, packets_in_abort_reg)
295 .filter_map(move |packet| {
297 futures::future::ready(match packet {
298 Ok(packet) => {
299 if let Err(error) = match &packet {
300 SessionMessage::Segment(s) => st_1.incoming_segment(&s.id(), s.seq_flags),
301 SessionMessage::Request(r) => st_1.incoming_retransmission_request(r.clone()),
302 SessionMessage::Acknowledge(a) => st_1.incoming_acknowledged_frames(a.clone()),
303 } {
304 tracing::debug!(%error, "incoming message state update failed");
305 }
306 packet.try_as_segment().filter(|s| {
308 let last_emitted_id = last_emitted_frame.load(std::sync::atomic::Ordering::Relaxed);
309 if s.frame_id <= last_emitted_id {
310 tracing::warn!(frame_id = s.frame_id, last_emitted_id, "frame already seen");
311 false
312 } else {
313 true
314 }
315 })
316 }
317 Err(error) => {
318 tracing::error!(%error, "unparseable packet");
319 None
320 }
321 })
322 .instrument(tracing::debug_span!(
323 "SessionSocket::packets_in::pre_reassembly",
324 session_id = st_1.session_id()
325 ))
326 })
327 .reassembler_with_inspector(cfg.frame_timeout, cfg.capacity, inspector)
329 .filter_map(move |maybe_frame| {
331 futures::future::ready(match maybe_frame {
332 Ok(frame) => {
333 if let Err(error) = st_2.frame_complete(frame.frame_id) {
334 tracing::error!(%error, "frame complete state update failed");
335 }
336 Some(OrderedFrame(frame))
337 }
338 Err(error) => {
339 tracing::error!(%error, "failed to reassemble frame");
340 None
341 }
342 })
343 .instrument(tracing::debug_span!(
344 "SessionSocket::packets_in::pre_sequencing",
345 session_id = st_2.session_id()
346 ))
347 })
348 .sequencer(cfg.frame_timeout, cfg.frame_size)
350 .filter_map(move |maybe_frame| {
353 future::ready(match maybe_frame {
355 Ok(frame) => {
356 if let Err(error) = st_3.frame_emitted(frame.0.frame_id) {
357 tracing::error!(%error, "frame received state update failed");
358 }
359 last_emitted_frame_clone.store(frame.0.frame_id, std::sync::atomic::Ordering::Relaxed);
360 if frame.0.is_terminating {
361 tracing::warn!("terminating frame received");
362 packets_in_abort_handle.abort();
363 }
364 Some(Ok(frame.0))
365 }
366 Err(SessionError::FrameDiscarded(frame_id)) | Err(SessionError::IncompleteFrame(frame_id)) => {
367 if let Err(error) = st_3.frame_discarded(frame_id) {
368 tracing::error!(%error, "frame discarded state update failed");
369 }
370 None }
372 Err(err) => Some(Err(std::io::Error::other(err))),
373 })
374 .instrument(tracing::debug_span!(
375 "SessionSocket::packets_in::post_sequencing",
376 session_id = st_3.session_id()
377 ))
378 })
379 .into_async_read();
380
381 Ok(Self {
382 state,
383 upstream_frames_in: Box::pin(upstream_frames_in),
384 downstream_frames_out: Box::pin(downstream_frames_out),
385 write_state: if cfg.flush_immediately {
386 WriteState::Writing
387 } else {
388 WriteState::WriteOnly
389 },
390 })
391 }
392}
393
394impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncRead for SessionSocket<C, S> {
395 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
396 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
397 self.project().downstream_frames_out.as_mut().poll_read(cx, buf)
398 }
399}
400
401impl<const C: usize, S: SocketState<C> + Clone + 'static> futures::io::AsyncWrite for SessionSocket<C, S> {
402 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
403 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
404 let this = self.project();
405 loop {
406 match this.write_state {
407 WriteState::WriteOnly => {
408 return this.upstream_frames_in.as_mut().poll_write(cx, buf);
409 }
410 WriteState::Writing => {
411 let len = futures::ready!(this.upstream_frames_in.as_mut().poll_write(cx, buf))?;
412 *this.write_state = WriteState::Flushing(len);
413 }
414 WriteState::Flushing(len) => {
415 let res = futures::ready!(this.upstream_frames_in.as_mut().poll_flush(cx)).map(|_| *len);
416 *this.write_state = WriteState::Writing;
417 return Poll::Ready(res);
418 }
419 }
420 }
421 }
422
423 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
424 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
425 self.project().upstream_frames_in.as_mut().poll_flush(cx)
426 }
427
428 #[instrument(name = "SessionSocket::poll_close", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
429 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
430 let this = self.project();
431 let _ = this.state.stop();
432 this.upstream_frames_in.as_mut().poll_close(cx)
433 }
434}
435
436#[cfg(feature = "runtime-tokio")]
437impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncRead for SessionSocket<C, S> {
438 #[instrument(name = "SessionSocket::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id()))]
439 fn poll_read(
440 mut self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 buf: &mut tokio::io::ReadBuf<'_>,
443 ) -> Poll<std::io::Result<()>> {
444 let slice = buf.initialize_unfilled();
445 let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
446 buf.advance(n);
447 Poll::Ready(Ok(()))
448 }
449}
450
451#[cfg(feature = "runtime-tokio")]
452impl<const C: usize, S: SocketState<C> + Clone + 'static> tokio::io::AsyncWrite for SessionSocket<C, S> {
453 #[instrument(name = "SessionSocket::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = self.state.session_id(), len = buf.len()))]
454 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
455 futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
456 }
457
458 #[instrument(name = "SessionSocket::poll_flush", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
459 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
460 futures::AsyncWrite::poll_flush(self.as_mut(), cx)
461 }
462
463 #[instrument(name = "SessionSocket::poll_shutdown", level = "trace", skip(self, cx), fields(session_id = self.state.session_id()))]
464 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
465 futures::AsyncWrite::poll_close(self.as_mut(), cx)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use std::collections::HashSet;
472
473 use futures::{AsyncReadExt, AsyncWriteExt};
474 use futures_time::future::FutureExt;
475 use hopr_crypto_packet::prelude::HoprPacket;
476
477 use super::*;
478 use crate::{AcknowledgementState, AcknowledgementStateConfig, utils::test::*};
479
480 const MTU: usize = HoprPacket::PAYLOAD_SIZE;
481
482 const FRAME_SIZE: usize = 1500;
483
484 const DATA_SIZE: usize = 17 * MTU + 271; #[test_log::test(tokio::test)]
487 async fn stateless_socket_unidirectional_should_work() -> anyhow::Result<()> {
488 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
489
490 let sock_cfg = SessionSocketConfig {
491 frame_size: FRAME_SIZE,
492 ..Default::default()
493 };
494
495 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
496 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
497
498 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
499
500 alice_socket
501 .write_all(&data)
502 .timeout(futures_time::time::Duration::from_secs(2))
503 .await??;
504 alice_socket.flush().await?;
505
506 let mut bob_data = [0u8; DATA_SIZE];
507 bob_socket
508 .read_exact(&mut bob_data)
509 .timeout(futures_time::time::Duration::from_secs(2))
510 .await??;
511 assert_eq!(data, bob_data);
512
513 alice_socket.close().await?;
514 bob_socket.close().await?;
515
516 Ok(())
517 }
518
519 #[test_log::test(tokio::test)]
520 async fn stateful_socket_unidirectional_should_work() -> anyhow::Result<()> {
521 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
522
523 let sock_cfg = SessionSocketConfig {
524 frame_size: FRAME_SIZE,
525 ..Default::default()
526 };
527
528 let ack_cfg = AcknowledgementStateConfig {
529 expected_packet_latency: Duration::from_millis(2),
530 acknowledgement_delay: Duration::from_millis(5),
531 ..Default::default()
532 };
533
534 let mut alice_socket =
535 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
536 let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
537
538 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
539
540 alice_socket
541 .write_all(&data)
542 .timeout(futures_time::time::Duration::from_secs(2))
543 .await??;
544 alice_socket.flush().await?;
545
546 let mut bob_data = [0u8; DATA_SIZE];
547 bob_socket
548 .read_exact(&mut bob_data)
549 .timeout(futures_time::time::Duration::from_secs(2))
550 .await??;
551 assert_eq!(data, bob_data);
552
553 alice_socket.close().await?;
554 bob_socket.close().await?;
555
556 Ok(())
557 }
558
559 #[test_log::test(tokio::test)]
560 async fn stateless_socket_bidirectional_should_work() -> anyhow::Result<()> {
561 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
562
563 let sock_cfg = SessionSocketConfig {
564 frame_size: FRAME_SIZE,
565 ..Default::default()
566 };
567
568 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
569 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
570
571 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
572 alice_socket
573 .write_all(&alice_sent_data)
574 .timeout(futures_time::time::Duration::from_secs(2))
575 .await??;
576 alice_socket.flush().await?;
577
578 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
579 bob_socket
580 .write_all(&bob_sent_data)
581 .timeout(futures_time::time::Duration::from_secs(2))
582 .await??;
583 bob_socket.flush().await?;
584
585 let mut bob_recv_data = [0u8; DATA_SIZE];
586 bob_socket
587 .read_exact(&mut bob_recv_data)
588 .timeout(futures_time::time::Duration::from_secs(2))
589 .await??;
590 assert_eq!(alice_sent_data, bob_recv_data);
591
592 let mut alice_recv_data = [0u8; DATA_SIZE];
593 alice_socket
594 .read_exact(&mut alice_recv_data)
595 .timeout(futures_time::time::Duration::from_secs(2))
596 .await??;
597 assert_eq!(bob_sent_data, alice_recv_data);
598
599 Ok(())
600 }
601
602 #[test_log::test(tokio::test)]
603 async fn stateful_socket_bidirectional_should_work() -> anyhow::Result<()> {
604 let (alice, bob) = setup_alice_bob::<MTU>(FaultyNetworkConfig::default(), None, None);
605
606 let sock_cfg = SessionSocketConfig {
610 frame_size: FRAME_SIZE,
611 ..Default::default()
612 };
613
614 let ack_cfg = AcknowledgementStateConfig {
615 expected_packet_latency: Duration::from_millis(2),
616 acknowledgement_delay: Duration::from_millis(10),
617 ..Default::default()
618 };
619
620 let mut alice_socket =
621 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
622 let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
623
624 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
625 alice_socket
626 .write_all(&alice_sent_data)
627 .timeout(futures_time::time::Duration::from_secs(2))
628 .await??;
629 alice_socket.flush().await?;
630
631 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
632 bob_socket
633 .write_all(&bob_sent_data)
634 .timeout(futures_time::time::Duration::from_secs(2))
635 .await??;
636 bob_socket.flush().await?;
637
638 let mut bob_recv_data = [0u8; DATA_SIZE];
639 bob_socket
640 .read_exact(&mut bob_recv_data)
641 .timeout(futures_time::time::Duration::from_secs(2))
642 .await??;
643 assert_eq!(alice_sent_data, bob_recv_data);
644
645 let mut alice_recv_data = [0u8; DATA_SIZE];
646 alice_socket
647 .read_exact(&mut alice_recv_data)
648 .timeout(futures_time::time::Duration::from_secs(2))
649 .await??;
650 assert_eq!(bob_sent_data, alice_recv_data);
651
652 Ok(())
653 }
654
655 #[test_log::test(tokio::test)]
656 async fn stateless_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
657 let network_cfg = FaultyNetworkConfig {
658 mixing_factor: 10,
659 ..Default::default()
660 };
661
662 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
663
664 let sock_cfg = SessionSocketConfig {
665 frame_size: FRAME_SIZE,
666 ..Default::default()
667 };
668
669 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
670 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
671
672 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
673 alice_socket
674 .write_all(&data)
675 .timeout(futures_time::time::Duration::from_secs(2))
676 .await??;
677 alice_socket.flush().await?;
678
679 let mut bob_recv_data = [0u8; DATA_SIZE];
680 bob_socket
681 .read_exact(&mut bob_recv_data)
682 .timeout(futures_time::time::Duration::from_secs(2))
683 .await??;
684 assert_eq!(data, bob_recv_data);
685
686 alice_socket.close().await?;
687 bob_socket.close().await?;
688
689 Ok(())
690 }
691
692 #[test_log::test(tokio::test)]
693 async fn stateful_socket_unidirectional_should_work_with_mixing() -> anyhow::Result<()> {
694 let network_cfg = FaultyNetworkConfig {
695 mixing_factor: 10,
696 ..Default::default()
697 };
698
699 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
700
701 let sock_cfg = SessionSocketConfig {
702 frame_size: FRAME_SIZE,
703 ..Default::default()
704 };
705
706 let ack_cfg = AcknowledgementStateConfig {
707 expected_packet_latency: Duration::from_millis(2),
708 acknowledgement_delay: Duration::from_millis(5),
709 ..Default::default()
710 };
711
712 let mut alice_socket =
713 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
714 let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
715
716 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
717 alice_socket
718 .write_all(&data)
719 .timeout(futures_time::time::Duration::from_secs(2))
720 .await??;
721 alice_socket.flush().await?;
722
723 let mut bob_recv_data = [0u8; DATA_SIZE];
724 bob_socket
725 .read_exact(&mut bob_recv_data)
726 .timeout(futures_time::time::Duration::from_secs(2))
727 .await??;
728 assert_eq!(data, bob_recv_data);
729
730 alice_socket.close().await?;
731 bob_socket.close().await?;
732
733 Ok(())
734 }
735
736 #[test_log::test(tokio::test)]
737 async fn stateless_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
738 let network_cfg = FaultyNetworkConfig {
739 mixing_factor: 10,
740 ..Default::default()
741 };
742
743 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
744
745 let sock_cfg = SessionSocketConfig {
746 frame_size: FRAME_SIZE,
747 ..Default::default()
748 };
749
750 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, sock_cfg)?;
751 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, sock_cfg)?;
752
753 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
754 alice_socket
755 .write_all(&alice_sent_data)
756 .timeout(futures_time::time::Duration::from_secs(2))
757 .await??;
758 alice_socket.flush().await?;
759
760 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
761 bob_socket
762 .write_all(&bob_sent_data)
763 .timeout(futures_time::time::Duration::from_secs(2))
764 .await??;
765 bob_socket.flush().await?;
766
767 let mut bob_recv_data = [0u8; DATA_SIZE];
768 bob_socket
769 .read_exact(&mut bob_recv_data)
770 .timeout(futures_time::time::Duration::from_secs(2))
771 .await??;
772 assert_eq!(alice_sent_data, bob_recv_data);
773
774 let mut alice_recv_data = [0u8; DATA_SIZE];
775 alice_socket
776 .read_exact(&mut alice_recv_data)
777 .timeout(futures_time::time::Duration::from_secs(2))
778 .await??;
779 assert_eq!(bob_sent_data, alice_recv_data);
780
781 alice_socket.close().await?;
782 bob_socket.close().await?;
783
784 Ok(())
785 }
786
787 #[test_log::test(tokio::test)]
788 async fn stateful_socket_bidirectional_should_work_with_mixing() -> anyhow::Result<()> {
789 let network_cfg = FaultyNetworkConfig {
790 mixing_factor: 10,
791 ..Default::default()
792 };
793
794 let (alice, bob) = setup_alice_bob::<MTU>(network_cfg, None, None);
795
796 let sock_cfg = SessionSocketConfig {
797 frame_size: FRAME_SIZE,
798 ..Default::default()
799 };
800
801 let ack_cfg = AcknowledgementStateConfig {
802 expected_packet_latency: Duration::from_millis(2),
803 acknowledgement_delay: Duration::from_millis(5),
804 ..Default::default()
805 };
806
807 let mut alice_socket =
808 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), sock_cfg)?;
809 let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), sock_cfg)?;
810
811 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
812 alice_socket
813 .write_all(&alice_sent_data)
814 .timeout(futures_time::time::Duration::from_secs(2))
815 .await??;
816 alice_socket.flush().await?;
817
818 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
819 bob_socket
820 .write_all(&bob_sent_data)
821 .timeout(futures_time::time::Duration::from_secs(2))
822 .await??;
823 bob_socket.flush().await?;
824
825 let mut bob_recv_data = [0u8; DATA_SIZE];
826 bob_socket
827 .read_exact(&mut bob_recv_data)
828 .timeout(futures_time::time::Duration::from_secs(2))
829 .await??;
830 assert_eq!(alice_sent_data, bob_recv_data);
831
832 let mut alice_recv_data = [0u8; DATA_SIZE];
833 alice_socket
834 .read_exact(&mut alice_recv_data)
835 .timeout(futures_time::time::Duration::from_secs(2))
836 .await??;
837 assert_eq!(bob_sent_data, alice_recv_data);
838
839 alice_socket.close().await?;
840 bob_socket.close().await?;
841
842 Ok(())
843 }
844
845 #[test_log::test(tokio::test)]
846 async fn stateless_socket_unidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
847 let (alice, bob) = setup_alice_bob::<MTU>(
848 FaultyNetworkConfig {
849 avg_delay: Duration::from_millis(10),
850 ids_to_drop: HashSet::from_iter([0_usize]),
851 ..Default::default()
852 },
853 None,
854 None,
855 );
856
857 let alice_cfg = SessionSocketConfig {
858 frame_size: FRAME_SIZE,
859 ..Default::default()
860 };
861
862 let bob_cfg = SessionSocketConfig {
863 frame_size: FRAME_SIZE,
864 frame_timeout: Duration::from_millis(55),
865 ..Default::default()
866 };
867
868 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, alice_cfg)?;
869 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, bob_cfg)?;
870
871 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
872 alice_socket
873 .write_all(&data)
874 .timeout(futures_time::time::Duration::from_secs(2))
875 .await??;
876 alice_socket.flush().await?;
877 alice_socket.close().await?;
878
879 let mut bob_data = Vec::with_capacity(DATA_SIZE);
880 bob_socket
881 .read_to_end(&mut bob_data)
882 .timeout(futures_time::time::Duration::from_secs(2))
883 .await??;
884
885 assert_eq!(data.len() - 1500, bob_data.len());
887 assert_eq!(&data[1500..], &bob_data);
888
889 bob_socket.close().await?;
890
891 Ok(())
892 }
893
894 #[test_log::test(tokio::test)]
895 async fn stateful_socket_unidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
896 let (alice, bob) = setup_alice_bob::<MTU>(
897 FaultyNetworkConfig {
898 avg_delay: Duration::from_millis(10),
899 ids_to_drop: HashSet::from_iter([0_usize]),
900 ..Default::default()
901 },
902 None,
903 None,
904 );
905
906 let alice_cfg = SessionSocketConfig {
907 frame_size: FRAME_SIZE,
908 ..Default::default()
909 };
910
911 let bob_cfg = SessionSocketConfig {
912 frame_size: FRAME_SIZE,
913 frame_timeout: Duration::from_millis(1000),
914 ..Default::default()
915 };
916
917 let ack_cfg = AcknowledgementStateConfig {
918 expected_packet_latency: Duration::from_millis(10),
919 acknowledgement_delay: Duration::from_millis(40),
920 ..Default::default()
921 };
922
923 let mut alice_socket =
924 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), alice_cfg)?;
925 let mut bob_socket = SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), bob_cfg)?;
926
927 let data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
928
929 let alice_jh = tokio::spawn(async move {
930 alice_socket
931 .write_all(&data)
932 .timeout(futures_time::time::Duration::from_secs(5))
933 .await??;
934
935 alice_socket.flush().await?;
936
937 let mut vec = Vec::new();
939 alice_socket.read_to_end(&mut vec).await?;
940 alice_socket.close().await?;
941
942 Ok::<_, std::io::Error>(vec)
943 });
944
945 let mut bob_data = [0u8; DATA_SIZE];
946 bob_socket
947 .read_exact(&mut bob_data)
948 .timeout(futures_time::time::Duration::from_secs(5))
949 .await??;
950 assert_eq!(data, bob_data);
951
952 bob_socket.close().await?;
953
954 let alice_recv = alice_jh.await??;
955 assert!(alice_recv.is_empty());
956
957 Ok(())
958 }
959
960 #[test_log::test(tokio::test)]
961 async fn stateless_socket_bidirectional_should_should_skip_missing_frames() -> anyhow::Result<()> {
962 let (alice, bob) = setup_alice_bob::<MTU>(
963 FaultyNetworkConfig {
964 avg_delay: Duration::from_millis(10),
965 ids_to_drop: HashSet::from_iter([0_usize]),
966 ..Default::default()
967 },
968 None,
969 None,
970 );
971
972 let alice_cfg = SessionSocketConfig {
973 frame_size: FRAME_SIZE,
974 frame_timeout: Duration::from_millis(55),
975 ..Default::default()
976 };
977
978 let bob_cfg = SessionSocketConfig {
979 frame_size: FRAME_SIZE,
980 frame_timeout: Duration::from_millis(55),
981 ..Default::default()
982 };
983
984 let mut alice_socket = SessionSocket::<MTU, _>::new_stateless("alice", alice, alice_cfg)?;
985 let mut bob_socket = SessionSocket::<MTU, _>::new_stateless("bob", bob, bob_cfg)?;
986
987 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
988 alice_socket
989 .write_all(&alice_sent_data)
990 .timeout(futures_time::time::Duration::from_secs(2))
991 .await??;
992 alice_socket.flush().await?;
993
994 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
995 bob_socket
996 .write_all(&bob_sent_data)
997 .timeout(futures_time::time::Duration::from_secs(2))
998 .await??;
999 bob_socket.flush().await?;
1000
1001 alice_socket.close().await?;
1002 bob_socket.close().await?;
1003
1004 let mut alice_recv_data = Vec::with_capacity(DATA_SIZE);
1005 alice_socket
1006 .read_to_end(&mut alice_recv_data)
1007 .timeout(futures_time::time::Duration::from_secs(2))
1008 .await??;
1009
1010 let mut bob_recv_data = Vec::with_capacity(DATA_SIZE);
1011 bob_socket
1012 .read_to_end(&mut bob_recv_data)
1013 .timeout(futures_time::time::Duration::from_secs(2))
1014 .await??;
1015
1016 assert_eq!(bob_sent_data.len() - 1500, alice_recv_data.len());
1018 assert_eq!(&bob_sent_data[1500..], &alice_recv_data);
1019
1020 assert_eq!(alice_sent_data.len() - 1500, bob_recv_data.len());
1021 assert_eq!(&alice_sent_data[1500..], &bob_recv_data);
1022
1023 Ok(())
1024 }
1025
1026 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1028 async fn stateful_socket_bidirectional_should_should_not_skip_missing_frames() -> anyhow::Result<()> {
1029 let (alice, bob) = setup_alice_bob::<MTU>(
1030 FaultyNetworkConfig {
1031 avg_delay: Duration::from_millis(10),
1032 ids_to_drop: HashSet::from_iter([0_usize]),
1033 ..Default::default()
1034 },
1035 None,
1036 None,
1037 );
1038
1039 let alice_cfg = SessionSocketConfig {
1043 frame_size: FRAME_SIZE,
1044 frame_timeout: Duration::from_millis(1000),
1045 ..Default::default()
1046 };
1047
1048 let bob_cfg = SessionSocketConfig {
1049 frame_size: FRAME_SIZE,
1050 frame_timeout: Duration::from_millis(1000),
1051 ..Default::default()
1052 };
1053
1054 let ack_cfg = AcknowledgementStateConfig {
1055 expected_packet_latency: Duration::from_millis(10),
1056 acknowledgement_delay: Duration::from_millis(40),
1057 ..Default::default()
1058 };
1059
1060 let (mut alice_rx, mut alice_tx) =
1061 SessionSocket::<MTU, _>::new(alice, AcknowledgementState::new("alice", ack_cfg), alice_cfg)?.split();
1062
1063 let (mut bob_rx, mut bob_tx) =
1064 SessionSocket::<MTU, _>::new(bob, AcknowledgementState::new("bob", ack_cfg), bob_cfg)?.split();
1065
1066 let alice_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1067 let (alice_data_tx, alice_recv_data) = futures::channel::oneshot::channel();
1068 let alice_rx_jh = tokio::spawn(async move {
1069 let mut alice_recv_data = vec![0u8; DATA_SIZE];
1070 alice_rx.read_exact(&mut alice_recv_data).await?;
1071 alice_data_tx
1072 .send(alice_recv_data)
1073 .map_err(|_| std::io::Error::other("tx error"))?;
1074
1075 alice_rx.read_to_end(&mut Vec::new()).await?;
1077 Ok::<_, std::io::Error>(())
1078 });
1079
1080 let bob_sent_data = hopr_crypto_random::random_bytes::<DATA_SIZE>();
1081 let (bob_data_tx, bob_recv_data) = futures::channel::oneshot::channel();
1082 let bob_rx_jh = tokio::spawn(async move {
1083 let mut bob_recv_data = vec![0u8; DATA_SIZE];
1084 bob_rx.read_exact(&mut bob_recv_data).await?;
1085 bob_data_tx
1086 .send(bob_recv_data)
1087 .map_err(|_| std::io::Error::other("tx error"))?;
1088
1089 bob_rx.read_to_end(&mut Vec::new()).await?;
1091 Ok::<_, std::io::Error>(())
1092 });
1093
1094 let alice_tx_jh = tokio::spawn(async move {
1095 alice_tx
1096 .write_all(&alice_sent_data)
1097 .timeout(futures_time::time::Duration::from_secs(2))
1098 .await??;
1099 alice_tx.flush().await?;
1100
1101 let out = alice_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1103 alice_tx.close().await?;
1104 tracing::info!("alice closed");
1105 Ok::<_, std::io::Error>(out)
1106 });
1107
1108 let bob_tx_jh = tokio::spawn(async move {
1109 bob_tx
1110 .write_all(&bob_sent_data)
1111 .timeout(futures_time::time::Duration::from_secs(2))
1112 .await??;
1113 bob_tx.flush().await?;
1114
1115 let out = bob_recv_data.await.map_err(|_| std::io::Error::other("rx error"))?;
1117 bob_tx.close().await?;
1118 tracing::info!("bob closed");
1119 Ok::<_, std::io::Error>(out)
1120 });
1121
1122 let (alice_recv_data, bob_recv_data, a, b) =
1123 futures::future::try_join4(alice_tx_jh, bob_tx_jh, alice_rx_jh, bob_rx_jh)
1124 .timeout(futures_time::time::Duration::from_secs(4))
1125 .await??;
1126
1127 assert_eq!(&alice_sent_data, bob_recv_data?.as_slice());
1128 assert_eq!(&bob_sent_data, alice_recv_data?.as_slice());
1129 assert!(a.is_ok());
1130 assert!(b.is_ok());
1131
1132 Ok(())
1133 }
1134}