1use std::{
2 fmt::{Debug, Display, Formatter},
3 hash::{Hash, Hasher},
4 pin::Pin,
5 str::FromStr,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::{SinkExt, StreamExt, TryStreamExt};
11use hopr_network_types::{
12 prelude::SealedHost,
13 utils::{AsyncWriteSink, DuplexIO},
14};
15use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataIn, ApplicationDataOut, Tag};
16use hopr_protocol_session::{
17 AcknowledgementMode, AcknowledgementState, AcknowledgementStateConfig, ReliableSocket, SessionSocketConfig,
18 UnreliableSocket,
19};
20use hopr_protocol_start::StartProtocol;
21use hopr_types::{
22 internal::{prelude::HoprPseudonym, routing::DestinationRouting},
23 primitive::{
24 errors::GeneralError,
25 prelude::{BytesRepresentable, ToHex},
26 },
27};
28use tracing::{debug, instrument};
29
30use crate::{Capabilities, Capability, errors::TransportSessionError};
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub struct ByteCapabilities(pub Capabilities);
35
36impl TryFrom<u8> for ByteCapabilities {
37 type Error = GeneralError;
38
39 fn try_from(value: u8) -> Result<Self, Self::Error> {
40 Capabilities::new(value)
41 .map(Self)
42 .map_err(|_| GeneralError::ParseError("capabilities".into()))
43 }
44}
45
46impl From<ByteCapabilities> for u8 {
47 fn from(value: ByteCapabilities) -> Self {
48 *value.0.as_ref()
49 }
50}
51
52impl From<ByteCapabilities> for Capabilities {
53 fn from(value: ByteCapabilities) -> Self {
54 value.0
55 }
56}
57
58impl From<Capabilities> for ByteCapabilities {
59 fn from(value: Capabilities) -> Self {
60 Self(value)
61 }
62}
63
64impl AsRef<Capabilities> for ByteCapabilities {
65 fn as_ref(&self) -> &Capabilities {
66 &self.0
67 }
68}
69
70pub type HoprStartProtocol = StartProtocol<SessionId, SessionTarget, ByteCapabilities>;
72
73const fn max_decimal_digits_for_n_bytes(n: usize) -> usize {
78 const LOG10_2_SCALED: u64 = 301030;
80 const SCALE: u64 = 1_000_000;
81
82 let scaled = 8 * n as u64 * LOG10_2_SCALED;
84
85 scaled.div_ceil(SCALE) as usize
86}
87
88const MAX_SESSION_ID_STR_LEN: usize = 2 + 2 * HoprPseudonym::SIZE + 1 + max_decimal_digits_for_n_bytes(Tag::SIZE);
90
91#[derive(Clone, Copy)]
97pub struct SessionId {
98 tag: Tag,
99 pseudonym: HoprPseudonym,
100 cached: arrayvec::ArrayString<MAX_SESSION_ID_STR_LEN>,
106}
107
108impl SessionId {
109 const DELIMITER: char = ':';
110
111 pub fn new<T: Into<Tag>>(tag: T, pseudonym: HoprPseudonym) -> Self {
112 let tag = tag.into();
113 let mut cached = format!("{pseudonym}{}{tag}", Self::DELIMITER);
114 cached.truncate(MAX_SESSION_ID_STR_LEN);
115
116 Self {
117 tag,
118 pseudonym,
119 cached: cached.parse().expect("cannot fail due to truncation"),
120 }
121 }
122
123 pub fn tag(&self) -> Tag {
124 self.tag
125 }
126
127 pub fn pseudonym(&self) -> &HoprPseudonym {
128 &self.pseudonym
129 }
130
131 pub fn as_str(&self) -> &str {
132 &self.cached
133 }
134}
135
136impl FromStr for SessionId {
137 type Err = TransportSessionError;
138
139 fn from_str(s: &str) -> Result<Self, Self::Err> {
140 s.split_once(Self::DELIMITER)
141 .ok_or(TransportSessionError::InvalidSessionId)
142 .and_then(
143 |(pseudonym, tag)| match (HoprPseudonym::from_hex(pseudonym), Tag::from_str(tag)) {
144 (Ok(p), Ok(t)) => Ok(Self::new(t, p)),
145 _ => Err(TransportSessionError::InvalidSessionId),
146 },
147 )
148 }
149}
150
151impl serde::Serialize for SessionId {
152 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: serde::Serializer,
155 {
156 use serde::ser::SerializeStruct;
157 let mut state = serializer.serialize_struct("SessionId", 2)?;
158 state.serialize_field("tag", &self.tag)?;
159 state.serialize_field("pseudonym", &self.pseudonym)?;
160 state.end()
161 }
162}
163
164impl<'de> serde::Deserialize<'de> for SessionId {
165 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
166 where
167 D: serde::Deserializer<'de>,
168 {
169 use serde::de;
170
171 #[derive(serde::Deserialize)]
172 #[serde(field_identifier, rename_all = "lowercase")]
173 enum Field {
174 Tag,
175 Pseudonym,
176 }
177
178 struct SessionIdVisitor;
179
180 impl<'de> de::Visitor<'de> for SessionIdVisitor {
181 type Value = SessionId;
182
183 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
184 formatter.write_str("struct SessionId")
185 }
186
187 fn visit_seq<A>(self, mut seq: A) -> Result<SessionId, A::Error>
188 where
189 A: de::SeqAccess<'de>,
190 {
191 Ok(SessionId::new(
192 seq.next_element::<Tag>()?
193 .ok_or_else(|| de::Error::invalid_length(0, &self))?,
194 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?,
195 ))
196 }
197
198 fn visit_map<V>(self, mut map: V) -> Result<SessionId, V::Error>
199 where
200 V: de::MapAccess<'de>,
201 {
202 let mut tag: Option<Tag> = None;
203 let mut pseudonym: Option<HoprPseudonym> = None;
204 while let Some(key) = map.next_key()? {
205 match key {
206 Field::Tag => {
207 if tag.is_some() {
208 return Err(de::Error::duplicate_field("tag"));
209 }
210 tag = Some(map.next_value()?);
211 }
212 Field::Pseudonym => {
213 if pseudonym.is_some() {
214 return Err(de::Error::duplicate_field("pseudonym"));
215 }
216 pseudonym = Some(map.next_value()?);
217 }
218 }
219 }
220
221 Ok(SessionId::new(
222 tag.ok_or_else(|| de::Error::missing_field("tag"))?,
223 pseudonym.ok_or_else(|| de::Error::missing_field("pseudonym"))?,
224 ))
225 }
226 }
227
228 const FIELDS: &[&str] = &["tag", "pseudonym"];
229 deserializer.deserialize_struct("SessionId", FIELDS, SessionIdVisitor)
230 }
231}
232
233impl Display for SessionId {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 write!(f, "{}", self.as_str())
236 }
237}
238
239impl Debug for SessionId {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "{}", self.as_str())
242 }
243}
244
245impl PartialEq for SessionId {
246 fn eq(&self, other: &Self) -> bool {
247 self.tag == other.tag && self.pseudonym == other.pseudonym
248 }
249}
250
251impl Eq for SessionId {}
252
253impl Hash for SessionId {
254 fn hash<H: Hasher>(&self, state: &mut H) {
255 self.tag.hash(state);
256 self.pseudonym.hash(state);
257 }
258}
259
260pub(crate) fn caps_to_ack_mode(caps: Capabilities) -> AcknowledgementMode {
261 if caps.contains(Capability::RetransmissionAck | Capability::RetransmissionNack) {
262 AcknowledgementMode::Both
263 } else if caps.contains(Capability::RetransmissionAck) {
264 AcknowledgementMode::Full
265 } else {
266 AcknowledgementMode::Partial
267 }
268}
269
270#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::Display)]
272pub enum ClosureReason {
273 WriteClosed,
275 EmptyRead,
277 Eviction,
279}
280
281trait AsyncReadWrite: futures::AsyncWrite + futures::AsyncRead + Send + Unpin {}
283impl<T: futures::AsyncWrite + futures::AsyncRead + Send + Unpin> AsyncReadWrite for T {}
284
285pub type ServiceId = u32;
291
292#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
295pub enum SessionTarget {
296 UdpStream(SealedHost),
298 TcpStream(SealedHost),
300 ExitNode(ServiceId),
302}
303
304#[derive(Debug)]
307pub struct IncomingSession {
308 pub session: HoprSession,
310 pub target: SessionTarget,
312}
313
314#[derive(Copy, Clone, Debug, PartialEq, Eq, smart_default::SmartDefault, serde::Serialize)]
316pub struct HoprSessionConfig {
317 #[default(Capabilities::empty())]
321 pub capabilities: Capabilities,
322 #[default(1500)]
326 pub frame_mtu: usize,
327 #[default(Duration::from_millis(800))]
331 #[serde(with = "humantime_serde")]
332 pub frame_timeout: Duration,
333}
334
335#[pin_project::pin_project]
340pub struct HoprSession {
341 id: SessionId,
342 #[pin]
343 inner: Box<dyn AsyncReadWrite>,
344 routing: DestinationRouting,
345 cfg: HoprSessionConfig,
346 on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
347}
348
349pub(crate) const SESSION_SOCKET_CAPACITY: usize = 16384;
350
351impl HoprSession {
352 #[tracing::instrument(skip_all, fields(id, routing, cfg, session_id = %id))]
360 pub fn new<Tx, Rx>(
361 id: SessionId,
362 routing: DestinationRouting,
363 cfg: HoprSessionConfig,
364 hopr: (Tx, Rx),
365 on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
366 ) -> Result<Self, TransportSessionError>
367 where
368 Tx: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Send + Sync + Unpin + 'static,
369 Rx: futures::Stream<Item = ApplicationDataIn> + Send + Sync + Unpin + 'static,
370 Tx::Error: std::error::Error + Send + Sync,
371 {
372 let routing_clone = routing.clone();
373
374 #[cfg(feature = "telemetry")]
375 let (session_id_write, session_id_read) = (id, id);
376
377 let transport = DuplexIO(
379 AsyncWriteSink::<{ ApplicationData::PAYLOAD_SIZE }, _>(hopr.0.sink_map_err(std::io::Error::other).with(
380 move |buf: Box<[u8]>| {
381 #[cfg(feature = "telemetry")]
382 crate::telemetry::record_session_write(&session_id_write, buf.len());
383 futures::future::ready(
386 ApplicationData::new(id.tag(), buf.into_vec())
387 .map(|data| (routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
388 .map_err(std::io::Error::other),
389 )
390 },
391 )),
392 hopr.1
395 .map(move |data| {
396 #[cfg(feature = "telemetry")]
397 crate::telemetry::record_session_read(&session_id_read, data.data.plain_text.len());
398 Ok::<_, std::io::Error>(data.data.plain_text)
399 })
400 .into_async_read(),
401 );
402
403 let inner: Box<dyn AsyncReadWrite> = if cfg.capabilities.contains(Capability::Segmentation) {
405 let socket_cfg = SessionSocketConfig {
406 frame_size: cfg.frame_mtu,
407 frame_timeout: cfg.frame_timeout,
408 capacity: SESSION_SOCKET_CAPACITY,
409 flush_immediately: cfg.capabilities.contains(Capability::NoDelay),
410 ..Default::default()
411 };
412
413 if cfg.capabilities.contains(Capability::RetransmissionAck)
416 || cfg.capabilities.contains(Capability::RetransmissionNack)
417 {
418 let ack_cfg = AcknowledgementStateConfig {
420 expected_packet_latency: Duration::from_millis(200),
425 mode: caps_to_ack_mode(cfg.capabilities),
426 backoff_base: 0.2,
427 max_incoming_frame_retries: 1,
428 max_outgoing_frame_retries: 2,
429 ..Default::default()
430 };
431
432 debug!(?socket_cfg, ?ack_cfg, "opening new stateful session socket");
433
434 Box::new(ReliableSocket::new(
435 transport,
436 AcknowledgementState::<{ ApplicationData::PAYLOAD_SIZE }>::new(id, ack_cfg),
437 socket_cfg,
438 #[cfg(feature = "telemetry")]
439 id,
440 )?)
441 } else {
442 debug!(?socket_cfg, "opening new stateless session socket");
443
444 Box::new(UnreliableSocket::<{ ApplicationData::PAYLOAD_SIZE }>::new_stateless(
445 id,
446 transport,
447 socket_cfg,
448 #[cfg(feature = "telemetry")]
449 id,
450 )?)
451 }
452 } else {
453 debug!("opening raw session socket");
454 Box::new(transport)
455 };
456
457 Ok(Self {
458 id,
459 inner,
460 routing,
461 cfg,
462 on_close,
463 })
464 }
465
466 pub fn id(&self) -> &SessionId {
468 &self.id
469 }
470
471 pub fn routing(&self) -> &DestinationRouting {
473 &self.routing
474 }
475
476 pub fn config(&self) -> &HoprSessionConfig {
478 &self.cfg
479 }
480}
481
482impl std::fmt::Debug for HoprSession {
483 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
484 f.debug_struct("Session")
485 .field("id", &self.id)
486 .field("routing", &self.routing)
487 .finish_non_exhaustive()
488 }
489}
490
491impl futures::AsyncRead for HoprSession {
492 #[instrument(name = "Session::poll_read", level = "trace", skip_all, fields(session_id = %self.id), ret)]
493 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
494 let this = self.project();
495 let read = futures::ready!(this.inner.poll_read(cx, buf))?;
496 if read == 0 {
497 tracing::trace!("hopr session empty read");
498 if let Some(notifier) = this.on_close.take() {
500 tracing::trace!("notifying read half closure of session");
501 notifier(*this.id, ClosureReason::EmptyRead);
502 }
503 }
504 Poll::Ready(Ok(read))
505 }
506}
507
508impl futures::AsyncWrite for HoprSession {
509 #[instrument(name = "Session::poll_write", level = "trace", skip_all, fields(session_id = %self.id), ret)]
510 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
511 self.project().inner.poll_write(cx, buf)
512 }
513
514 #[instrument(name = "Session::poll_flush", level = "trace", skip_all, fields(session_id = %self.id), ret)]
515 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
516 self.project().inner.poll_flush(cx)
517 }
518
519 #[instrument(name = "Session::poll_close", level = "trace", skip_all, fields(session_id = %self.id), ret)]
520 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
521 let this = self.project();
522 futures::ready!(this.inner.poll_close(cx))?;
523 tracing::trace!("hopr session closed");
524
525 #[cfg(feature = "telemetry")]
526 crate::telemetry::set_session_state(this.id, crate::telemetry::SessionLifecycleState::Closing);
527
528 if let Some(notifier) = this.on_close.take() {
529 tracing::trace!("notifying write half closure of session");
530 notifier(*this.id, ClosureReason::WriteClosed);
531 }
532
533 Poll::Ready(Ok(()))
534 }
535}
536
537#[cfg(feature = "runtime-tokio")]
538impl tokio::io::AsyncRead for HoprSession {
539 fn poll_read(
540 mut self: Pin<&mut Self>,
541 cx: &mut Context<'_>,
542 buf: &mut tokio::io::ReadBuf<'_>,
543 ) -> Poll<std::io::Result<()>> {
544 let slice = buf.initialize_unfilled();
545 let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
546 buf.advance(n);
547 Poll::Ready(Ok(()))
548 }
549}
550
551#[cfg(feature = "runtime-tokio")]
552impl tokio::io::AsyncWrite for HoprSession {
553 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
554 futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
555 }
556
557 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
558 futures::AsyncWrite::poll_flush(self.as_mut(), cx)
559 }
560
561 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
562 futures::AsyncWrite::poll_close(self.as_mut(), cx)
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use anyhow::Context;
569 use futures::{AsyncReadExt, AsyncWriteExt};
570 use hopr_types::{
571 crypto::prelude::*, crypto_random::Randomizable, internal::routing::RoutingOptions, primitive::prelude::*,
572 };
573
574 use super::*;
575
576 #[test]
579 fn byte_capabilities_roundtrip_via_u8() -> anyhow::Result<()> {
580 let flags: Capabilities = Capability::Segmentation.into();
581 let caps = ByteCapabilities::from(flags);
582 let byte_val: u8 = caps.into();
583 let restored = ByteCapabilities::try_from(byte_val)?;
584 assert_eq!(caps, restored);
585 Ok(())
586 }
587
588 #[test]
589 fn byte_capabilities_invalid_bits_are_rejected() {
590 assert!(ByteCapabilities::try_from(0xFF_u8).is_err());
592 }
593
594 #[test]
595 fn byte_capabilities_empty_is_zero() {
596 let caps = ByteCapabilities::from(Capabilities::empty());
597 let byte_val: u8 = caps.into();
598 assert_eq!(byte_val, 0);
599 }
600
601 #[test]
602 fn byte_capabilities_combined_flags() -> anyhow::Result<()> {
603 let caps: Capabilities = Capability::Segmentation | Capability::NoRateControl;
604 let byte_caps = ByteCapabilities::from(caps);
605 let byte_val: u8 = byte_caps.into();
606 let restored = ByteCapabilities::try_from(byte_val)?;
607 assert_eq!(*restored.as_ref(), caps);
608 Ok(())
609 }
610
611 #[test]
614 fn caps_to_ack_mode_both_when_ack_and_nack() {
615 let caps: Capabilities = Capability::RetransmissionAck | Capability::RetransmissionNack;
616 assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Both);
617 }
618
619 #[test]
620 fn caps_to_ack_mode_full_when_only_ack() {
621 let caps: Capabilities = Capability::RetransmissionAck.into();
622 assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Full);
623 }
624
625 #[test]
626 fn caps_to_ack_mode_partial_when_no_retransmission() {
627 let caps: Capabilities = Capability::Segmentation.into();
628 assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Partial);
629 }
630
631 #[test]
632 fn caps_to_ack_mode_partial_when_empty() {
633 assert_eq!(caps_to_ack_mode(Capabilities::empty()), AcknowledgementMode::Partial);
634 }
635
636 #[test]
637 fn caps_to_ack_mode_should_be_partial_when_only_nack() {
638 let caps: Capabilities = Capability::RetransmissionNack.into();
639 assert_eq!(caps_to_ack_mode(caps), AcknowledgementMode::Partial);
640 }
641
642 #[test]
645 fn closure_reason_display_values_are_stable() {
646 let reasons = [
647 ClosureReason::WriteClosed,
648 ClosureReason::EmptyRead,
649 ClosureReason::Eviction,
650 ];
651 insta::assert_debug_snapshot!(reasons);
652 }
653
654 #[test]
657 fn hopr_session_config_default_snapshot() {
658 let cfg = HoprSessionConfig::default();
659 insta::assert_yaml_snapshot!(cfg);
660 }
661
662 #[test]
665 fn session_target_variants_debug_snapshot() -> anyhow::Result<()> {
666 let targets: Vec<SessionTarget> = vec![
667 SessionTarget::UdpStream(SealedHost::Plain(
668 "127.0.0.1:8080".parse().context("parsing UDP target")?,
669 )),
670 SessionTarget::TcpStream(SealedHost::Plain("10.0.0.1:443".parse().context("parsing TCP target")?)),
671 SessionTarget::ExitNode(42),
672 ];
673 insta::assert_debug_snapshot!(targets);
674 Ok(())
675 }
676
677 #[test]
680 fn session_id_from_str_rejects_missing_delimiter() {
681 assert!(SessionId::from_str("nodelmiter").is_err());
682 }
683
684 #[test]
685 fn session_id_from_str_rejects_invalid_pseudonym() {
686 assert!(SessionId::from_str("notahexvalue:1234").is_err());
687 }
688
689 #[test]
690 fn session_id_from_str_should_reject_invalid_tag() {
691 let pseudonym = HoprPseudonym::random();
692 let hex = pseudonym.to_hex();
693 let bad = format!("{hex}:not_a_number");
694 assert!(SessionId::from_str(&bad).is_err());
695 }
696
697 #[test]
698 fn session_id_display_and_debug_should_be_identical() {
699 let id = SessionId::new(42_u64, HoprPseudonym::random());
700 assert_eq!(format!("{id}"), format!("{id:?}"));
701 }
702
703 #[test]
704 fn session_id_hash_eq_consistency() {
705 use std::collections::HashSet;
706 let pseudonym = HoprPseudonym::random();
707 let id1 = SessionId::new(1234_u64, pseudonym);
708 let id2 = SessionId::new(1234_u64, pseudonym);
709 let id3 = SessionId::new(5678_u64, pseudonym);
710 let id4 = SessionId::new(1234_u64, HoprPseudonym::random());
711
712 let mut set = HashSet::new();
713 set.insert(id1);
714 assert!(set.contains(&id2));
715 assert!(!set.contains(&id3));
716 assert!(!set.contains(&id4), "same id but different pseudonym should not match");
717 }
718
719 #[test]
722 fn test_session_id_to_str_from_str() -> anyhow::Result<()> {
723 let id = SessionId::new(1234_u64, HoprPseudonym::random());
724 assert_eq!(id.as_str(), id.to_string());
725 assert_eq!(id, SessionId::from_str(id.as_str())?);
726
727 Ok(())
728 }
729
730 #[test]
731 fn test_max_decimal_digits_for_n_bytes() {
732 assert_eq!(3, max_decimal_digits_for_n_bytes(size_of::<u8>()));
733 assert_eq!(5, max_decimal_digits_for_n_bytes(size_of::<u16>()));
734 assert_eq!(10, max_decimal_digits_for_n_bytes(size_of::<u32>()));
735 assert_eq!(20, max_decimal_digits_for_n_bytes(size_of::<u64>()));
736 }
737
738 #[test]
739 fn standard_session_id_must_fit_within_limit() {
740 let id = format!("{}:{}", SimplePseudonym::random(), Tag::Application(Tag::MAX));
741 assert!(id.len() <= MAX_SESSION_ID_STR_LEN);
742 }
743
744 #[test]
745 fn session_id_should_serialize_and_deserialize_correctly() -> anyhow::Result<()> {
746 let pseudonym = HoprPseudonym::random();
747 let tag: Tag = 1234u64.into();
748
749 let session_id_1 = SessionId::new(tag, pseudonym);
750 let data = serde_cbor_2::to_vec(&session_id_1)?;
751 let session_id_2: SessionId = serde_cbor_2::from_slice(&data)?;
752
753 assert_eq!(tag, session_id_2.tag());
754 assert_eq!(pseudonym, *session_id_2.pseudonym());
755
756 assert_eq!(session_id_1.as_str(), session_id_2.as_str());
757 assert_eq!(session_id_1, session_id_2);
758
759 Ok(())
760 }
761
762 #[test_log::test(tokio::test)]
763 async fn test_session_bidirectional_flow_without_segmentation() -> anyhow::Result<()> {
764 let dst: Address = (&ChainKeypair::random()).into();
765 let id = SessionId::new(1234_u64, HoprPseudonym::random());
766 const DATA_LEN: usize = 5000;
767
768 let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
769 let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
770
771 let mut alice_session = HoprSession::new(
772 id,
773 DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
774 Default::default(),
775 (
776 alice_tx,
777 alice_rx
778 .map(|(_, data)| ApplicationDataIn {
779 data: data.data,
780 packet_info: Default::default(),
781 })
782 .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
783 ),
784 None,
785 )?;
786
787 let mut bob_session = HoprSession::new(
788 id,
789 DestinationRouting::Return(id.pseudonym().into()),
790 Default::default(),
791 (
792 bob_tx,
793 bob_rx
794 .map(|(_, data)| ApplicationDataIn {
795 data: data.data,
796 packet_info: Default::default(),
797 })
798 .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
799 ),
800 None,
801 )?;
802
803 let alice_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
804 let bob_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
805
806 let mut bob_recv = [0u8; DATA_LEN];
807 let mut alice_recv = [0u8; DATA_LEN];
808
809 tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
810 .await
811 .context("alice write failed")?
812 .context("alice write timed out")?;
813 alice_session.flush().await?;
814
815 tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
816 .await
817 .context("bob write failed")?
818 .context("bob write timed out")?;
819 bob_session.flush().await?;
820
821 tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
822 .await
823 .context("bob read failed")?
824 .context("bob read timed out")?;
825
826 tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
827 .await
828 .context("alice read failed")?
829 .context("alice read timed out")?;
830
831 assert_eq!(&alice_sent, bob_recv.as_slice());
832 assert_eq!(bob_sent, alice_recv);
833
834 Ok(())
835 }
836
837 #[test_log::test(tokio::test)]
838 async fn test_session_bidirectional_flow_with_segmentation() -> anyhow::Result<()> {
839 let dst: Address = (&ChainKeypair::random()).into();
840 let id = SessionId::new(1234_u64, HoprPseudonym::random());
841 const DATA_LEN: usize = 5000;
842
843 let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
844 let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
845
846 let mut alice_session = HoprSession::new(
847 id,
848 DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
849 HoprSessionConfig {
850 capabilities: Capability::Segmentation.into(),
851 ..Default::default()
852 },
853 (
854 alice_tx,
855 alice_rx
856 .map(|(_, data)| ApplicationDataIn {
857 data: data.data,
858 packet_info: Default::default(),
859 })
860 .inspect(|d| debug!("alice rcvd: {}", d.data.total_len())),
861 ),
862 None,
863 )?;
864
865 let mut bob_session = HoprSession::new(
866 id,
867 DestinationRouting::Return(id.pseudonym().into()),
868 HoprSessionConfig {
869 capabilities: Capability::Segmentation.into(),
870 ..Default::default()
871 },
872 (
873 bob_tx,
874 bob_rx
875 .map(|(_, data)| ApplicationDataIn {
876 data: data.data,
877 packet_info: Default::default(),
878 })
879 .inspect(|d| debug!("bob rcvd: {}", d.data.total_len())),
880 ),
881 None,
882 )?;
883
884 let alice_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
885 let bob_sent = hopr_types::crypto_random::random_bytes::<DATA_LEN>();
886
887 let mut bob_recv = [0u8; DATA_LEN];
888 let mut alice_recv = [0u8; DATA_LEN];
889
890 tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
891 .await
892 .context("alice write failed")?
893 .context("alice write timed out")?;
894 alice_session.flush().await?;
895
896 tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
897 .await
898 .context("bob write failed")?
899 .context("bob write timed out")?;
900 bob_session.flush().await?;
901
902 tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
903 .await
904 .context("bob read failed")?
905 .context("bob read timed out")?;
906
907 tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
908 .await
909 .context("alice read failed")?
910 .context("alice read timed out")?;
911
912 assert_eq!(alice_sent, bob_recv);
913 assert_eq!(bob_sent, alice_recv);
914
915 Ok(())
916 }
917}