1use crossbeam_queue::ArrayQueue;
89use crossbeam_skiplist::SkipMap;
90use dashmap::mapref::entry::Entry;
91use dashmap::DashMap;
92use futures::channel::mpsc::UnboundedSender;
93use futures::future::BoxFuture;
94use futures::{
95 pin_mut, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, Sink, SinkExt, StreamExt, TryStreamExt,
96};
97use governor::Quota;
98use smart_default::SmartDefault;
99use std::collections::{BTreeSet, HashSet};
100use std::fmt::{Debug, Display};
101use std::future::Future;
102use std::pin::Pin;
103use std::sync::atomic::{AtomicU32, Ordering};
104use std::sync::Arc;
105use std::task::{Context, Poll};
106use std::time::{Duration, Instant};
107use tracing::{debug, error, trace, warn};
108
109use hopr_async_runtime::prelude::spawn;
110
111use crate::errors::NetworkTypeError;
112use crate::prelude::protocol::SessionMessageIter;
113use crate::session::errors::SessionError;
114use crate::session::frame::{segment, FrameId, FrameReassembler, Segment, SegmentId};
115use crate::session::protocol::{FrameAcknowledgements, SegmentRequest, SessionMessage};
116use crate::session::utils::{RetryResult, RetryToken};
117use crate::utils::AsyncReadStreamer;
118
119#[cfg(all(feature = "prometheus", not(test)))]
120lazy_static::lazy_static! {
121 static ref METRIC_TIME_TO_ACK: hopr_metrics::MultiHistogram =
122 hopr_metrics::MultiHistogram::new(
123 "hopr_session_time_to_ack",
124 "Time in seconds until a complete frame gets acknowledged by the recipient",
125 vec![0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0],
126 &["session_id"]
127 ).unwrap();
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum SessionFeature {
133 RequestIncompleteFrames,
135 RetransmitFrames,
138 AcknowledgeFrames,
140 NoDelay,
142}
143
144impl SessionFeature {
145 fn default_features() -> Vec<SessionFeature> {
152 vec![
153 SessionFeature::AcknowledgeFrames,
154 SessionFeature::RequestIncompleteFrames,
155 SessionFeature::RetransmitFrames,
156 ]
157 }
158}
159
160#[derive(Debug, Clone, SmartDefault, validator::Validate)]
162pub struct SessionConfig {
163 #[default = 50_000]
170 pub max_buffered_segments: usize,
171
172 #[default = 1024]
179 #[validate(range(min = 1))]
180 pub acknowledged_frames_buffer: usize,
181
182 #[default(Duration::from_secs(30))]
187 pub frame_expiration_age: Duration,
188
189 #[default(Duration::from_millis(1000))]
201 pub rto_base_receiver: Duration,
202
203 #[default(Duration::from_millis(1500))]
215 pub rto_base_sender: Duration,
216
217 #[default(2.0)]
221 #[validate(range(min = 1.0))]
222 pub backoff_base: f64,
223
224 #[default(0.05)]
229 #[validate(range(min = 0.0, max = 0.25))]
230 pub rto_jitter: f64,
231
232 #[default(_code = "HashSet::from_iter(SessionFeature::default_features())")]
236 pub enabled_features: HashSet<SessionFeature>,
237}
238
239#[derive(Debug, Clone)]
248pub struct SessionState<const C: usize> {
249 session_id: String,
250 lookbehind: Arc<SkipMap<SegmentId, Segment>>,
251 to_acknowledge: Arc<ArrayQueue<FrameId>>,
252 incoming_frame_retries: Arc<DashMap<FrameId, RetryToken>>,
253 outgoing_frame_resends: Arc<DashMap<FrameId, RetryToken>>,
254 outgoing_frame_id: Arc<AtomicU32>,
255 frame_reassembler: Arc<FrameReassembler>,
256 cfg: SessionConfig,
257 segment_egress_send: UnboundedSender<SessionMessage<C>>,
258}
259
260fn maybe_fused_future<'a, F>(condition: bool, future: F) -> futures::future::Fuse<BoxFuture<'a, ()>>
261where
262 F: Future<Output = ()> + Send + Sync + 'a,
263{
264 if condition {
265 future.boxed()
266 } else {
267 futures::future::pending().boxed()
268 }
269 .fuse()
270}
271
272impl<const C: usize> SessionState<C> {
273 fn consume_segment(&mut self, segment: Segment) -> crate::errors::Result<()> {
274 let id = segment.id();
275
276 match self.frame_reassembler.push_segment(segment) {
277 Ok(_) => {
278 trace!(session_id = self.session_id, segment = %id, "RECEIVED: segment");
279 match self.incoming_frame_retries.entry(id.0) {
280 Entry::Occupied(e) => {
281 let rt = *e.get();
283 e.replace_entry(rt.replenish(Instant::now(), self.cfg.backoff_base));
284 }
285 Entry::Vacant(v) => {
286 v.insert(RetryToken::new(Instant::now(), self.cfg.backoff_base));
288 }
289 }
290 }
291 Err(e) => warn!(session_id = self.session_id, ?id, error = %e, "segment not pushed"),
293 }
294
295 Ok(())
296 }
297
298 fn retransmit_segments(&mut self, request: SegmentRequest<C>) -> crate::errors::Result<()> {
299 trace!(
300 session_id = self.session_id,
301 count_of_segments = request.len(),
302 "RECEIVED: request",
303 );
304
305 let mut count = 0;
306 request
307 .into_iter()
308 .filter_map(|segment_id| {
309 self.outgoing_frame_resends.remove(&segment_id.0);
311 let ret = self
312 .lookbehind
313 .get(&segment_id)
314 .map(|e| SessionMessage::<C>::Segment(e.value().clone()));
315 if ret.is_some() {
316 trace!(
317 session_id = self.session_id,
318 %segment_id,
319 "SENDING: retransmitted segment"
320 );
321 count += 1;
322 } else {
323 warn!(
324 session_id = self.session_id,
325 id = ?segment_id,
326 "segment not in lookbehind buffer anymore",
327 );
328 }
329 ret
330 })
331 .try_for_each(|msg| self.segment_egress_send.unbounded_send(msg))
332 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
333
334 trace!(session_id = self.session_id, count, "retransmitted requested segments");
335
336 Ok(())
337 }
338
339 fn acknowledged_frames(&mut self, acked: FrameAcknowledgements<C>) -> crate::errors::Result<()> {
340 trace!(
341 session_id = self.session_id,
342 count = acked.len(),
343 "RECEIVED: acknowledgement frames",
344 );
345
346 for frame_id in acked {
347 if let Some((_, rt)) = self.outgoing_frame_resends.remove(&frame_id) {
349 let to_ack = rt.time_since_creation();
350 trace!(
351 session_id = self.session_id,
352 frame_id,
353 duration_in_ms = to_ack.as_millis(),
354 "frame acknowledgement duratin"
355 );
356
357 #[cfg(all(feature = "prometheus", not(test)))]
358 METRIC_TIME_TO_ACK.observe(&[self.session_id()], to_ack.as_secs_f64())
359 }
360
361 for seg in self.lookbehind.iter().filter(|s| frame_id == s.key().0) {
362 seg.remove();
363 }
364 }
365
366 Ok(())
367 }
368
369 async fn request_missing_segments(&mut self) -> crate::errors::Result<usize> {
375 let tracked_incomplete = self.frame_reassembler.incomplete_frames();
376 trace!(
377 session_id = self.session_id,
378 count = tracked_incomplete.len(),
379 "tracking incomplete frames",
380 );
381
382 let mut to_retry = Vec::with_capacity(tracked_incomplete.len());
384 let now = Instant::now();
385 for info in tracked_incomplete {
386 match self.incoming_frame_retries.entry(info.frame_id) {
387 Entry::Occupied(e) => {
388 let rto_check = e.get().check(
390 now,
391 self.cfg.rto_base_receiver,
392 self.cfg.frame_expiration_age,
393 self.cfg.rto_jitter,
394 );
395 match rto_check {
396 RetryResult::RetryNow(next_rto) => {
397 trace!(
399 session_id = self.session_id,
400 frame_id = info.frame_id,
401 retransmission_number = next_rto.num_retry,
402 "performing frame retransmission",
403 );
404 e.replace_entry(next_rto);
405 to_retry.push(info);
406 }
407 RetryResult::Expired => {
408 debug!(
410 session_id = self.session_id,
411 frame_id = info.frame_id,
412 "frame is already expired and will be evicted"
413 );
414 e.remove();
415 }
416 RetryResult::Wait(d) => trace!(
417 session_id = self.session_id,
418 frame_id = info.frame_id,
419 timeout_in_ms = d.as_millis(),
420 next_retransmission_request_number = e.get().num_retry,
421 "frame needs to wait for next retransmission request",
422 ),
423 }
424 }
425 Entry::Vacant(v) => {
426 debug!(
428 session_id = self.session_id,
429 frame_id = info.frame_id,
430 "frame does not have a retry token"
431 );
432 v.insert(RetryToken::new(now, self.cfg.backoff_base));
433 to_retry.push(info);
434 }
435 }
436 }
437
438 let mut sent = 0;
439 let to_retry = to_retry
440 .chunks(SegmentRequest::<C>::MAX_ENTRIES)
441 .map(|chunk| Ok(SessionMessage::<C>::Request(chunk.iter().cloned().collect())))
442 .inspect(|r| {
443 trace!(
444 session_id = self.session_id,
445 result = ?r,
446 "SENDING: retransmission request"
447 );
448 sent += 1;
449 })
450 .collect::<Vec<_>>();
451
452 self.segment_egress_send
453 .send_all(&mut futures::stream::iter(to_retry))
454 .await
455 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
456
457 trace!(
458 session_id = self.session_id,
459 count = sent,
460 "RETRANSMISSION BATCH COMPLETE: sent {sent} re-send requests",
461 );
462 Ok(sent)
463 }
464
465 async fn acknowledge_segments(&mut self) -> crate::errors::Result<usize> {
476 let mut len = 0;
477 let mut msgs = 0;
478
479 while !self.to_acknowledge.is_empty() {
480 let mut ack_frames = FrameAcknowledgements::<C>::default();
481
482 while !ack_frames.is_full() && !self.to_acknowledge.is_empty() {
483 if let Some(ack_id) = self.to_acknowledge.pop() {
484 ack_frames.push(ack_id);
485 len += 1;
486 }
487 }
488
489 trace!(
490 session_id = self.session_id,
491 count = ack_frames.len(),
492 "SENDING: acknowledgements of frames",
493 );
494 self.segment_egress_send
495 .feed(SessionMessage::Acknowledge(ack_frames))
496 .await
497 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
498 msgs += 1;
499 }
500 self.segment_egress_send
501 .flush()
502 .await
503 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
504
505 trace!(
506 session_id = self.session_id,
507 count = len,
508 messages = msgs,
509 "ACK BATCH COMPLETE: sent acks in messages",
510 );
511 Ok(len)
512 }
513
514 async fn retransmit_unacknowledged_frames(&mut self) -> crate::errors::Result<usize> {
520 if self.cfg.acknowledged_frames_buffer == 0 {
521 return Ok(0);
522 }
523
524 let now = Instant::now();
525
526 let mut frames_to_resend = BTreeSet::new();
528 self.outgoing_frame_resends.retain(|frame_id, retry_log| {
529 let check_res = retry_log.check(
530 now,
531 self.cfg.rto_base_sender,
532 self.cfg.frame_expiration_age,
533 self.cfg.rto_jitter,
534 );
535 match check_res {
536 RetryResult::Wait(d) => {
537 trace!(
538 session_id = self.session_id,
539 frame_id,
540 wait_timeout_in_ms = d.as_millis(),
541 "frame will retransmit"
542 );
543 true
544 }
545 RetryResult::RetryNow(next_retry) => {
546 frames_to_resend.insert(*frame_id);
548 *retry_log = next_retry;
549 debug!(session_id = self.session_id, frame_id, "frame will self-resend now");
550 true
551 }
552 RetryResult::Expired => {
553 debug!(session_id = self.session_id, frame_id, "frame expired");
554 false
555 }
556 }
557 });
558
559 trace!(
560 session_id = self.session_id,
561 count = frames_to_resend.len(),
562 "frames will auto-resend",
563 );
564
565 let mut count = 0;
568 let frames_to_resend = frames_to_resend
569 .into_iter()
570 .flat_map(|f| self.lookbehind.iter().filter(move |e| e.key().0 == f))
571 .inspect(|e| {
572 trace!(
573 session_id = self.session_id,
574 key = ?e.key(),
575 "SENDING: auto-retransmitted"
576 );
577 count += 1
578 })
579 .map(|e| Ok(SessionMessage::<C>::Segment(e.value().clone())))
580 .collect::<Vec<_>>();
581
582 self.segment_egress_send
583 .send_all(&mut futures::stream::iter(frames_to_resend))
584 .await
585 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
586
587 trace!(
588 session_id = self.session_id,
589 count,
590 "AUTO-RETRANSMIT BATCH COMPLETE: re-sent segments",
591 );
592
593 Ok(count)
594 }
595
596 const PAYLOAD_CAPACITY: usize = C - SessionMessage::<C>::SEGMENT_OVERHEAD;
598
599 const MAX_WRITE_SIZE: usize = SessionMessage::<C>::MAX_SEGMENTS_PER_FRAME * Self::PAYLOAD_CAPACITY;
601
602 pub async fn send_frame_data(&mut self, data: &[u8]) -> crate::errors::Result<()> {
611 if !(1..=Self::MAX_WRITE_SIZE).contains(&data.len()) {
612 return Err(SessionError::IncorrectMessageLength.into());
613 }
614
615 let frame_id = self.outgoing_frame_id.fetch_add(1, Ordering::SeqCst);
616 let segments = segment(data, Self::PAYLOAD_CAPACITY, frame_id)?;
617 let count = segments.len();
618
619 for segment in segments {
620 let msg = SessionMessage::<C>::Segment(segment.clone());
621 trace!(session_id = self.session_id, id = ?segment.id(), "SENDING: segment");
622 self.segment_egress_send
623 .feed(msg)
624 .await
625 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
626
627 self.lookbehind.insert((&segment).into(), segment.clone());
629 while self.lookbehind.len() > self.cfg.max_buffered_segments {
630 self.lookbehind.pop_front();
631 }
632 }
633
634 self.segment_egress_send
635 .flush()
636 .await
637 .map_err(|e| SessionError::ProcessingError(e.to_string()))?;
638 self.outgoing_frame_resends
639 .insert(frame_id, RetryToken::new(Instant::now(), self.cfg.backoff_base));
640
641 trace!(
642 session_id = self.session_id,
643 frame_id,
644 count,
645 "FRAME SEND COMPLETE: sent segments",
646 );
647
648 Ok(())
649 }
650
651 async fn state_loop(&mut self) -> crate::errors::Result<()> {
657 let eviction_limiter =
660 governor::RateLimiter::direct(Quota::with_period(self.cfg.frame_expiration_age / 10).ok_or(
661 NetworkTypeError::Other("rate limiter frame_expiration_age invalid".into()),
662 )?);
663
664 let ack_rate_limiter = governor::RateLimiter::direct(
668 Quota::with_period(self.cfg.rto_base_sender.min(self.cfg.rto_base_receiver) / 4)
669 .ok_or(NetworkTypeError::Other("rate limiter ack rate invalid".into()))?,
670 );
671
672 let sender_retransmit = governor::RateLimiter::direct(
674 Quota::with_period(self.cfg.rto_base_sender)
675 .ok_or(NetworkTypeError::Other("rate limiter rto sender invalid".into()))?,
676 );
677
678 let receiver_retransmit = governor::RateLimiter::direct(
680 Quota::with_period(self.cfg.rto_base_receiver)
681 .ok_or(NetworkTypeError::Other("rate limiter rto receiver invalid".into()))?,
682 );
683
684 loop {
685 let mut evict_fut = eviction_limiter.until_ready().boxed().fuse();
686 let mut ack_fut = maybe_fused_future(
687 self.cfg.enabled_features.contains(&SessionFeature::AcknowledgeFrames),
688 ack_rate_limiter.until_ready(),
689 );
690 let mut r_snd_fut = maybe_fused_future(
691 self.cfg.enabled_features.contains(&SessionFeature::RetransmitFrames),
692 sender_retransmit.until_ready(),
693 );
694 let mut r_rcv_fut = maybe_fused_future(
695 self.cfg
696 .enabled_features
697 .contains(&SessionFeature::RequestIncompleteFrames),
698 receiver_retransmit.until_ready(),
699 );
700 let mut is_done = maybe_fused_future(self.segment_egress_send.is_closed(), futures::future::ready(()));
701
702 if let Err(e) = futures::select_biased! {
706 _ = is_done => {
707 Err(NetworkTypeError::Other("session writer has been closed".into()))
708 },
709 _ = r_rcv_fut => {
710 self.request_missing_segments().await
711 },
712 _ = r_snd_fut => {
713 self.retransmit_unacknowledged_frames().await
714 },
715 _ = ack_fut => {
716 self.acknowledge_segments().await
717 },
718 _ = evict_fut => {
719 self.frame_reassembler.evict().map_err(NetworkTypeError::from)
720 },
721 } {
722 debug!(session_id = self.session_id, "session is closing: {e}");
723 break;
724 }
725 }
726
727 Ok(())
728 }
729
730 pub fn session_id(&self) -> &str {
732 &self.session_id
733 }
734}
735
736impl<const C: usize> Sink<SessionMessage<C>> for SessionState<C> {
738 type Error = NetworkTypeError;
739
740 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
741 Poll::Ready(Ok(()))
742 }
743
744 fn start_send(mut self: Pin<&mut Self>, item: SessionMessage<C>) -> Result<(), Self::Error> {
745 match item {
746 SessionMessage::Segment(s) => self.consume_segment(s),
747 SessionMessage::Request(r) => self.retransmit_segments(r),
748 SessionMessage::Acknowledge(f) => self.acknowledged_frames(f),
749 }
750 }
751
752 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
753 Poll::Ready(Ok(()))
754 }
755
756 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
757 self.frame_reassembler.close();
758 Poll::Ready(Ok(()))
759 }
760}
761
762#[cfg_attr(doc, aquamarine::aquamarine)]
763pub struct SessionSocket<const C: usize> {
862 state: SessionState<C>,
863 frame_egress: Box<dyn AsyncRead + Send + Unpin>,
864}
865
866impl<const C: usize> SessionSocket<C> {
867 pub const PAYLOAD_CAPACITY: usize = SessionState::<C>::PAYLOAD_CAPACITY;
869
870 pub const MAX_WRITE_SIZE: usize = SessionState::<C>::MAX_WRITE_SIZE;
872
873 pub fn new<T, I>(id: I, transport: T, cfg: SessionConfig) -> Self
876 where
877 T: AsyncWrite + AsyncRead + Send + 'static,
878 I: Display + Send + 'static,
879 {
880 assert!(
881 C >= SessionMessage::<C>::minimum_message_size(),
882 "given MTU is too small"
883 );
884
885 let (reassembler, egress) = FrameReassembler::new(cfg.frame_expiration_age);
886
887 let to_acknowledge = Arc::new(ArrayQueue::new(cfg.acknowledged_frames_buffer.max(1)));
888 let incoming_frame_retries = Arc::new(DashMap::new());
889
890 let incoming_frame_retries_clone = incoming_frame_retries.clone();
891 let id_clone = id.to_string().clone();
892 let to_acknowledge_clone = to_acknowledge.clone();
893 let ack_enabled = cfg.enabled_features.contains(&SessionFeature::AcknowledgeFrames);
894
895 let frame_egress = Box::new(
896 egress
897 .filter_map(move |maybe_frame| {
898 match maybe_frame {
899 Ok(frame) => {
900 trace!(session_id = id_clone, frame_id = frame.frame_id, "frame completed");
901 incoming_frame_retries_clone.remove(&frame.frame_id);
903 if ack_enabled {
904 to_acknowledge_clone.force_push(frame.frame_id);
907 }
908 futures::future::ready(Some(Ok(frame)))
909 }
910 Err(SessionError::FrameDiscarded(fid)) | Err(SessionError::IncompleteFrame(fid)) => {
911 incoming_frame_retries_clone.remove(&fid);
913 warn!(session_id = id_clone, frame_id = fid, "frame skipped");
914 futures::future::ready(None) }
916 Err(e) => {
917 error!(session_id = id_clone, "error on frame reassembly: {e}");
918 futures::future::ready(Some(Err(std::io::Error::other(e))))
919 }
920 }
921 })
922 .into_async_read(),
923 );
924
925 let (segment_egress_send, segment_egress_recv) = futures::channel::mpsc::unbounded();
926
927 let (downstream_read, downstream_write) = transport.split();
928
929 let downstream_write = futures::io::BufWriter::with_capacity(
931 if !cfg.enabled_features.contains(&SessionFeature::NoDelay) {
932 C
933 } else {
934 0
935 },
936 downstream_write,
937 );
938
939 let state = SessionState {
940 lookbehind: Arc::new(SkipMap::new()),
941 outgoing_frame_id: Arc::new(AtomicU32::new(1)),
942 frame_reassembler: Arc::new(reassembler),
943 outgoing_frame_resends: Arc::new(DashMap::new()),
944 session_id: id.to_string(),
945 to_acknowledge,
946 incoming_frame_retries,
947 segment_egress_send,
948 cfg,
949 };
950
951 spawn(async move {
953 if let Err(e) = segment_egress_recv
954 .map(|m: SessionMessage<C>| Ok(m.into_encoded()))
955 .forward(downstream_write.into_sink())
956 .await
957 {
958 error!(session_id = %id, error = %e, "FINISHED: forwarding to downstream terminated with error")
959 } else {
960 debug!(session_id = %id, "FINISHED: forwarding to downstream done");
961 }
962 });
963
964 spawn(
966 AsyncReadStreamer::<C, _>(downstream_read)
967 .map_err(|e| NetworkTypeError::SessionProtocolError(SessionError::ProcessingError(e.to_string())))
968 .and_then(|m| futures::future::ok(futures::stream::iter(SessionMessageIter::from(m.into_vec()))))
969 .try_flatten()
970 .forward(state.clone()),
971 );
972
973 let mut state_clone = state.clone();
975 spawn(async move {
976 let loop_done = state_clone.state_loop().await;
977 debug!(
978 session_id = state_clone.session_id,
979 "FINISHED: state loop {loop_done:?}"
980 );
981 });
982
983 Self { state, frame_egress }
984 }
985
986 pub fn state(&self) -> &SessionState<C> {
988 &self.state
989 }
990
991 pub fn state_mut(&mut self) -> &mut SessionState<C> {
993 &mut self.state
994 }
995}
996
997impl<const C: usize> AsyncWrite for SessionSocket<C> {
998 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
999 let len_to_write = Self::MAX_WRITE_SIZE.min(buf.len());
1000 tracing::trace!(
1001 session_id = self.state.session_id(),
1002 number_of_bytes = len_to_write,
1003 "polling write of bytes on socket reader inside session",
1004 );
1005
1006 if len_to_write == 0 {
1008 return Poll::Ready(Ok(0));
1009 }
1010
1011 let mut socket_future = self.state.send_frame_data(&buf[..len_to_write]).boxed();
1012 match Pin::new(&mut socket_future).poll(cx) {
1013 Poll::Ready(Ok(())) => Poll::Ready(Ok(len_to_write)),
1014 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
1015 Poll::Pending => Poll::Pending,
1016 }
1017 }
1018
1019 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1020 tracing::trace!(
1021 session_id = self.state.session_id(),
1022 "polling flush on socket reader inside session"
1023 );
1024 let inner = &mut self.state.segment_egress_send;
1025 pin_mut!(inner);
1026 inner.poll_flush(cx).map_err(std::io::Error::other)
1027 }
1028
1029 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1030 tracing::trace!(
1031 session_id = self.state.session_id(),
1032 "polling close on socket reader inside session"
1033 );
1034 self.state.segment_egress_send.close_channel();
1036 Poll::Ready(Ok(()))
1037 }
1038}
1039
1040impl<const C: usize> AsyncRead for SessionSocket<C> {
1041 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
1042 tracing::trace!(
1043 session_id = self.state.session_id(),
1044 "polling read on socket reader inside session"
1045 );
1046 let inner = &mut self.frame_egress;
1047 pin_mut!(inner);
1048 inner.poll_read(cx, buf)
1049 }
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054 use super::*;
1055 use crate::utils::DuplexIO;
1056 use futures::future::Either;
1057 use futures::io::{AsyncReadExt, AsyncWriteExt};
1058 use futures::pin_mut;
1059 use hex_literal::hex;
1060 use parameterized::parameterized;
1061 use rand::rngs::StdRng;
1062 use rand::{Rng, SeedableRng};
1063 use std::iter::Extend;
1064 use test_log::test;
1065
1066 use crate::session::utils::{FaultyNetwork, FaultyNetworkConfig, NetworkStats};
1067
1068 const MTU: usize = 466; const RNG_SEED: [u8; 32] = hex!("d8a471f1c20490a3442b96fdde9d1807428096e1601b0cef0eea7e6d44a24c01");
1072
1073 fn setup_alice_bob(
1074 cfg: SessionConfig,
1075 network_cfg: FaultyNetworkConfig,
1076 alice_stats: Option<NetworkStats>,
1077 bob_stats: Option<NetworkStats>,
1078 ) -> (SessionSocket<MTU>, SessionSocket<MTU>) {
1079 let (alice_stats, bob_stats) = alice_stats
1080 .zip(bob_stats)
1081 .map(|(alice, bob)| {
1082 (
1083 NetworkStats {
1084 packets_sent: bob.packets_sent,
1085 bytes_sent: bob.bytes_sent,
1086 packets_received: alice.packets_received,
1087 bytes_received: alice.bytes_received,
1088 },
1089 NetworkStats {
1090 packets_sent: alice.packets_sent,
1091 bytes_sent: alice.bytes_sent,
1092 packets_received: bob.packets_received,
1093 bytes_received: bob.bytes_received,
1094 },
1095 )
1096 })
1097 .unzip();
1098
1099 let (alice_reader, alice_writer) = FaultyNetwork::<MTU>::new(network_cfg, alice_stats).split();
1100 let (bob_reader, bob_writer) = FaultyNetwork::<MTU>::new(network_cfg, bob_stats).split();
1101
1102 let alice_to_bob = SessionSocket::new("alice", DuplexIO(alice_reader, bob_writer), cfg.clone());
1103 let bob_to_alice = SessionSocket::new("bob", DuplexIO(bob_reader, alice_writer), cfg.clone());
1104
1105 (alice_to_bob, bob_to_alice)
1106 }
1107
1108 async fn send_and_recv<S>(
1109 num_frames: usize,
1110 frame_size: usize,
1111 alice: S,
1112 bob: S,
1113 timeout: Duration,
1114 alice_to_bob_only: bool,
1115 randomized_frame_sizes: bool,
1116 ) where
1117 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
1118 {
1119 #[derive(PartialEq, Eq)]
1120 enum Direction {
1121 Send,
1122 Recv,
1123 Both,
1124 }
1125
1126 let frame_sizes = if randomized_frame_sizes {
1127 let norm_dist = rand_distr::Normal::new(frame_size as f64 * 0.75, frame_size as f64 / 4.0).unwrap();
1128 StdRng::from_seed(RNG_SEED)
1129 .sample_iter(norm_dist)
1130 .map(|s| (s as usize).max(10).min(2 * frame_size))
1131 .take(num_frames)
1132 .collect::<Vec<_>>()
1133 } else {
1134 std::iter::repeat(frame_size).take(num_frames).collect::<Vec<_>>()
1135 };
1136
1137 let socket_worker = |mut socket: S, d: Direction| {
1138 let frame_sizes = frame_sizes.clone();
1139 let frame_sizes_total = frame_sizes.iter().sum();
1140 async move {
1141 let mut received = Vec::with_capacity(frame_sizes_total);
1142 let mut sent = Vec::with_capacity(frame_sizes_total);
1143
1144 if d == Direction::Send || d == Direction::Both {
1145 for frame_size in &frame_sizes {
1146 let mut write = vec![0u8; *frame_size];
1147 hopr_crypto_random::random_fill(&mut write);
1148 socket.write(&write).await?;
1149 sent.extend(write);
1150 }
1151 }
1152
1153 if d == Direction::Recv || d == Direction::Both {
1154 while received.len() < frame_sizes_total {
1156 let mut buffer = [0u8; 2048];
1157 let read = socket.read(&mut buffer).await?;
1158 received.extend(buffer.into_iter().take(read));
1159 }
1160 }
1161
1162 Ok::<_, std::io::Error>((sent, received))
1167 }
1168 };
1169
1170 let alice_worker = async_std::task::spawn(socket_worker(
1171 alice,
1172 if alice_to_bob_only {
1173 Direction::Send
1174 } else {
1175 Direction::Both
1176 },
1177 ));
1178 let bob_worker = async_std::task::spawn(socket_worker(
1179 bob,
1180 if alice_to_bob_only {
1181 Direction::Recv
1182 } else {
1183 Direction::Both
1184 },
1185 ));
1186
1187 let send_recv = futures::future::join(alice_worker, bob_worker);
1188 let timeout = async_std::task::sleep(timeout);
1189
1190 pin_mut!(send_recv);
1191 pin_mut!(timeout);
1192
1193 match futures::future::select(send_recv, timeout).await {
1194 Either::Left(((Ok((alice_sent, alice_recv)), Ok((bob_sent, bob_recv))), _)) => {
1195 assert_eq!(
1196 hex::encode(alice_sent),
1197 hex::encode(bob_recv),
1198 "alice sent must be equal to bob received"
1199 );
1200 assert_eq!(
1201 hex::encode(bob_sent),
1202 hex::encode(alice_recv),
1203 "bob sent must be equal to alice received",
1204 );
1205 }
1206 Either::Left(((Err(e), _), _)) => panic!("alice send recv error: {e}"),
1207 Either::Left(((_, Err(e)), _)) => panic!("bob send recv error: {e}"),
1208 Either::Right(_) => panic!("timeout"),
1209 }
1210 }
1211
1212 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1213 #[parameterized_macro(async_std::test)]
1214 async fn reliable_send_recv_with_no_acks(num_frames: usize, frame_size: usize) {
1215 let cfg = SessionConfig {
1216 enabled_features: HashSet::new(),
1217 ..Default::default()
1218 };
1219
1220 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, Default::default(), None, None);
1221
1222 send_and_recv(
1223 num_frames,
1224 frame_size,
1225 alice_to_bob,
1226 bob_to_alice,
1227 Duration::from_secs(10),
1228 false,
1229 false,
1230 )
1231 .await;
1232 }
1233
1234 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1235 #[parameterized_macro(async_std::test)]
1236 async fn reliable_send_recv_with_with_acks(num_frames: usize, frame_size: usize) {
1237 let cfg = SessionConfig { ..Default::default() };
1238
1239 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, Default::default(), None, None);
1240
1241 send_and_recv(
1242 num_frames,
1243 frame_size,
1244 alice_to_bob,
1245 bob_to_alice,
1246 Duration::from_secs(10),
1247 false,
1248 false,
1249 )
1250 .await;
1251 }
1252
1253 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1254 #[parameterized_macro(async_std::test)]
1255 async fn unreliable_send_recv(num_frames: usize, frame_size: usize) {
1256 let cfg = SessionConfig {
1257 rto_base_receiver: Duration::from_millis(10),
1258 rto_base_sender: Duration::from_millis(500),
1259 frame_expiration_age: Duration::from_secs(30),
1260 backoff_base: 1.001,
1261 ..Default::default()
1262 };
1263
1264 let net_cfg = FaultyNetworkConfig {
1265 fault_prob: 0.33,
1266 ..Default::default()
1267 };
1268
1269 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1270
1271 send_and_recv(
1272 num_frames,
1273 frame_size,
1274 alice_to_bob,
1275 bob_to_alice,
1276 Duration::from_secs(30),
1277 false,
1278 false,
1279 )
1280 .await;
1281 }
1282
1283 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1284 #[parameterized_macro(async_std::test)]
1285 async fn unreliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1286 let cfg = SessionConfig {
1287 rto_base_receiver: Duration::from_millis(10),
1288 rto_base_sender: Duration::from_millis(500),
1289 frame_expiration_age: Duration::from_secs(30),
1290 backoff_base: 1.001,
1291 ..Default::default()
1292 };
1293
1294 let net_cfg = FaultyNetworkConfig {
1295 fault_prob: 0.33,
1296 mixing_factor: 2,
1297 ..Default::default()
1298 };
1299
1300 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1301
1302 send_and_recv(
1303 num_frames,
1304 frame_size,
1305 alice_to_bob,
1306 bob_to_alice,
1307 Duration::from_secs(30),
1308 false,
1309 false,
1310 )
1311 .await;
1312 }
1313
1314 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1315 #[parameterized_macro(async_std::test)]
1316 async fn almost_reliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1317 let cfg = SessionConfig {
1318 rto_base_sender: Duration::from_millis(500),
1319 rto_base_receiver: Duration::from_millis(10),
1320 frame_expiration_age: Duration::from_secs(30),
1321 backoff_base: 1.001,
1322 ..Default::default()
1323 };
1324
1325 let net_cfg = FaultyNetworkConfig {
1326 fault_prob: 0.1,
1327 mixing_factor: 2,
1328 ..Default::default()
1329 };
1330
1331 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1332
1333 send_and_recv(
1334 num_frames,
1335 frame_size,
1336 alice_to_bob,
1337 bob_to_alice,
1338 Duration::from_secs(30),
1339 false,
1340 false,
1341 )
1342 .await;
1343 }
1344
1345 #[parameterized(num_frames = {10, 100, 1000}, frame_size = {1500, 1500, 1500})]
1346 #[parameterized_macro(async_std::test)]
1347 async fn reliable_send_recv_with_mixing(num_frames: usize, frame_size: usize) {
1348 let cfg = SessionConfig {
1349 rto_base_sender: Duration::from_millis(500),
1350 rto_base_receiver: Duration::from_millis(10),
1351 frame_expiration_age: Duration::from_secs(30),
1352 backoff_base: 1.001,
1353 ..Default::default()
1354 };
1355
1356 let net_cfg = FaultyNetworkConfig {
1357 mixing_factor: 2,
1358 ..Default::default()
1359 };
1360
1361 let (alice_to_bob, bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1362
1363 send_and_recv(
1364 num_frames,
1365 frame_size,
1366 alice_to_bob,
1367 bob_to_alice,
1368 Duration::from_secs(30),
1369 false,
1370 false,
1371 )
1372 .await;
1373 }
1374
1375 #[test(async_std::test)]
1376 async fn small_frames_should_be_sent_as_single_transport_msgs_with_buffering_disabled() {
1377 const NUM_FRAMES: usize = 10;
1378 const FRAME_SIZE: usize = 64;
1379
1380 let cfg = SessionConfig {
1381 enabled_features: HashSet::from_iter([SessionFeature::NoDelay]),
1382 ..Default::default()
1383 };
1384
1385 let alice_stats = NetworkStats::default();
1386 let bob_stats = NetworkStats::default();
1387
1388 let (alice_to_bob, bob_to_alice) = setup_alice_bob(
1389 cfg,
1390 FaultyNetworkConfig::default(),
1391 alice_stats.clone().into(),
1392 bob_stats.clone().into(),
1393 );
1394
1395 send_and_recv(
1396 NUM_FRAMES,
1397 FRAME_SIZE,
1398 alice_to_bob,
1399 bob_to_alice,
1400 Duration::from_secs(30),
1401 true,
1402 false,
1403 )
1404 .await;
1405
1406 assert_eq!(bob_stats.packets_received.load(Ordering::Relaxed), NUM_FRAMES);
1407 assert_eq!(alice_stats.packets_sent.load(Ordering::Relaxed), NUM_FRAMES);
1408
1409 assert_eq!(
1410 alice_stats.bytes_sent.load(Ordering::Relaxed),
1411 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1412 );
1413 assert_eq!(
1414 bob_stats.bytes_received.load(Ordering::Relaxed),
1415 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1416 );
1417 }
1418
1419 #[test(async_std::test)]
1420 async fn small_frames_should_be_sent_batched_in_transport_msgs_with_buffering_enabled() {
1421 const NUM_FRAMES: usize = 10;
1422 const FRAME_SIZE: usize = 64;
1423
1424 let cfg = SessionConfig {
1425 enabled_features: HashSet::new(),
1426 ..Default::default()
1427 };
1428
1429 let alice_stats = NetworkStats::default();
1430 let bob_stats = NetworkStats::default();
1431
1432 let (alice_to_bob, bob_to_alice) = setup_alice_bob(
1433 cfg,
1434 FaultyNetworkConfig::default(),
1435 alice_stats.clone().into(),
1436 bob_stats.clone().into(),
1437 );
1438
1439 send_and_recv(
1440 NUM_FRAMES,
1441 FRAME_SIZE,
1442 alice_to_bob,
1443 bob_to_alice,
1444 Duration::from_secs(30),
1445 true,
1446 false,
1447 )
1448 .await;
1449
1450 assert!(bob_stats.packets_received.load(Ordering::Relaxed) < NUM_FRAMES);
1451 assert!(alice_stats.packets_sent.load(Ordering::Relaxed) < NUM_FRAMES);
1452
1453 assert_eq!(
1454 alice_stats.bytes_sent.load(Ordering::Relaxed),
1455 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1456 );
1457 assert_eq!(
1458 bob_stats.bytes_received.load(Ordering::Relaxed),
1459 NUM_FRAMES * (FRAME_SIZE + SessionMessage::<MTU>::SEGMENT_OVERHEAD)
1460 );
1461 }
1462
1463 #[test(async_std::test)]
1464 async fn receiving_on_disconnected_network_should_timeout() {
1465 let cfg = SessionConfig {
1466 rto_base_sender: Duration::from_millis(250),
1467 rto_base_receiver: Duration::from_millis(300),
1468 frame_expiration_age: Duration::from_secs(2),
1469 ..Default::default()
1470 };
1471
1472 let net_cfg = FaultyNetworkConfig {
1473 fault_prob: 1.0, mixing_factor: 0,
1475 ..Default::default()
1476 };
1477
1478 let (mut alice_to_bob, mut bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1479 let data = b"will not be delivered!";
1480
1481 alice_to_bob.write(data.as_ref()).await.unwrap();
1482
1483 let mut out = vec![0u8; data.len()];
1484 let f1 = bob_to_alice.read_exact(&mut out);
1485 let f2 = async_std::task::sleep(Duration::from_secs(3));
1486 pin_mut!(f1);
1487 pin_mut!(f2);
1488
1489 match futures::future::select(f1, f2).await {
1490 Either::Left(_) => panic!("should timeout: {:?}", out),
1491 Either::Right(_) => {}
1492 }
1493 }
1494
1495 #[test(async_std::test)]
1496 async fn single_frame_resend_should_be_resent_on_unreliable_network() {
1497 let cfg = SessionConfig {
1498 rto_base_sender: Duration::from_millis(250),
1499 rto_base_receiver: Duration::from_millis(300),
1500 frame_expiration_age: Duration::from_secs(10),
1501 ..Default::default()
1502 };
1503
1504 let net_cfg = FaultyNetworkConfig {
1505 fault_prob: 0.5, mixing_factor: 0,
1507 ..Default::default()
1508 };
1509
1510 let (mut alice_to_bob, mut bob_to_alice) = setup_alice_bob(cfg, net_cfg, None, None);
1511 let data = b"will be re-delivered!";
1512
1513 alice_to_bob.write(data.as_ref()).await.unwrap();
1514
1515 let mut out = vec![0u8; data.len()];
1516 let f1 = bob_to_alice.read_exact(&mut out);
1517 let f2 = async_std::task::sleep(Duration::from_secs(5));
1518 pin_mut!(f1);
1519 pin_mut!(f2);
1520
1521 match futures::future::select(f1, f2).await {
1522 Either::Left(_) => {}
1523 Either::Right(_) => panic!("timeout"),
1524 }
1525 }
1526}