1use std::{
89 collections::{BTreeSet, HashSet},
90 fmt::{Debug, Display},
91 future::Future,
92 pin::Pin,
93 sync::{
94 Arc,
95 atomic::{AtomicU32, Ordering},
96 },
97 task::{Context, Poll},
98 time::{Duration, Instant},
99};
100
101use crossbeam_queue::ArrayQueue;
102use crossbeam_skiplist::SkipMap;
103use dashmap::{DashMap, mapref::entry::Entry};
104use futures::{
105 AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, Sink, SinkExt, StreamExt, TryStreamExt,
106 channel::mpsc::UnboundedSender, future::BoxFuture, pin_mut,
107};
108use governor::Quota;
109use hopr_async_runtime::prelude::spawn;
110use smart_default::SmartDefault;
111use tracing::{debug, error, trace, warn};
112
113use crate::{
114 errors::NetworkTypeError,
115 prelude::protocol::SessionMessageIter,
116 session::{
117 errors::SessionError,
118 frame::{FrameId, FrameReassembler, Segment, SegmentId, segment},
119 protocol::{FrameAcknowledgements, SegmentRequest, SessionMessage},
120 utils::{RetryResult, RetryToken},
121 },
122 utils::AsyncReadStreamer,
123};
124
125#[cfg(all(feature = "prometheus", not(test)))]
126lazy_static::lazy_static! {
127 static ref METRIC_TIME_TO_ACK: hopr_metrics::MultiHistogram =
128 hopr_metrics::MultiHistogram::new(
129 "hopr_session_time_to_ack",
130 "Time in seconds until a complete frame gets acknowledged by the recipient",
131 vec![0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0],
132 &["session_id"]
133 ).unwrap();
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum SessionFeature {
139 RequestIncompleteFrames,
141 RetransmitFrames,
144 AcknowledgeFrames,
146 NoDelay,
148}
149
150impl SessionFeature {
151 fn default_features() -> Vec<SessionFeature> {
159 vec![
160 SessionFeature::AcknowledgeFrames,
161 SessionFeature::RequestIncompleteFrames,
162 SessionFeature::RetransmitFrames,
163 ]
164 }
165}
166
167#[derive(Debug, Clone, SmartDefault, validator::Validate)]
169pub struct SessionConfig {
170 #[default = 50_000]
177 pub max_buffered_segments: usize,
178
179 #[default = 1024]
186 #[validate(range(min = 1))]
187 pub acknowledged_frames_buffer: usize,
188
189 #[default(Duration::from_secs(30))]
194 pub frame_expiration_age: Duration,
195
196 #[default(Duration::from_millis(1000))]
208 pub rto_base_receiver: Duration,
209
210 #[default(Duration::from_millis(1500))]
222 pub rto_base_sender: Duration,
223
224 #[default(2.0)]
228 #[validate(range(min = 1.0))]
229 pub backoff_base: f64,
230
231 #[default(0.05)]
236 #[validate(range(min = 0.0, max = 0.25))]
237 pub rto_jitter: f64,
238
239 #[default(_code = "HashSet::from_iter(SessionFeature::default_features())")]
243 pub enabled_features: HashSet<SessionFeature>,
244}
245
246#[derive(Debug, Clone)]
255pub struct SessionState<const C: usize> {
256 session_id: String,
257 lookbehind: Arc<SkipMap<SegmentId, Segment>>,
258 to_acknowledge: Arc<ArrayQueue<FrameId>>,
259 incoming_frame_retries: Arc<DashMap<FrameId, RetryToken>>,
260 outgoing_frame_resends: Arc<DashMap<FrameId, RetryToken>>,
261 outgoing_frame_id: Arc<AtomicU32>,
262 frame_reassembler: Arc<FrameReassembler>,
263 cfg: SessionConfig,
264 segment_egress_send: UnboundedSender<SessionMessage<C>>,
265}
266
267fn maybe_fused_future<'a, F>(condition: bool, future: F) -> futures::future::Fuse<BoxFuture<'a, ()>>
268where
269 F: Future<Output = ()> + Send + Sync + 'a,
270{
271 if condition {
272 future.boxed()
273 } else {
274 futures::future::pending().boxed()
275 }
276 .fuse()
277}
278
279impl<const C: usize> SessionState<C> {
280 const MAX_WRITE_SIZE: usize = SessionMessage::<C>::MAX_SEGMENTS_PER_FRAME * Self::PAYLOAD_CAPACITY;
282 const PAYLOAD_CAPACITY: usize = C - SessionMessage::<C>::SEGMENT_OVERHEAD;
284
285 fn consume_segment(&mut self, segment: Segment) -> crate::errors::Result<()> {
286 let id = segment.id();
287
288 trace!(session_id = self.session_id, segment = %id, "received segment");
289
290 match self.frame_reassembler.push_segment(segment) {
291 Ok(_) => {
292 match self.incoming_frame_retries.entry(id.0) {
293 Entry::Occupied(e) => {
294 let rt = *e.get();
296 e.replace_entry(rt.replenish(Instant::now(), self.cfg.backoff_base));
297 }
298 Entry::Vacant(v) => {
299 v.insert(RetryToken::new(Instant::now(), self.cfg.backoff_base));
301 }
302 }
303 trace!(session_id = self.session_id, segment = %id, "received segment pushed");
304 }
305 Err(e) => warn!(session_id = self.session_id, ?id, error = %e, "segment not pushed"),
307 }
308
309 Ok(())
310 }
311
312 fn retransmit_segments(&mut self, request: SegmentRequest<C>) -> crate::errors::Result<()> {
313 trace!(
314 session_id = self.session_id,
315 count_of_segments = request.len(),
316 "received request",
317 );
318
319 let mut count = 0;
320 request
321 .into_iter()
322 .filter_map(|segment_id| {
323 self.outgoing_frame_resends.remove(&segment_id.0);
325 let ret = self
326 .lookbehind
327 .get(&segment_id)
328 .map(|e| SessionMessage::<C>::Segment(e.value().clone()));
329 if ret.is_some() {
330 trace!(
331 session_id = self.session_id,
332 %segment_id,
333 "SENDING: retransmitted segment"
334 );
335 count += 1;
336 } else {
337 warn!(
338 session_id = self.session_id,
339 id = ?segment_id,
340 "segment not in lookbehind buffer anymore",
341 );
342 }
343 ret
344 })
345 .try_for_each(|msg| self.segment_egress_send.unbounded_send(msg))
346 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
347
348 trace!(session_id = self.session_id, count, "retransmitted requested segments");
349
350 Ok(())
351 }
352
353 fn acknowledged_frames(&mut self, acked: FrameAcknowledgements<C>) -> crate::errors::Result<()> {
354 trace!(
355 session_id = self.session_id,
356 count = acked.len(),
357 "received acknowledgement frames",
358 );
359
360 for frame_id in acked {
361 if let Some((_, rt)) = self.outgoing_frame_resends.remove(&frame_id) {
363 let to_ack = rt.time_since_creation();
364 trace!(
365 session_id = self.session_id,
366 frame_id,
367 duration_in_ms = to_ack.as_millis(),
368 "frame acknowledgement duratin"
369 );
370
371 #[cfg(all(feature = "prometheus", not(test)))]
372 METRIC_TIME_TO_ACK.observe(&[self.session_id()], to_ack.as_secs_f64())
373 }
374
375 for seg in self.lookbehind.iter().filter(|s| frame_id == s.key().0) {
376 seg.remove();
377 }
378 }
379
380 Ok(())
381 }
382
383 async fn request_missing_segments(&mut self) -> crate::errors::Result<usize> {
389 let tracked_incomplete = self.frame_reassembler.incomplete_frames();
390 trace!(
391 session_id = self.session_id,
392 count = tracked_incomplete.len(),
393 "tracking incomplete frames",
394 );
395
396 let mut to_retry = Vec::with_capacity(tracked_incomplete.len());
398 let now = Instant::now();
399 for info in tracked_incomplete {
400 match self.incoming_frame_retries.entry(info.frame_id) {
401 Entry::Occupied(e) => {
402 let rto_check = e.get().check(
404 now,
405 self.cfg.rto_base_receiver,
406 self.cfg.frame_expiration_age,
407 self.cfg.rto_jitter,
408 );
409 match rto_check {
410 RetryResult::RetryNow(next_rto) => {
411 trace!(
413 session_id = self.session_id,
414 frame_id = info.frame_id,
415 retransmission_number = next_rto.num_retry,
416 "performing frame retransmission",
417 );
418 e.replace_entry(next_rto);
419 to_retry.push(info);
420 }
421 RetryResult::Expired => {
422 debug!(
424 session_id = self.session_id,
425 frame_id = info.frame_id,
426 "frame is already expired and will be evicted"
427 );
428 e.remove();
429 }
430 RetryResult::Wait(d) => trace!(
431 session_id = self.session_id,
432 frame_id = info.frame_id,
433 timeout_in_ms = d.as_millis(),
434 next_retransmission_request_number = e.get().num_retry,
435 "frame needs to wait for next retransmission request",
436 ),
437 }
438 }
439 Entry::Vacant(v) => {
440 debug!(
442 session_id = self.session_id,
443 frame_id = info.frame_id,
444 "frame does not have a retry token"
445 );
446 v.insert(RetryToken::new(now, self.cfg.backoff_base));
447 to_retry.push(info);
448 }
449 }
450 }
451
452 let mut sent = 0;
453 let to_retry = to_retry
454 .chunks(SegmentRequest::<C>::MAX_ENTRIES)
455 .map(|chunk| Ok(SessionMessage::<C>::Request(chunk.iter().cloned().collect())))
456 .inspect(|r| {
457 trace!(
458 session_id = self.session_id,
459 result = ?r,
460 "SENDING: retransmission request"
461 );
462 sent += 1;
463 })
464 .collect::<Vec<_>>();
465
466 self.segment_egress_send
467 .send_all(&mut futures::stream::iter(to_retry))
468 .await
469 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
470
471 trace!(
472 session_id = self.session_id,
473 count = sent,
474 "RETRANSMISSION BATCH COMPLETE: sent {sent} re-send requests",
475 );
476 Ok(sent)
477 }
478
479 async fn acknowledge_segments(&mut self) -> crate::errors::Result<usize> {
490 let mut len = 0;
491 let mut msgs = 0;
492
493 while !self.to_acknowledge.is_empty() {
494 let mut ack_frames = FrameAcknowledgements::<C>::default();
495
496 while !ack_frames.is_full() && !self.to_acknowledge.is_empty() {
497 if let Some(ack_id) = self.to_acknowledge.pop() {
498 ack_frames.push(ack_id);
499 len += 1;
500 }
501 }
502
503 trace!(
504 session_id = self.session_id,
505 count = ack_frames.len(),
506 "SENDING: acknowledgements of frames",
507 );
508 self.segment_egress_send
509 .feed(SessionMessage::Acknowledge(ack_frames))
510 .await
511 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
512 msgs += 1;
513 }
514 self.segment_egress_send
515 .flush()
516 .await
517 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
518
519 trace!(
520 session_id = self.session_id,
521 count = len,
522 messages = msgs,
523 "ACK BATCH COMPLETE: sent acks in messages",
524 );
525 Ok(len)
526 }
527
528 async fn retransmit_unacknowledged_frames(&mut self) -> crate::errors::Result<usize> {
534 if self.cfg.acknowledged_frames_buffer == 0 {
535 return Ok(0);
536 }
537
538 let now = Instant::now();
539
540 let mut frames_to_resend = BTreeSet::new();
542 self.outgoing_frame_resends.retain(|frame_id, retry_log| {
543 let check_res = retry_log.check(
544 now,
545 self.cfg.rto_base_sender,
546 self.cfg.frame_expiration_age,
547 self.cfg.rto_jitter,
548 );
549 match check_res {
550 RetryResult::Wait(d) => {
551 trace!(
552 session_id = self.session_id,
553 frame_id,
554 wait_timeout_in_ms = d.as_millis(),
555 "frame will retransmit"
556 );
557 true
558 }
559 RetryResult::RetryNow(next_retry) => {
560 frames_to_resend.insert(*frame_id);
562 *retry_log = next_retry;
563 debug!(session_id = self.session_id, frame_id, "frame will self-resend now");
564 true
565 }
566 RetryResult::Expired => {
567 debug!(session_id = self.session_id, frame_id, "frame expired");
568 false
569 }
570 }
571 });
572
573 trace!(
574 session_id = self.session_id,
575 count = frames_to_resend.len(),
576 "frames will auto-resend",
577 );
578
579 let mut count = 0;
582 let frames_to_resend = frames_to_resend
583 .into_iter()
584 .flat_map(|f| self.lookbehind.iter().filter(move |e| e.key().0 == f))
585 .inspect(|e| {
586 trace!(
587 session_id = self.session_id,
588 key = ?e.key(),
589 "SENDING: auto-retransmitted"
590 );
591 count += 1
592 })
593 .map(|e| Ok(SessionMessage::<C>::Segment(e.value().clone())))
594 .collect::<Vec<_>>();
595
596 self.segment_egress_send
597 .send_all(&mut futures::stream::iter(frames_to_resend))
598 .await
599 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
600
601 trace!(
602 session_id = self.session_id,
603 count, "AUTO-RETRANSMIT BATCH COMPLETE: re-sent segments",
604 );
605
606 Ok(count)
607 }
608
609 pub async fn send_frame_data(&mut self, data: &[u8]) -> crate::errors::Result<()> {
618 if !(1..=Self::MAX_WRITE_SIZE).contains(&data.len()) {
619 return Err(SessionError::IncorrectMessageLength.into());
620 }
621
622 let frame_id = self.outgoing_frame_id.fetch_add(1, Ordering::SeqCst);
623 let segments = segment(data, Self::PAYLOAD_CAPACITY, frame_id)?;
624 let count = segments.len();
625
626 for segment in segments {
627 let msg = SessionMessage::<C>::Segment(segment.clone());
628 trace!(session_id = self.session_id, id = ?segment.id(), "SENDING: segment");
629 self.segment_egress_send
630 .feed(msg)
631 .await
632 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
633
634 self.lookbehind.insert((&segment).into(), segment.clone());
636 while self.lookbehind.len() > self.cfg.max_buffered_segments {
637 self.lookbehind.pop_front();
638 }
639 }
640
641 self.segment_egress_send
642 .flush()
643 .await
644 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
645 self.outgoing_frame_resends
646 .insert(frame_id, RetryToken::new(Instant::now(), self.cfg.backoff_base));
647
648 trace!(
649 session_id = self.session_id,
650 frame_id, count, "FRAME SEND COMPLETE: sent segments",
651 );
652
653 Ok(())
654 }
655
656 async fn state_loop(&mut self) -> crate::errors::Result<()> {
661 let eviction_limiter =
664 governor::RateLimiter::direct(Quota::with_period(self.cfg.frame_expiration_age / 10).ok_or(
665 NetworkTypeError::Other("rate limiter frame_expiration_age invalid".into()),
666 )?);
667
668 let ack_rate_limiter = governor::RateLimiter::direct(
672 Quota::with_period(self.cfg.rto_base_sender.min(self.cfg.rto_base_receiver) / 4)
673 .ok_or(NetworkTypeError::Other("rate limiter ack rate invalid".into()))?,
674 );
675
676 let sender_retransmit = governor::RateLimiter::direct(
678 Quota::with_period(self.cfg.rto_base_sender)
679 .ok_or(NetworkTypeError::Other("rate limiter rto sender invalid".into()))?,
680 );
681
682 let receiver_retransmit = governor::RateLimiter::direct(
684 Quota::with_period(self.cfg.rto_base_receiver)
685 .ok_or(NetworkTypeError::Other("rate limiter rto receiver invalid".into()))?,
686 );
687
688 loop {
689 let mut evict_fut = eviction_limiter.until_ready().boxed().fuse();
690 let mut ack_fut = maybe_fused_future(
691 self.cfg.enabled_features.contains(&SessionFeature::AcknowledgeFrames),
692 ack_rate_limiter.until_ready(),
693 );
694 let mut r_snd_fut = maybe_fused_future(
695 self.cfg.enabled_features.contains(&SessionFeature::RetransmitFrames),
696 sender_retransmit.until_ready(),
697 );
698 let mut r_rcv_fut = maybe_fused_future(
699 self.cfg
700 .enabled_features
701 .contains(&SessionFeature::RequestIncompleteFrames),
702 receiver_retransmit.until_ready(),
703 );
704 let mut is_done = maybe_fused_future(self.segment_egress_send.is_closed(), futures::future::ready(()));
705
706 if let Err(e) = futures::select_biased! {
710 _ = is_done => {
711 Err(NetworkTypeError::Other("session writer has been closed".into()))
712 },
713 _ = r_rcv_fut => {
714 self.request_missing_segments().await
715 },
716 _ = r_snd_fut => {
717 self.retransmit_unacknowledged_frames().await
718 },
719 _ = ack_fut => {
720 self.acknowledge_segments().await
721 },
722 _ = evict_fut => {
723 self.frame_reassembler.evict().map_err(NetworkTypeError::from)
724 },
725 } {
726 debug!(session_id = self.session_id, "session is closing: {e}");
727 break;
728 }
729 }
730
731 Ok(())
732 }
733
734 pub fn session_id(&self) -> &str {
736 &self.session_id
737 }
738}
739
740impl<const C: usize> Sink<SessionMessage<C>> for SessionState<C> {
742 type Error = NetworkTypeError;
743
744 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
745 Poll::Ready(Ok(()))
746 }
747
748 fn start_send(mut self: Pin<&mut Self>, item: SessionMessage<C>) -> Result<(), Self::Error> {
749 match item {
750 SessionMessage::Segment(s) => self.consume_segment(s),
751 SessionMessage::Request(r) => self.retransmit_segments(r),
752 SessionMessage::Acknowledge(f) => self.acknowledged_frames(f),
753 }
754 }
755
756 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
757 Poll::Ready(Ok(()))
758 }
759
760 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
761 self.frame_reassembler.close();
762 Poll::Ready(Ok(()))
763 }
764}
765
766#[cfg_attr(doc, aquamarine::aquamarine)]
767pub struct SessionSocket<const C: usize> {
866 state: SessionState<C>,
867 frame_egress: Box<dyn AsyncRead + Send + Unpin>,
868}
869
870impl<const C: usize> SessionSocket<C> {
871 pub const MAX_WRITE_SIZE: usize = SessionState::<C>::MAX_WRITE_SIZE;
873 pub const PAYLOAD_CAPACITY: usize = SessionState::<C>::PAYLOAD_CAPACITY;
875
876 pub fn new<T, I>(id: I, transport: T, cfg: SessionConfig) -> Self
879 where
880 T: AsyncWrite + AsyncRead + Send + 'static,
881 I: Display + Send + 'static,
882 {
883 assert!(
884 C >= SessionMessage::<C>::minimum_message_size(),
885 "given MTU is too small"
886 );
887
888 let (reassembler, egress) = FrameReassembler::new(cfg.frame_expiration_age);
889
890 let to_acknowledge = Arc::new(ArrayQueue::new(cfg.acknowledged_frames_buffer.max(1)));
891 let incoming_frame_retries = Arc::new(DashMap::new());
892
893 let incoming_frame_retries_clone = incoming_frame_retries.clone();
894 let id_clone = id.to_string().clone();
895 let to_acknowledge_clone = to_acknowledge.clone();
896 let ack_enabled = cfg.enabled_features.contains(&SessionFeature::AcknowledgeFrames);
897
898 let frame_egress = Box::new(
899 egress
900 .filter_map(move |maybe_frame| {
901 match maybe_frame {
902 Ok(frame) => {
903 trace!(session_id = id_clone, frame_id = frame.frame_id, "frame completed");
904 incoming_frame_retries_clone.remove(&frame.frame_id);
906 if ack_enabled {
907 to_acknowledge_clone.force_push(frame.frame_id);
910 }
911 futures::future::ready(Some(Ok(frame)))
912 }
913 Err(SessionError::FrameDiscarded(fid)) | Err(SessionError::IncompleteFrame(fid)) => {
914 incoming_frame_retries_clone.remove(&fid);
916 warn!(session_id = id_clone, frame_id = fid, "frame skipped");
917 futures::future::ready(None) }
919 Err(e) => {
920 error!(session_id = id_clone, "error on frame reassembly: {e}");
921 futures::future::ready(Some(Err(std::io::Error::other(e))))
922 }
923 }
924 })
925 .into_async_read(),
926 );
927
928 let (segment_egress_send, segment_egress_recv) = futures::channel::mpsc::unbounded();
929
930 let (downstream_read, downstream_write) = transport.split();
931
932 let downstream_write = futures::io::BufWriter::with_capacity(
934 if !cfg.enabled_features.contains(&SessionFeature::NoDelay) {
935 C
936 } else {
937 0
938 },
939 downstream_write,
940 );
941
942 let state = SessionState {
943 lookbehind: Arc::new(SkipMap::new()),
944 outgoing_frame_id: Arc::new(AtomicU32::new(1)),
945 frame_reassembler: Arc::new(reassembler),
946 outgoing_frame_resends: Arc::new(DashMap::new()),
947 session_id: id.to_string(),
948 to_acknowledge,
949 incoming_frame_retries,
950 segment_egress_send,
951 cfg,
952 };
953
954 spawn(async move {
956 if let Err(e) = segment_egress_recv
957 .map(|m: SessionMessage<C>| Ok(m.into_encoded()))
958 .forward(downstream_write.into_sink())
959 .await
960 {
961 error!(session_id = %id, error = %e, "FINISHED: forwarding to downstream terminated with error")
962 } else {
963 debug!(session_id = %id, "FINISHED: forwarding to downstream done");
964 }
965 });
966
967 spawn(
969 AsyncReadStreamer::<C, _>(downstream_read)
970 .map_err(|e| NetworkTypeError::SessionProtocolError(SessionError::ProcessingError(e.to_string())))
971 .and_then(|m| futures::future::ok(futures::stream::iter(SessionMessageIter::from(m.into_vec()))))
972 .try_flatten()
973 .forward(state.clone()),
974 );
975
976 let mut state_clone = state.clone();
978 spawn(async move {
979 let loop_done = state_clone.state_loop().await;
980 debug!(
981 session_id = state_clone.session_id,
982 "FINISHED: state loop {loop_done:?}"
983 );
984 });
985
986 Self { state, frame_egress }
987 }
988
989 pub fn state(&self) -> &SessionState<C> {
991 &self.state
992 }
993
994 pub fn state_mut(&mut self) -> &mut SessionState<C> {
996 &mut self.state
997 }
998}
999
1000impl<const C: usize> AsyncWrite for SessionSocket<C> {
1001 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
1002 let len_to_write = Self::MAX_WRITE_SIZE.min(buf.len());
1003 tracing::trace!(
1004 session_id = self.state.session_id(),
1005 number_of_bytes = len_to_write,
1006 "polling write of bytes on socket reader inside session",
1007 );
1008
1009 if len_to_write == 0 {
1011 return Poll::Ready(Ok(0));
1012 }
1013
1014 let mut socket_future = self.state.send_frame_data(&buf[..len_to_write]).boxed();
1015 match Pin::new(&mut socket_future).poll(cx) {
1016 Poll::Ready(Ok(())) => Poll::Ready(Ok(len_to_write)),
1017 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
1018 Poll::Pending => Poll::Pending,
1019 }
1020 }
1021
1022 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1023 tracing::trace!(
1024 session_id = self.state.session_id(),
1025 "polling flush on socket reader inside session"
1026 );
1027 let inner = &mut self.state.segment_egress_send;
1028 pin_mut!(inner);
1029 inner.poll_flush(cx).map_err(std::io::Error::other)
1030 }
1031
1032 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1033 tracing::trace!(
1034 session_id = self.state.session_id(),
1035 "polling close on socket reader inside session"
1036 );
1037 self.state.segment_egress_send.close_channel();
1039 Poll::Ready(Ok(()))
1040 }
1041}
1042
1043impl<const C: usize> AsyncRead for SessionSocket<C> {
1044 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
1045 tracing::trace!(
1046 session_id = self.state.session_id(),
1047 "polling read on socket reader inside session"
1048 );
1049 let inner = &mut self.frame_egress;
1050 pin_mut!(inner);
1051 inner.poll_read(cx, buf)
1052 }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use std::iter::Extend;
1058
1059 use futures::{
1060 future::Either,
1061 io::{AsyncReadExt, AsyncWriteExt},
1062 pin_mut,
1063 };
1064 use hex_literal::hex;
1065 use parameterized::parameterized;
1066 use rand::{Rng, SeedableRng, rngs::StdRng};
1067 use test_log::test;
1068
1069 use super::*;
1070 use crate::{
1071 session::utils::{FaultyNetwork, FaultyNetworkConfig, NetworkStats},
1072 utils::DuplexIO,
1073 };
1074
1075 const MTU: usize = 466; const RNG_SEED: [u8; 32] = hex!("d8a471f1c20490a3442b96fdde9d1807428096e1601b0cef0eea7e6d44a24c01");
1079
1080 fn setup_alice_bob(
1081 cfg: SessionConfig,
1082 network_cfg: FaultyNetworkConfig,
1083 alice_stats: Option<NetworkStats>,
1084 bob_stats: Option<NetworkStats>,
1085 ) -> (SessionSocket<MTU>, SessionSocket<MTU>) {
1086 let (alice_stats, bob_stats) = alice_stats
1087 .zip(bob_stats)
1088 .map(|(alice, bob)| {
1089 (
1090 NetworkStats {
1091 packets_sent: bob.packets_sent,
1092 bytes_sent: bob.bytes_sent,
1093 packets_received: alice.packets_received,
1094 bytes_received: alice.bytes_received,
1095 },
1096 NetworkStats {
1097 packets_sent: alice.packets_sent,
1098 bytes_sent: alice.bytes_sent,
1099 packets_received: bob.packets_received,
1100 bytes_received: bob.bytes_received,
1101 },
1102 )
1103 })
1104 .unzip();
1105
1106 let (alice_reader, alice_writer) = FaultyNetwork::<MTU>::new(network_cfg, alice_stats).split();
1107 let (bob_reader, bob_writer) = FaultyNetwork::<MTU>::new(network_cfg, bob_stats).split();
1108
1109 let alice_to_bob = SessionSocket::new("alice", DuplexIO(alice_reader, bob_writer), cfg.clone());
1110 let bob_to_alice = SessionSocket::new("bob", DuplexIO(bob_reader, alice_writer), cfg.clone());
1111
1112 (alice_to_bob, bob_to_alice)
1113 }
1114
1115 async fn send_and_recv<S>(
1116 num_frames: usize,
1117 frame_size: usize,
1118 alice: S,
1119 bob: S,
1120 timeout: Duration,
1121 alice_to_bob_only: bool,
1122 randomized_frame_sizes: bool,
1123 ) where
1124 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
1125 {
1126 #[derive(PartialEq, Eq)]
1127 enum Direction {
1128 Send,
1129 Recv,
1130 Both,
1131 }
1132
1133 let frame_sizes = if randomized_frame_sizes {
1134 let norm_dist = rand_distr::Normal::new(frame_size as f64 * 0.75, frame_size as f64 / 4.0).unwrap();
1135 StdRng::from_seed(RNG_SEED)
1136 .sample_iter(norm_dist)
1137 .map(|s| (s as usize).max(10).min(2 * frame_size))
1138 .take(num_frames)
1139 .collect::<Vec<_>>()
1140 } else {
1141 std::iter::repeat_n(frame_size, num_frames).collect::<Vec<_>>()
1142 };
1143
1144 let socket_worker = |mut socket: S, d: Direction| {
1145 let frame_sizes = frame_sizes.clone();
1146 let frame_sizes_total = frame_sizes.iter().sum();
1147 async move {
1148 let mut received = Vec::with_capacity(frame_sizes_total);
1149 let mut sent = Vec::with_capacity(frame_sizes_total);
1150
1151 if d == Direction::Send || d == Direction::Both {
1152 for frame_size in &frame_sizes {
1153 let mut write = vec![0u8; *frame_size];
1154 hopr_crypto_random::random_fill(&mut write);
1155 let _ = socket.write(&write).await?;
1156 sent.extend(write);
1157 }
1158 }
1159
1160 if d == Direction::Recv || d == Direction::Both {
1161 while received.len() < frame_sizes_total {
1163 let mut buffer = [0u8; 2048];
1164 let read = socket.read(&mut buffer).await?;
1165 received.extend(buffer.into_iter().take(read));
1166 }
1167 }
1168
1169 Ok::<_, std::io::Error>((sent, received))
1174 }
1175 };
1176
1177 let alice_worker = tokio::task::spawn(socket_worker(
1178 alice,
1179 if alice_to_bob_only {
1180 Direction::Send
1181 } else {
1182 Direction::Both
1183 },
1184 ));
1185 let bob_worker = tokio::task::spawn(socket_worker(
1186 bob,
1187 if alice_to_bob_only {
1188 Direction::Recv
1189 } else {
1190 Direction::Both
1191 },
1192 ));
1193
1194 let send_recv = futures::future::join(
1195 async move { alice_worker.await.expect("alice should not fail") },
1196 async move { bob_worker.await.expect("bob should not fail") },
1197 );
1198 let timeout = tokio::time::sleep(timeout);
1199
1200 pin_mut!(send_recv);
1201 pin_mut!(timeout);
1202
1203 match futures::future::select(send_recv, timeout).await {
1204 Either::Left(((Ok((alice_sent, alice_recv)), Ok((bob_sent, bob_recv))), _)) => {
1205 assert_eq!(
1206 hex::encode(alice_sent),
1207 hex::encode(bob_recv),
1208 "alice sent must be equal to bob received"
1209 );
1210 assert_eq!(
1211 hex::encode(bob_sent),
1212 hex::encode(alice_recv),
1213 "bob sent must be equal to alice received",
1214 );
1215 }
1216 Either::Left(((Err(e), _), _)) => panic!("alice send recv error: {e}"),
1217 Either::Left(((_, Err(e)), _)) => panic!("bob send recv error: {e}"),
1218 Either::Right(_) => panic!("timeout"),
1219 }
1220 }
1221
1222 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1223 #[parameterized_macro(tokio::test)]
1224 async fn reliable_send_recv_with_no_acks(num_frames: usize, frame_size: usize) {
1225 let cfg = SessionConfig {
1226 enabled_features: HashSet::new(),
1227 ..Default::default()
1228 };
1229
1230 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, Default::default(), None, None);
1231
1232 send_and_recv(
1233 num_frames,
1234 frame_size,
1235 alice_to_bob,
1236 bob_to_alice,
1237 Duration::from_secs(10),
1238 false,
1239 false,
1240 )
1241 .await;
1242 }
1243
1244 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1245 #[parameterized_macro(tokio::test)]
1246 async fn reliable_send_recv_with_with_acks(num_frames: usize, frame_size: usize) {
1247 let cfg = SessionConfig { ..Default::default() };
1248
1249 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, Default::default(), None, None);
1250
1251 send_and_recv(
1252 num_frames,
1253 frame_size,
1254 alice_to_bob,
1255 bob_to_alice,
1256 Duration::from_secs(10),
1257 false,
1258 false,
1259 )
1260 .await;
1261 }
1262
1263 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1264 #[parameterized_macro(tokio::test)]
1265 async fn unreliable_send_recv(num_frames: usize, frame_size: usize) {
1266 let cfg = SessionConfig {
1267 rto_base_receiver: Duration::from_millis(10),
1268 rto_base_sender: Duration::from_millis(500),
1269 frame_expiration_age: Duration::from_secs(30),
1270 backoff_base: 1.001,
1271 ..Default::default()
1272 };
1273
1274 let net_cfg = FaultyNetworkConfig {
1275 fault_prob: 0.33,
1276 ..Default::default()
1277 };
1278
1279 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1280
1281 send_and_recv(
1282 num_frames,
1283 frame_size,
1284 alice_to_bob,
1285 bob_to_alice,
1286 Duration::from_secs(30),
1287 false,
1288 false,
1289 )
1290 .await;
1291 }
1292
1293 #[ignore]
1294 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1295 #[parameterized_macro(tokio::test)]
1296 async fn unreliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1297 let cfg = SessionConfig {
1298 rto_base_receiver: Duration::from_millis(10),
1299 rto_base_sender: Duration::from_millis(500),
1300 frame_expiration_age: Duration::from_secs(30),
1301 backoff_base: 1.001,
1302 ..Default::default()
1303 };
1304
1305 let net_cfg = FaultyNetworkConfig {
1306 fault_prob: 0.20,
1307 mixing_factor: 2,
1308 ..Default::default()
1309 };
1310
1311 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1312
1313 send_and_recv(
1314 num_frames,
1315 frame_size,
1316 alice_to_bob,
1317 bob_to_alice,
1318 Duration::from_secs(30),
1319 false,
1320 false,
1321 )
1322 .await;
1323 }
1324
1325 #[ignore]
1326 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1327 #[parameterized_macro(tokio::test)]
1328 async fn almost_reliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1329 let cfg = SessionConfig {
1330 rto_base_sender: Duration::from_millis(500),
1331 rto_base_receiver: Duration::from_millis(10),
1332 frame_expiration_age: Duration::from_secs(30),
1333 backoff_base: 1.001,
1334 ..Default::default()
1335 };
1336
1337 let net_cfg = FaultyNetworkConfig {
1338 fault_prob: 0.1,
1339 mixing_factor: 2,
1340 ..Default::default()
1341 };
1342
1343 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1344
1345 send_and_recv(
1346 num_frames,
1347 frame_size,
1348 alice_to_bob,
1349 bob_to_alice,
1350 Duration::from_secs(30),
1351 false,
1352 false,
1353 )
1354 .await;
1355 }
1356
1357 #[ignore]
1358 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1359 #[parameterized_macro(tokio::test)]
1360 async fn reliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1361 let cfg = SessionConfig {
1362 rto_base_sender: Duration::from_millis(500),
1363 rto_base_receiver: Duration::from_millis(10),
1364 frame_expiration_age: Duration::from_secs(30),
1365 backoff_base: 1.001,
1366 ..Default::default()
1367 };
1368
1369 let net_cfg = FaultyNetworkConfig {
1370 mixing_factor: 2,
1371 ..Default::default()
1372 };
1373
1374 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1375
1376 send_and_recv(
1377 num_frames,
1378 frame_size,
1379 alice_to_bob,
1380 bob_to_alice,
1381 Duration::from_secs(30),
1382 false,
1383 false,
1384 )
1385 .await;
1386 }
1387
1388 #[test(tokio::test)]
1389 async fn small_frames_should_be_sent_as_single_transport_msgs_with_buffering_disabled() {
1390 const NUM_FRAMES: usize = 10;
1391 const FRAME_SIZE: usize = 64;
1392
1393 let cfg = SessionConfig {
1394 enabled_features: HashSet::from_iter([SessionFeature::NoDelay]),
1395 ..Default::default()
1396 };
1397
1398 let alice_stats = NetworkStats::default();
1399 let bob_stats = NetworkStats::default();
1400
1401 let (alice_to_bob, bob_to_alice) = setup_alice_bob(
1402 cfg,
1403 FaultyNetworkConfig::default(),
1404 alice_stats.clone().into(),
1405 bob_stats.clone().into(),
1406 );
1407
1408 send_and_recv(
1409 NUM_FRAMES,
1410 FRAME_SIZE,
1411 alice_to_bob,
1412 bob_to_alice,
1413 Duration::from_secs(30),
1414 true,
1415 false,
1416 )
1417 .await;
1418
1419 assert_eq!(bob_stats.packets_received.load(Ordering::Relaxed), NUM_FRAMES);
1420 assert_eq!(alice_stats.packets_sent.load(Ordering::Relaxed), NUM_FRAMES);
1421
1422 assert_eq!(
1423 alice_stats.bytes_sent.load(Ordering::Relaxed),
1424 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1425 );
1426 assert_eq!(
1427 bob_stats.bytes_received.load(Ordering::Relaxed),
1428 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1429 );
1430 }
1431
1432 #[test(tokio::test)]
1433 async fn small_frames_should_be_sent_batched_in_transport_msgs_with_buffering_enabled() {
1434 const NUM_FRAMES: usize = 10;
1435 const FRAME_SIZE: usize = 64;
1436
1437 let cfg = SessionConfig {
1438 enabled_features: HashSet::new(),
1439 ..Default::default()
1440 };
1441
1442 let alice_stats = NetworkStats::default();
1443 let bob_stats = NetworkStats::default();
1444
1445 let (alice_to_bob, bob_to_alice) = setup_alice_bob(
1446 cfg,
1447 FaultyNetworkConfig::default(),
1448 alice_stats.clone().into(),
1449 bob_stats.clone().into(),
1450 );
1451
1452 send_and_recv(
1453 NUM_FRAMES,
1454 FRAME_SIZE,
1455 alice_to_bob,
1456 bob_to_alice,
1457 Duration::from_secs(30),
1458 true,
1459 false,
1460 )
1461 .await;
1462
1463 assert!(bob_stats.packets_received.load(Ordering::Relaxed) < NUM_FRAMES);
1464 assert!(alice_stats.packets_sent.load(Ordering::Relaxed) < NUM_FRAMES);
1465
1466 assert_eq!(
1467 alice_stats.bytes_sent.load(Ordering::Relaxed),
1468 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1469 );
1470 assert_eq!(
1471 bob_stats.bytes_received.load(Ordering::Relaxed),
1472 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1473 );
1474 }
1475
1476 #[test(tokio::test)]
1477 async fn receiving_on_disconnected_network_should_timeout() -> anyhow::Result<()> {
1478 let cfg = SessionConfig {
1479 rto_base_sender: Duration::from_millis(250),
1480 rto_base_receiver: Duration::from_millis(300),
1481 frame_expiration_age: Duration::from_secs(2),
1482 ..Default::default()
1483 };
1484
1485 let net_cfg = FaultyNetworkConfig {
1486 fault_prob: 1.0, mixing_factor: 0,
1488 ..Default::default()
1489 };
1490
1491 let (mut alice_to_bob, mut bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1492 let data = b"will not be delivered!";
1493
1494 let _ = alice_to_bob.write(data.as_ref()).await?;
1495
1496 let mut out = vec![0u8; data.len()];
1497 let f1 = bob_to_alice.read_exact(&mut out);
1498 let f2 = tokio::time::sleep(Duration::from_secs(3));
1499 pin_mut!(f1);
1500 pin_mut!(f2);
1501
1502 match futures::future::select(f1, f2).await {
1503 Either::Left(_) => panic!("should timeout: {:?}", out),
1504 Either::Right(_) => {}
1505 }
1506
1507 Ok(())
1508 }
1509
1510 #[test(tokio::test)]
1511 async fn single_frame_resend_should_be_resent_on_unreliable_network() -> anyhow::Result<()> {
1512 let cfg = SessionConfig {
1513 rto_base_sender: Duration::from_millis(250),
1514 rto_base_receiver: Duration::from_millis(300),
1515 frame_expiration_age: Duration::from_secs(10),
1516 ..Default::default()
1517 };
1518
1519 let net_cfg = FaultyNetworkConfig {
1520 fault_prob: 0.5, mixing_factor: 0,
1522 ..Default::default()
1523 };
1524
1525 let (mut alice_to_bob, mut bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1526 let data = b"will be re-delivered!";
1527
1528 let _ = alice_to_bob.write(data.as_ref()).await?;
1529
1530 let mut out = vec![0u8; data.len()];
1531 let f1 = bob_to_alice.read_exact(&mut out);
1532 let f2 = tokio::time::sleep(Duration::from_secs(5));
1533 pin_mut!(f1);
1534 pin_mut!(f2);
1535
1536 match futures::future::select(f1, f2).await {
1537 Either::Left(_) => {}
1538 Either::Right(_) => panic!("timeout"),
1539 }
1540
1541 Ok(())
1542 }
1543}