1use std::{
2 fmt::{Debug, Display, Formatter},
3 hash::{Hash, Hasher},
4 pin::Pin,
5 str::FromStr,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::{SinkExt, StreamExt, TryStreamExt};
11use hopr_internal_types::prelude::HoprPseudonym;
12use hopr_network_types::{
13 prelude::{DestinationRouting, SealedHost},
14 utils::{AsyncWriteSink, DuplexIO},
15};
16use hopr_primitive_types::{
17 errors::GeneralError,
18 prelude::{BytesRepresentable, ToHex},
19};
20use hopr_protocol_app::prelude::{ApplicationData, Tag};
21use hopr_protocol_session::{
22 AcknowledgementMode, AcknowledgementState, AcknowledgementStateConfig, ReliableSocket, SessionSocketConfig,
23 UnreliableSocket,
24};
25use hopr_protocol_start::StartProtocol;
26use tracing::{debug, instrument};
27
28use crate::{Capabilities, Capability, errors::TransportSessionError};
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub struct ByteCapabilities(pub Capabilities);
33
34impl TryFrom<u8> for ByteCapabilities {
35 type Error = GeneralError;
36
37 fn try_from(value: u8) -> Result<Self, Self::Error> {
38 Capabilities::new(value)
39 .map(Self)
40 .map_err(|_| GeneralError::ParseError("capabilities".into()))
41 }
42}
43
44impl From<ByteCapabilities> for u8 {
45 fn from(value: ByteCapabilities) -> Self {
46 *value.0.as_ref()
47 }
48}
49
50impl From<ByteCapabilities> for Capabilities {
51 fn from(value: ByteCapabilities) -> Self {
52 value.0
53 }
54}
55
56impl From<Capabilities> for ByteCapabilities {
57 fn from(value: Capabilities) -> Self {
58 Self(value)
59 }
60}
61
62impl AsRef<Capabilities> for ByteCapabilities {
63 fn as_ref(&self) -> &Capabilities {
64 &self.0
65 }
66}
67
68pub type HoprStartProtocol = StartProtocol<SessionId, SessionTarget, ByteCapabilities>;
70
71const fn max_decimal_digits_for_n_bytes(n: usize) -> usize {
76 const LOG10_2_SCALED: u64 = 301030;
78 const SCALE: u64 = 1_000_000;
79
80 let scaled = 8 * n as u64 * LOG10_2_SCALED;
82
83 scaled.div_ceil(SCALE) as usize
84}
85
86const MAX_SESSION_ID_STR_LEN: usize = 2 + 2 * HoprPseudonym::SIZE + 1 + max_decimal_digits_for_n_bytes(Tag::SIZE);
88
89#[derive(Clone, Copy)]
95pub struct SessionId {
96 tag: Tag,
97 pseudonym: HoprPseudonym,
98 cached: arrayvec::ArrayString<MAX_SESSION_ID_STR_LEN>,
104}
105
106impl SessionId {
107 const DELIMITER: char = ':';
108
109 pub fn new<T: Into<Tag>>(tag: T, pseudonym: HoprPseudonym) -> Self {
110 let tag = tag.into();
111 let mut cached = format!("{pseudonym}{}{tag}", Self::DELIMITER);
112 cached.truncate(MAX_SESSION_ID_STR_LEN);
113
114 Self {
115 tag,
116 pseudonym,
117 cached: cached.parse().expect("cannot fail due to truncation"),
118 }
119 }
120
121 pub fn tag(&self) -> Tag {
122 self.tag
123 }
124
125 pub fn pseudonym(&self) -> &HoprPseudonym {
126 &self.pseudonym
127 }
128
129 pub fn as_str(&self) -> &str {
130 &self.cached
131 }
132}
133
134impl FromStr for SessionId {
135 type Err = TransportSessionError;
136
137 fn from_str(s: &str) -> Result<Self, Self::Err> {
138 s.split_once(Self::DELIMITER)
139 .ok_or(TransportSessionError::InvalidSessionId)
140 .and_then(
141 |(pseudonym, tag)| match (HoprPseudonym::from_hex(pseudonym), Tag::from_str(tag)) {
142 (Ok(p), Ok(t)) => Ok(Self::new(t, p)),
143 _ => Err(TransportSessionError::InvalidSessionId),
144 },
145 )
146 }
147}
148
149impl serde::Serialize for SessionId {
150 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
151 where
152 S: serde::Serializer,
153 {
154 use serde::ser::SerializeStruct;
155 let mut state = serializer.serialize_struct("SessionId", 2)?;
156 state.serialize_field("tag", &self.tag)?;
157 state.serialize_field("pseudonym", &self.pseudonym)?;
158 state.end()
159 }
160}
161
162impl<'de> serde::Deserialize<'de> for SessionId {
163 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
164 where
165 D: serde::Deserializer<'de>,
166 {
167 use serde::de;
168
169 #[derive(serde::Deserialize)]
170 #[serde(field_identifier, rename_all = "lowercase")]
171 enum Field {
172 Tag,
173 Pseudonym,
174 }
175
176 struct SessionIdVisitor;
177
178 impl<'de> de::Visitor<'de> for SessionIdVisitor {
179 type Value = SessionId;
180
181 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
182 formatter.write_str("struct SessionId")
183 }
184
185 fn visit_seq<A>(self, mut seq: A) -> Result<SessionId, A::Error>
186 where
187 A: de::SeqAccess<'de>,
188 {
189 Ok(SessionId::new(
190 seq.next_element::<Tag>()?
191 .ok_or_else(|| de::Error::invalid_length(0, &self))?,
192 seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?,
193 ))
194 }
195
196 fn visit_map<V>(self, mut map: V) -> Result<SessionId, V::Error>
197 where
198 V: de::MapAccess<'de>,
199 {
200 let mut tag: Option<Tag> = None;
201 let mut pseudonym: Option<HoprPseudonym> = None;
202 while let Some(key) = map.next_key()? {
203 match key {
204 Field::Tag => {
205 if tag.is_some() {
206 return Err(de::Error::duplicate_field("tag"));
207 }
208 tag = Some(map.next_value()?);
209 }
210 Field::Pseudonym => {
211 if pseudonym.is_some() {
212 return Err(de::Error::duplicate_field("pseudonym"));
213 }
214 pseudonym = Some(map.next_value()?);
215 }
216 }
217 }
218
219 Ok(SessionId::new(
220 tag.ok_or_else(|| de::Error::missing_field("tag"))?,
221 pseudonym.ok_or_else(|| de::Error::missing_field("pseudonym"))?,
222 ))
223 }
224 }
225
226 const FIELDS: &[&str] = &["tag", "pseudonym"];
227 deserializer.deserialize_struct("SessionId", FIELDS, SessionIdVisitor)
228 }
229}
230
231impl Display for SessionId {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 write!(f, "{}", self.as_str())
234 }
235}
236
237impl Debug for SessionId {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 write!(f, "{}", self.as_str())
240 }
241}
242
243impl PartialEq for SessionId {
244 fn eq(&self, other: &Self) -> bool {
245 self.tag == other.tag && self.pseudonym == other.pseudonym
246 }
247}
248
249impl Eq for SessionId {}
250
251impl Hash for SessionId {
252 fn hash<H: Hasher>(&self, state: &mut H) {
253 self.tag.hash(state);
254 self.pseudonym.hash(state);
255 }
256}
257
258fn caps_to_ack_mode(caps: Capabilities) -> AcknowledgementMode {
259 if caps.contains(Capability::RetransmissionAck | Capability::RetransmissionNack) {
260 AcknowledgementMode::Both
261 } else if caps.contains(Capability::RetransmissionAck) {
262 AcknowledgementMode::Full
263 } else {
264 AcknowledgementMode::Partial
265 }
266}
267
268#[derive(Clone, Copy, Debug, PartialEq, Eq, strum::Display)]
270pub enum ClosureReason {
271 WriteClosed,
273 EmptyRead,
275 Eviction,
277}
278
279trait AsyncReadWrite: futures::AsyncWrite + futures::AsyncRead + Send + Unpin {}
281impl<T: futures::AsyncWrite + futures::AsyncRead + Send + Unpin> AsyncReadWrite for T {}
282
283pub type ServiceId = u32;
289
290#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
293pub enum SessionTarget {
294 UdpStream(SealedHost),
296 TcpStream(SealedHost),
298 ExitNode(ServiceId),
300}
301
302#[derive(Debug)]
305pub struct IncomingSession {
306 pub session: Session,
308 pub target: SessionTarget,
310}
311
312#[pin_project::pin_project]
317pub struct Session {
318 id: SessionId,
319 #[pin]
320 inner: Box<dyn AsyncReadWrite>,
321 routing: DestinationRouting,
322 capabilities: Capabilities,
323 on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
324}
325
326impl Session {
327 #[tracing::instrument(skip(hopr, on_close), fields(session_id = %id))]
333 pub fn new<Tx, Rx, C>(
334 id: SessionId,
335 routing: DestinationRouting,
336 capabilities: C,
337 hopr: (Tx, Rx),
338 on_close: Option<Box<dyn FnOnce(SessionId, ClosureReason) + Send + Sync>>,
339 ) -> Result<Self, TransportSessionError>
340 where
341 Tx: futures::Sink<(DestinationRouting, ApplicationData)> + Send + Sync + Unpin + 'static,
342 Rx: futures::Stream<Item = Box<[u8]>> + Send + Sync + Unpin + 'static,
343 C: Into<Capabilities> + std::fmt::Debug,
344 Tx::Error: std::error::Error + Send + Sync,
345 {
346 let capabilities = capabilities.into();
347 let routing_clone = routing.clone();
348 let transport = DuplexIO(
349 AsyncWriteSink::<{ ApplicationData::PAYLOAD_SIZE }, _>(hopr.0.sink_map_err(std::io::Error::other).with(
350 move |buf| {
351 futures::future::ok::<_, std::io::Error>((
352 routing_clone.clone(),
353 ApplicationData::new_from_owned(id.tag(), buf),
354 ))
355 },
356 )),
357 hopr.1.map(Ok::<_, std::io::Error>).into_async_read(),
358 );
359
360 let inner: Box<dyn AsyncReadWrite> = if capabilities.contains(Capability::Segmentation) {
362 let socket_cfg = SessionSocketConfig {
364 frame_size: 1500,
365 frame_timeout: Duration::from_millis(800),
366 capacity: 16384,
367 flush_immediately: capabilities.contains(Capability::NoDelay),
368 ..Default::default()
369 };
370
371 if capabilities.contains(Capability::RetransmissionAck | Capability::RetransmissionNack) {
372 let ack_cfg = AcknowledgementStateConfig {
374 expected_packet_latency: Duration::from_millis(200),
379 mode: caps_to_ack_mode(capabilities),
380 backoff_base: 0.2,
381 max_incoming_frame_retries: 1,
382 max_outgoing_frame_retries: 2,
383 ..Default::default()
384 };
385
386 debug!(?socket_cfg, ?ack_cfg, "opening new stateful session socket");
387
388 Box::new(ReliableSocket::new(
389 transport,
390 AcknowledgementState::<{ ApplicationData::PAYLOAD_SIZE }>::new(id, ack_cfg),
391 socket_cfg,
392 )?)
393 } else {
394 debug!(?socket_cfg, "opening new stateless session socket");
395
396 Box::new(UnreliableSocket::<{ ApplicationData::PAYLOAD_SIZE }>::new_stateless(
397 id, transport, socket_cfg,
398 )?)
399 }
400 } else {
401 debug!("opening raw session socket");
402 Box::new(transport)
403 };
404
405 Ok(Self {
406 id,
407 inner,
408 routing,
409 capabilities,
410 on_close,
411 })
412 }
413
414 pub fn id(&self) -> &SessionId {
416 &self.id
417 }
418
419 pub fn routing(&self) -> &DestinationRouting {
421 &self.routing
422 }
423
424 pub fn capabilities(&self) -> &Capabilities {
426 &self.capabilities
427 }
428}
429
430impl std::fmt::Debug for Session {
431 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
432 f.debug_struct("Session")
433 .field("id", &self.id)
434 .field("routing", &self.routing)
435 .finish_non_exhaustive()
436 }
437}
438
439impl futures::AsyncRead for Session {
440 #[instrument(name = "Session::poll_read", level = "trace", skip(self, cx, buf), fields(session_id = %self.id), ret)]
441 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
442 let this = self.project();
443 let read = futures::ready!(this.inner.poll_read(cx, buf))?;
444 if read == 0 {
445 tracing::trace!("hopr session empty read");
446 if let Some(notifier) = this.on_close.take() {
448 tracing::trace!("notifying read half closure of session");
449 notifier(*this.id, ClosureReason::EmptyRead);
450 }
451 }
452 Poll::Ready(Ok(read))
453 }
454}
455
456impl futures::AsyncWrite for Session {
457 #[instrument(name = "Session::poll_write", level = "trace", skip(self, cx, buf), fields(session_id = %self.id), ret)]
458 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
459 self.project().inner.poll_write(cx, buf)
460 }
461
462 #[instrument(name = "Session::poll_flush", level = "trace", skip(self, cx), fields(session_id = %self.id), ret)]
463 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
464 self.project().inner.poll_flush(cx)
465 }
466
467 #[instrument(name = "Session::poll_close", level = "trace", skip(self, cx), fields(session_id = %self.id), ret)]
468 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
469 let this = self.project();
470 futures::ready!(this.inner.poll_close(cx))?;
471 tracing::trace!("hopr session closed");
472
473 if let Some(notifier) = this.on_close.take() {
474 tracing::trace!("notifying write half closure of session");
475 notifier(*this.id, ClosureReason::WriteClosed);
476 }
477
478 Poll::Ready(Ok(()))
479 }
480}
481
482#[cfg(feature = "runtime-tokio")]
483impl tokio::io::AsyncRead for Session {
484 fn poll_read(
485 mut self: Pin<&mut Self>,
486 cx: &mut Context<'_>,
487 buf: &mut tokio::io::ReadBuf<'_>,
488 ) -> Poll<std::io::Result<()>> {
489 let slice = buf.initialize_unfilled();
490 let n = std::task::ready!(futures::AsyncRead::poll_read(self.as_mut(), cx, slice))?;
491 buf.advance(n);
492 Poll::Ready(Ok(()))
493 }
494}
495
496#[cfg(feature = "runtime-tokio")]
497impl tokio::io::AsyncWrite for Session {
498 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
499 futures::AsyncWrite::poll_write(self.as_mut(), cx, buf)
500 }
501
502 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
503 futures::AsyncWrite::poll_flush(self.as_mut(), cx)
504 }
505
506 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
507 futures::AsyncWrite::poll_close(self.as_mut(), cx)
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use anyhow::Context;
514 use futures::{AsyncReadExt, AsyncWriteExt};
515 use hopr_crypto_random::Randomizable;
516 use hopr_crypto_types::prelude::*;
517 use hopr_network_types::prelude::*;
518 use hopr_primitive_types::prelude::*;
519
520 use super::*;
521
522 #[test]
523 fn test_session_id_to_str_from_str() -> anyhow::Result<()> {
524 let id = SessionId::new(1234_u64, HoprPseudonym::random());
525 assert_eq!(id.as_str(), id.to_string());
526 assert_eq!(id, SessionId::from_str(id.as_str())?);
527
528 Ok(())
529 }
530
531 #[test]
532 fn test_max_decimal_digits_for_n_bytes() {
533 assert_eq!(3, max_decimal_digits_for_n_bytes(size_of::<u8>()));
534 assert_eq!(5, max_decimal_digits_for_n_bytes(size_of::<u16>()));
535 assert_eq!(10, max_decimal_digits_for_n_bytes(size_of::<u32>()));
536 assert_eq!(20, max_decimal_digits_for_n_bytes(size_of::<u64>()));
537 }
538
539 #[test]
540 fn standard_session_id_must_fit_within_limit() {
541 let id = format!("{}:{}", SimplePseudonym::random(), Tag::Application(Tag::MAX));
542 assert!(id.len() <= MAX_SESSION_ID_STR_LEN);
543 }
544
545 #[test]
546 fn session_id_should_serialize_and_deserialize_correctly() -> anyhow::Result<()> {
547 let pseudonym = HoprPseudonym::random();
548 let tag: Tag = 1234u64.into();
549
550 let session_id_1 = SessionId::new(tag, pseudonym);
551 let data = serde_cbor_2::to_vec(&session_id_1)?;
552 let session_id_2: SessionId = serde_cbor_2::from_slice(&data)?;
553
554 assert_eq!(tag, session_id_2.tag());
555 assert_eq!(pseudonym, *session_id_2.pseudonym());
556
557 assert_eq!(session_id_1.as_str(), session_id_2.as_str());
558 assert_eq!(session_id_1, session_id_2);
559
560 Ok(())
561 }
562
563 #[test_log::test(tokio::test)]
564 async fn test_session_bidirectional_flow_without_segmentation() -> anyhow::Result<()> {
565 let dst: Address = (&ChainKeypair::random()).into();
566 let id = SessionId::new(1234_u64, HoprPseudonym::random());
567 const DATA_LEN: usize = 5000;
568
569 let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationData)>();
570 let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationData)>();
571
572 let mut alice_session = Session::new(
573 id,
574 DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
575 None,
576 (
577 alice_tx,
578 alice_rx
579 .map(|(_, data)| data.plain_text)
580 .inspect(|d| debug!("alice rcvd: {}", d.len())),
581 ),
582 None,
583 )?;
584
585 let mut bob_session = Session::new(
586 id,
587 DestinationRouting::Return(id.pseudonym().into()),
588 None,
589 (
590 bob_tx,
591 bob_rx
592 .map(|(_, data)| data.plain_text)
593 .inspect(|d| debug!("bob rcvd: {}", d.len())),
594 ),
595 None,
596 )?;
597
598 let alice_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
599 let bob_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
600
601 let mut bob_recv = [0u8; DATA_LEN];
602 let mut alice_recv = [0u8; DATA_LEN];
603
604 tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
605 .await
606 .context("alice write failed")?
607 .context("alice write timed out")?;
608 alice_session.flush().await?;
609
610 tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
611 .await
612 .context("bob write failed")?
613 .context("bob write timed out")?;
614 bob_session.flush().await?;
615
616 tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
617 .await
618 .context("bob read failed")?
619 .context("bob read timed out")?;
620
621 tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
622 .await
623 .context("alice read failed")?
624 .context("alice read timed out")?;
625
626 assert_eq!(&alice_sent, bob_recv.as_slice());
627 assert_eq!(bob_sent, alice_recv);
628
629 Ok(())
630 }
631
632 #[test_log::test(tokio::test)]
633 async fn test_session_bidirectional_flow_with_segmentation() -> anyhow::Result<()> {
634 let dst: Address = (&ChainKeypair::random()).into();
635 let id = SessionId::new(1234_u64, HoprPseudonym::random());
636 const DATA_LEN: usize = 5000;
637
638 let (alice_tx, bob_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationData)>();
639 let (bob_tx, alice_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationData)>();
640
641 let mut alice_session = Session::new(
642 id,
643 DestinationRouting::forward_only(dst, RoutingOptions::Hops(0.try_into()?)),
644 Capability::Segmentation,
645 (
646 alice_tx,
647 alice_rx
648 .map(|(_, data)| data.plain_text)
649 .inspect(|d| debug!("alice rcvd: {}", d.len())),
650 ),
651 None,
652 )?;
653
654 let mut bob_session = Session::new(
655 id,
656 DestinationRouting::Return(id.pseudonym().into()),
657 Capability::Segmentation,
658 (
659 bob_tx,
660 bob_rx
661 .map(|(_, data)| data.plain_text)
662 .inspect(|d| debug!("bob rcvd: {}", d.len())),
663 ),
664 None,
665 )?;
666
667 let alice_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
668 let bob_sent = hopr_crypto_random::random_bytes::<DATA_LEN>();
669
670 let mut bob_recv = [0u8; DATA_LEN];
671 let mut alice_recv = [0u8; DATA_LEN];
672
673 tokio::time::timeout(Duration::from_secs(1), alice_session.write_all(&alice_sent))
674 .await
675 .context("alice write failed")?
676 .context("alice write timed out")?;
677 alice_session.flush().await?;
678
679 tokio::time::timeout(Duration::from_secs(1), bob_session.write_all(&bob_sent))
680 .await
681 .context("bob write failed")?
682 .context("bob write timed out")?;
683 bob_session.flush().await?;
684
685 tokio::time::timeout(Duration::from_secs(1), bob_session.read_exact(&mut bob_recv))
686 .await
687 .context("bob read failed")?
688 .context("bob read timed out")?;
689
690 tokio::time::timeout(Duration::from_secs(1), alice_session.read_exact(&mut alice_recv))
691 .await
692 .context("alice read failed")?
693 .context("alice read timed out")?;
694
695 assert_eq!(alice_sent, bob_recv);
696 assert_eq!(bob_sent, alice_recv);
697
698 Ok(())
699 }
700}