Skip to main content

hopr_utils_session/
lib.rs

1//! Session-related utilities for HOPR
2//!
3//! This module provides utility functions and structures for managing sessions,
4//! including session lifecycle management, session data handling, and common
5//! session operations.
6
7use std::{
8    collections::VecDeque, fmt::Formatter, future::Future, hash::Hash, net::SocketAddr, num::NonZeroUsize,
9    str::FromStr, sync::Arc,
10};
11
12use anyhow::anyhow;
13use base64::Engine;
14use bytesize::ByteSize;
15use dashmap::DashMap;
16use futures::{
17    FutureExt, StreamExt, TryStreamExt,
18    future::{AbortHandle, AbortRegistration},
19};
20use hopr_api::{
21    chain::HoprChainApi,
22    graph::{
23        NetworkGraphTraverse, NetworkGraphUpdate, NetworkGraphView, NetworkGraphWrite,
24        traits::{EdgeObservableRead, EdgeObservableWrite},
25    },
26    network::NetworkStreamControl,
27};
28#[cfg(not(feature = "explicit-path"))]
29use hopr_lib::HopRouting;
30#[cfg(feature = "explicit-path")]
31use hopr_lib::HoprSessionClientExplicitPathConfig;
32#[cfg(feature = "explicit-path")]
33use hopr_lib::api::types::internal::routing::RoutingOptions;
34use hopr_lib::{
35    Hopr, HoprSessionClientConfig,
36    api::{network::NetworkView, node::HoprSessionClientOperations, types::primitive::prelude::Address},
37    errors::HoprLibError,
38    exports::transport::{
39        HoprSession, HoprSessionConfigurator, OffchainPublicKey, SURB_SIZE, ServiceId, SessionId, SessionTarget,
40        transfer_session,
41    },
42};
43use hopr_utils::{
44    network_types::{
45        prelude::{ConnectedUdpStream, IpOrHost, IpProtocol, SealedHost, UdpStreamParallelism},
46        udp::ForeignDataMode,
47    },
48    runtime::Abortable,
49};
50use human_bandwidth::re::bandwidth::Bandwidth;
51use serde::{Deserialize, Serialize};
52use serde_with::serde_as;
53use tokio::net::TcpListener;
54use tracing::{debug, error, info};
55
56/// Size of the buffer for forwarding data to/from a TCP stream.
57pub const HOPR_TCP_BUFFER_SIZE: usize = 4096;
58
59/// Size of the buffer for forwarding data to/from a UDP stream.
60pub const HOPR_UDP_BUFFER_SIZE: usize = 16384;
61
62/// Size of the queue (back-pressure) for data incoming from a UDP stream.
63pub const HOPR_UDP_QUEUE_SIZE: usize = 8192;
64
65#[cfg(all(feature = "telemetry", not(test)))]
66lazy_static::lazy_static! {
67    static ref METRIC_ACTIVE_CLIENTS: hopr_types::telemetry::MultiGauge = hopr_types::telemetry::MultiGauge::new(
68        "hopr_session_hoprd_clients",
69        "Number of clients connected at this Entry node",
70        &["type"]
71    ).unwrap();
72}
73
74#[cfg(feature = "explicit-path")]
75/// Temporary compatibility alias while stored listener metadata is shared between
76/// hop-count and explicit-path session APIs.
77pub type Routing = RoutingOptions;
78
79#[cfg(not(feature = "explicit-path"))]
80pub type Routing = HopRouting;
81
82#[serde_as]
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
84/// Session target specification.
85pub enum SessionTargetSpec {
86    Plain(String),
87    Sealed(#[serde_as(as = "serde_with::base64::Base64")] Vec<u8>),
88    Service(ServiceId),
89}
90
91impl std::fmt::Display for SessionTargetSpec {
92    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
93        match self {
94            SessionTargetSpec::Plain(t) => write!(f, "{t}"),
95            SessionTargetSpec::Sealed(t) => write!(f, "$${}", base64::prelude::BASE64_URL_SAFE.encode(t)),
96            SessionTargetSpec::Service(t) => write!(f, "#{t}"),
97        }
98    }
99}
100
101impl FromStr for SessionTargetSpec {
102    type Err = HoprLibError;
103
104    fn from_str(s: &str) -> Result<Self, Self::Err> {
105        Ok(if let Some(stripped) = s.strip_prefix("$$") {
106            Self::Sealed(
107                base64::prelude::BASE64_URL_SAFE
108                    .decode(stripped)
109                    .map_err(|e| HoprLibError::Other(e.into()))?,
110            )
111        } else if let Some(stripped) = s.strip_prefix("#") {
112            Self::Service(
113                stripped
114                    .parse()
115                    .map_err(|_| HoprLibError::GeneralError("cannot parse service id".into()))?,
116            )
117        } else {
118            Self::Plain(s.to_owned())
119        })
120    }
121}
122
123impl SessionTargetSpec {
124    pub fn into_target(self, protocol: IpProtocol) -> Result<SessionTarget, HoprLibError> {
125        Ok(match (protocol, self) {
126            (IpProtocol::TCP, SessionTargetSpec::Plain(plain)) => {
127                SessionTarget::TcpStream(IpOrHost::from_str(&plain).map(SealedHost::from)?)
128            }
129            (IpProtocol::UDP, SessionTargetSpec::Plain(plain)) => {
130                SessionTarget::UdpStream(IpOrHost::from_str(&plain).map(SealedHost::from)?)
131            }
132            (IpProtocol::TCP, SessionTargetSpec::Sealed(enc)) => {
133                SessionTarget::TcpStream(SealedHost::Sealed(enc.into_boxed_slice()))
134            }
135            (IpProtocol::UDP, SessionTargetSpec::Sealed(enc)) => {
136                SessionTarget::UdpStream(SealedHost::Sealed(enc.into_boxed_slice()))
137            }
138            (_, SessionTargetSpec::Service(id)) => SessionTarget::ExitNode(id),
139        })
140    }
141}
142
143/// A single client connected to a session listener.
144#[derive(Debug)]
145pub struct ClientEntry {
146    /// The socket address of the connected client.
147    pub sock_addr: SocketAddr,
148    /// The abort handle for the client's session processing task.
149    pub abort_handle: AbortHandle,
150    /// The per-session configurator.
151    pub configurator: HoprSessionConfigurator,
152}
153
154/// Entry stored in the session registry table.
155#[derive(Debug)]
156pub struct StoredSessionEntry {
157    /// Destination address of the Session counterparty.
158    pub destination: Address,
159    /// Target of the Session.
160    pub target: SessionTargetSpec,
161    /// Forward routing options used for the Session.
162    pub forward_path: Routing,
163    /// Return routing options used for the Session.
164    pub return_path: Routing,
165    /// The maximum number of client sessions that the listener can spawn.
166    pub max_client_sessions: usize,
167    /// The maximum number of SURB packets that can be sent upstream.
168    pub max_surb_upstream: Option<human_bandwidth::re::bandwidth::Bandwidth>,
169    /// The amount of response data the Session counterparty can deliver back to us, without us
170    /// having to request it.
171    pub response_buffer: Option<bytesize::ByteSize>,
172    /// How many Sessions to pool for clients.
173    pub session_pool: Option<usize>,
174    /// The abort handle for the Session processing.
175    pub abort_handle: AbortHandle,
176
177    clients: Arc<DashMap<SessionId, ClientEntry>>,
178}
179
180impl StoredSessionEntry {
181    pub fn get_clients(&self) -> &Arc<DashMap<SessionId, ClientEntry>> {
182        &self.clients
183    }
184}
185
186/// This function first tries to parse `requested` as the `ip:port` host pair.
187/// If that does not work, it tries to parse `requested` as a single IP address
188/// and as a `:` prefixed port number. Whichever of those fails, is replaced by the corresponding
189/// part from the given `default`.
190pub fn build_binding_host(requested: Option<&str>, default: std::net::SocketAddr) -> std::net::SocketAddr {
191    match requested.map(|r| std::net::SocketAddr::from_str(r).map_err(|_| r)) {
192        Some(Err(requested)) => {
193            // If the requested host is not parseable as a whole as `SocketAddr`, try only its parts
194            debug!(requested, %default, "using partially default listen host");
195            std::net::SocketAddr::new(
196                requested.parse().unwrap_or(default.ip()),
197                requested
198                    .strip_prefix(":")
199                    .and_then(|p| u16::from_str(p).ok())
200                    .unwrap_or(default.port()),
201            )
202        }
203        Some(Ok(requested)) => {
204            debug!(%requested, "using requested listen host");
205            requested
206        }
207        None => {
208            debug!(%default, "using default listen host");
209            default
210        }
211    }
212}
213
214#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
215pub struct ListenerId(pub IpProtocol, pub std::net::SocketAddr);
216
217impl std::fmt::Display for ListenerId {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        write!(f, "{}://{}:{}", self.0, self.1.ip(), self.1.port())
220    }
221}
222
223#[derive(Default)]
224pub struct ListenerJoinHandles(pub DashMap<ListenerId, StoredSessionEntry>);
225
226impl ListenerJoinHandles {
227    /// Finds the [`HoprSessionConfigurator`] for the given session ID across all listeners.
228    pub fn find_configurator(&self, session_id: &SessionId) -> Option<HoprSessionConfigurator> {
229        self.0.iter().find_map(|entry| {
230            entry
231                .value()
232                .get_clients()
233                .get(session_id)
234                .map(|client| client.value().configurator.clone())
235        })
236    }
237}
238
239impl Abortable for ListenerJoinHandles {
240    fn abort_task(&self) {
241        self.0.alter_all(|_, v| {
242            v.abort_handle.abort();
243            v
244        });
245    }
246
247    fn was_aborted(&self) -> bool {
248        self.0.iter().all(|v| v.abort_handle.is_aborted())
249    }
250}
251
252// ---------------------------------------------------------------------------
253// Generic SessionFactory adapters
254// ---------------------------------------------------------------------------
255
256#[async_trait::async_trait]
257pub trait SessionFactory: Clone + Send + Sync + 'static {
258    type Cfg: Clone + Send + 'static;
259
260    /// Creates a new Session with the given destination, target and configuration.
261    async fn create_session(
262        &self,
263        dest: Address,
264        target: SessionTarget,
265        cfg: Self::Cfg,
266    ) -> Result<(HoprSession, HoprSessionConfigurator), anyhow::Error>;
267
268    /// Derives the forward and return routing options from the given configuration.
269    fn routing_from_cfg(&self, cfg: &Self::Cfg) -> Result<(Routing, Routing), anyhow::Error>;
270
271    /// Derives the listener limits (max SURB upstream and response buffer) from the given configuration.
272    fn listener_limits(&self, cfg: &Self::Cfg)
273    -> (Option<human_bandwidth::re::bandwidth::Bandwidth>, Option<ByteSize>);
274
275    /// Returns the idle timeout duration for sessions created by this factory, if any.
276    fn session_idle_timeout(&self) -> Option<std::time::Duration>;
277}
278
279pub struct HopSessionFactory<Chain, Graph, Net, TMgr> {
280    hopr: Arc<Hopr<Chain, Graph, Net, TMgr>>,
281}
282
283impl<Chain, Graph, Net, TMgr> HopSessionFactory<Chain, Graph, Net, TMgr> {
284    pub fn new(hopr: Arc<Hopr<Chain, Graph, Net, TMgr>>) -> Self {
285        Self { hopr }
286    }
287}
288
289impl<Chain, Graph, Net, TMgr> Clone for HopSessionFactory<Chain, Graph, Net, TMgr> {
290    fn clone(&self) -> Self {
291        Self {
292            hopr: self.hopr.clone(),
293        }
294    }
295}
296
297#[async_trait::async_trait]
298impl<Chain, Graph, Net, TMgr> SessionFactory for HopSessionFactory<Chain, Graph, Net, TMgr>
299where
300    Chain: HoprChainApi + Clone + Send + Sync + 'static,
301    Graph: NetworkGraphView<NodeId = OffchainPublicKey>
302        + NetworkGraphUpdate
303        + NetworkGraphWrite<NodeId = OffchainPublicKey>
304        + NetworkGraphTraverse<NodeId = OffchainPublicKey>
305        + Clone
306        + Send
307        + Sync
308        + 'static,
309    <Graph as NetworkGraphTraverse>::Observed: EdgeObservableRead + Send + 'static,
310    <Graph as NetworkGraphWrite>::Observed: EdgeObservableWrite + Send,
311    Net: NetworkView + NetworkStreamControl + Send + Sync + Clone + 'static,
312    TMgr: Send + Sync + 'static,
313{
314    type Cfg = HoprSessionClientConfig;
315
316    async fn create_session(
317        &self,
318        dest: Address,
319        target: SessionTarget,
320        cfg: Self::Cfg,
321    ) -> Result<(HoprSession, HoprSessionConfigurator), anyhow::Error> {
322        Ok(HoprSessionClientOperations::connect_to(self.hopr.as_ref(), dest, target, cfg).await?)
323    }
324
325    fn routing_from_cfg(&self, cfg: &Self::Cfg) -> Result<(Routing, Routing), anyhow::Error> {
326        Ok((cfg.forward_path.into(), cfg.return_path.into()))
327    }
328
329    fn listener_limits(
330        &self,
331        cfg: &Self::Cfg,
332    ) -> (Option<human_bandwidth::re::bandwidth::Bandwidth>, Option<ByteSize>) {
333        (
334            cfg.surb_management
335                .map(|v| Bandwidth::from_bps(v.max_surbs_per_sec * SURB_SIZE as u64)),
336            cfg.surb_management
337                .map(|v| ByteSize::b(v.target_surb_buffer_size * SURB_SIZE as u64)),
338        )
339    }
340
341    fn session_idle_timeout(&self) -> Option<std::time::Duration> {
342        Some(self.hopr.config().protocol.session.idle_timeout)
343    }
344}
345
346#[cfg(feature = "explicit-path")]
347pub struct ExplicitPathSessionFactory<Chain, Graph, Net, TMgr> {
348    hopr: Arc<Hopr<Chain, Graph, Net, TMgr>>,
349}
350
351#[cfg(feature = "explicit-path")]
352impl<Chain, Graph, Net, TMgr> ExplicitPathSessionFactory<Chain, Graph, Net, TMgr> {
353    pub fn new(hopr: Arc<Hopr<Chain, Graph, Net, TMgr>>) -> Self {
354        Self { hopr }
355    }
356}
357
358#[cfg(feature = "explicit-path")]
359impl<Chain, Graph, Net, TMgr> Clone for ExplicitPathSessionFactory<Chain, Graph, Net, TMgr> {
360    fn clone(&self) -> Self {
361        Self {
362            hopr: self.hopr.clone(),
363        }
364    }
365}
366
367#[cfg(feature = "explicit-path")]
368#[async_trait::async_trait]
369impl<Chain, Graph, Net, TMgr> SessionFactory for ExplicitPathSessionFactory<Chain, Graph, Net, TMgr>
370where
371    Chain: HoprChainApi + Clone + Send + Sync + 'static,
372    Graph: NetworkGraphView<NodeId = OffchainPublicKey>
373        + NetworkGraphUpdate
374        + NetworkGraphWrite<NodeId = OffchainPublicKey>
375        + NetworkGraphTraverse<NodeId = OffchainPublicKey>
376        + Clone
377        + Send
378        + Sync
379        + 'static,
380    <Graph as NetworkGraphTraverse>::Observed: EdgeObservableRead + Send + 'static,
381    <Graph as NetworkGraphWrite>::Observed: EdgeObservableWrite + Send,
382    Net: NetworkView + NetworkStreamControl + Send + Sync + Clone + 'static,
383    TMgr: Send + Sync + 'static,
384{
385    type Cfg = HoprSessionClientExplicitPathConfig;
386
387    async fn create_session(
388        &self,
389        dest: Address,
390        target: SessionTarget,
391        cfg: Self::Cfg,
392    ) -> Result<(HoprSession, HoprSessionConfigurator), anyhow::Error> {
393        Ok(self.hopr.connect_to_using_explicit_path(dest, target, cfg).await?)
394    }
395
396    fn routing_from_cfg(&self, cfg: &Self::Cfg) -> Result<(Routing, Routing), anyhow::Error> {
397        let forward_path = RoutingOptions::IntermediatePath(
398            cfg.forward_path
399                .clone()
400                .try_into()
401                .map_err(|e| anyhow!("invalid explicit forward path: {e}"))?,
402        );
403        let return_path = RoutingOptions::IntermediatePath(
404            cfg.return_path
405                .clone()
406                .try_into()
407                .map_err(|e| anyhow!("invalid explicit return path: {e}"))?,
408        );
409        Ok((forward_path, return_path))
410    }
411
412    fn listener_limits(
413        &self,
414        cfg: &Self::Cfg,
415    ) -> (Option<human_bandwidth::re::bandwidth::Bandwidth>, Option<ByteSize>) {
416        (
417            cfg.surb_management
418                .map(|v| Bandwidth::from_bps(v.max_surbs_per_sec * SURB_SIZE as u64)),
419            cfg.surb_management
420                .map(|v| ByteSize::b(v.target_surb_buffer_size * SURB_SIZE as u64)),
421        )
422    }
423
424    fn session_idle_timeout(&self) -> Option<std::time::Duration> {
425        Some(self.hopr.config().protocol.session.idle_timeout)
426    }
427}
428
429type SessionPoolInner = Arc<parking_lot::Mutex<VecDeque<(HoprSession, HoprSessionConfigurator)>>>;
430
431pub struct SessionPool {
432    pool: Option<SessionPoolInner>,
433    ah: Option<AbortHandle>,
434}
435
436impl SessionPool {
437    pub const MAX_SESSION_POOL_SIZE: usize = 5;
438
439    pub async fn new<T: SessionFactory>(
440        size: usize,
441        dst: Address,
442        target: SessionTarget,
443        cfg: T::Cfg,
444        factory: T,
445    ) -> Result<Self, anyhow::Error> {
446        let pool = Arc::new(parking_lot::Mutex::new(VecDeque::with_capacity(size)));
447        let factory_clone = factory.clone();
448        let pool_clone = pool.clone();
449        futures::stream::iter(0..size.min(Self::MAX_SESSION_POOL_SIZE))
450            .map(Ok)
451            .try_for_each_concurrent(Self::MAX_SESSION_POOL_SIZE, move |i| {
452                let pool = pool_clone.clone();
453                let factory = factory_clone.clone();
454                let target = target.clone();
455                let cfg = cfg.clone();
456                async move {
457                    match factory.create_session(dst, target.clone(), cfg.clone()).await {
458                        Ok((session, configurator)) => {
459                            debug!(session_id = %session.id(), num_session = i, "created a new session in pool");
460                            pool.lock().push_back((session, configurator));
461                            Ok(())
462                        }
463                        Err(error) => {
464                            error!(%error, num_session = i, "failed to establish session for pool");
465                            Err(anyhow!("failed to establish session #{i} in pool to {dst}: {error}"))
466                        }
467                    }
468                }
469            })
470            .await?;
471
472        if let Some(timeout) = factory.session_idle_timeout().filter(|_| !pool.lock().is_empty()) {
473            let pool_clone_1 = pool.clone();
474            let pool_clone_2 = pool.clone();
475            Ok(Self {
476                pool: Some(pool),
477                ah: Some(hopr_utils::spawn_as_abortable!(
478                    futures_time::stream::interval(futures_time::time::Duration::from(
479                        std::time::Duration::from_secs(1).max(timeout / 2)
480                    ))
481                    .take_while(move |_| futures::future::ready(!pool_clone_1.lock().is_empty()))
482                    .for_each(move |_| {
483                        let pool = pool_clone_2.clone();
484                        async move {
485                            let configurators: Vec<_> = pool.lock().iter().map(|(_, cfg)| cfg.clone()).collect();
486                            let mut dead_ids = Vec::new();
487                            for configurator in &configurators {
488                                if let Err(error) = configurator.ping().await {
489                                    let id = *configurator.id();
490                                    error!(%error, session_id = %id, "session in pool is not alive, will remove");
491                                    dead_ids.push(id);
492                                }
493                            }
494                            if !dead_ids.is_empty() {
495                                pool.lock().retain(|(_, cfg)| !dead_ids.contains(cfg.id()));
496                            }
497                        }
498                    })
499                )),
500            })
501        } else {
502            Ok(Self { pool: None, ah: None })
503        }
504    }
505
506    pub fn pop(&mut self) -> Option<(HoprSession, HoprSessionConfigurator)> {
507        self.pool.as_ref().and_then(|pool| pool.lock().pop_front())
508    }
509}
510
511impl Drop for SessionPool {
512    fn drop(&mut self) {
513        if let Some(ah) = self.ah.take() {
514            ah.abort();
515        }
516    }
517}
518
519#[allow(clippy::too_many_arguments)]
520pub async fn create_tcp_client_binding<T: SessionFactory>(
521    bind_host: std::net::SocketAddr,
522    port_range: Option<String>,
523    factory: T,
524    open_listeners: Arc<ListenerJoinHandles>,
525    destination: Address,
526    target_spec: SessionTargetSpec,
527    config: T::Cfg,
528    use_session_pool: Option<usize>,
529    max_client_sessions: Option<usize>,
530) -> Result<(std::net::SocketAddr, Option<SessionId>, usize), BindError> {
531    // Bind the TCP socket first
532    let (bound_host, tcp_listener) = tcp_listen_on(bind_host, port_range).await.map_err(|e| {
533        if e.kind() == std::io::ErrorKind::AddrInUse {
534            BindError::ListenHostAlreadyUsed
535        } else {
536            BindError::UnknownFailure(format!("failed to start TCP listener on {bind_host}: {e}"))
537        }
538    })?;
539    info!(%bound_host, "TCP session listener bound");
540
541    // For each new TCP connection coming to the listener,
542    // open a Session with the same parameters
543    let target = target_spec
544        .clone()
545        .into_target(IpProtocol::TCP)
546        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
547    let (forward_path, return_path) = factory
548        .routing_from_cfg(&config)
549        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
550    let (max_surb_upstream, response_buffer) = factory.listener_limits(&config);
551
552    // Create a session pool if requested
553    let session_pool_size = use_session_pool.unwrap_or(0);
554    let mut session_pool = SessionPool::new(
555        session_pool_size,
556        destination,
557        target.clone(),
558        config.clone(),
559        factory.clone(),
560    )
561    .await
562    .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
563
564    let active_sessions = Arc::new(DashMap::new());
565    let mut max_clients = max_client_sessions.unwrap_or(5).max(1);
566
567    if max_clients < session_pool_size {
568        max_clients = session_pool_size;
569    }
570
571    let config_clone = config.clone();
572    // Create an abort handler for the listener
573    let (abort_handle, abort_reg) = AbortHandle::new_pair();
574    let active_sessions_clone = active_sessions.clone();
575    hopr_utils::runtime::prelude::spawn(async move {
576        let active_sessions_clone_2 = active_sessions_clone.clone();
577
578        futures::stream::Abortable::new(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener), abort_reg)
579            .and_then(|sock| async { Ok((sock.peer_addr()?, sock)) })
580            .for_each(move |accepted_client| {
581                let data = config_clone.clone();
582                let target = target.clone();
583                let factory = factory.clone();
584                let active_sessions = active_sessions_clone_2.clone();
585                let has_capacity = accepted_client.is_ok() && active_sessions.len() < max_clients;
586                let maybe_pooled = has_capacity.then(|| session_pool.pop()).flatten();
587
588                async move {
589                    match accepted_client {
590                        Ok((sock_addr, mut stream)) => {
591                            debug!(?sock_addr, "incoming TCP connection");
592
593                            // Check that we are still within the quota,
594                            // otherwise shutdown the new client immediately
595                            if active_sessions.len() >= max_clients {
596                                error!(?bind_host, "no more client slots available at listener");
597                                use tokio::io::AsyncWriteExt;
598                                if let Err(error) = stream.shutdown().await {
599                                    error!(%error, ?sock_addr, "failed to shutdown TCP connection");
600                                }
601                                return;
602                            }
603
604                            // See if we still have some session pooled
605                            let (session, configurator) = match maybe_pooled {
606                                Some((s, c)) => {
607                                    debug!(session_id = %s.id(), "using pooled session");
608                                    (s, c)
609                                }
610                                None => {
611                                    debug!("no more active sessions in the pool, creating a new one");
612                                    match factory.create_session(destination, target, data).await {
613                                        Ok((s, c)) => (s, c),
614                                        Err(error) => {
615                                            error!(%error, "failed to establish session");
616                                            return;
617                                        }
618                                    }
619                                }
620                            };
621
622                            let session_id = *session.id();
623                            debug!(?sock_addr, %session_id, "new session for incoming TCP connection");
624
625                            let (abort_handle, abort_reg) = AbortHandle::new_pair();
626                            active_sessions.insert(
627                                session_id,
628                                ClientEntry {
629                                    sock_addr,
630                                    abort_handle,
631                                    configurator,
632                                },
633                            );
634
635                            #[cfg(all(feature = "telemetry", not(test)))]
636                            METRIC_ACTIVE_CLIENTS.increment(&["tcp"], 1.0);
637
638                            hopr_utils::runtime::prelude::spawn(
639                                // The stream either terminates naturally (by the client closing the TCP connection)
640                                // or is terminated via the abort handle.
641                                bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE, Some(abort_reg)).then(
642                                    move |_| async move {
643                                        // Regardless how the session ended, remove the abort handle
644                                        // from the map
645                                        active_sessions.remove(&session_id);
646
647                                        debug!(%session_id, "tcp session has ended");
648
649                                        #[cfg(all(feature = "telemetry", not(test)))]
650                                        METRIC_ACTIVE_CLIENTS.decrement(&["tcp"], 1.0);
651                                    },
652                                ),
653                            );
654                        }
655                        Err(error) => error!(%error, "failed to accept connection"),
656                    }
657                }
658            })
659            .await;
660
661        // Once the listener is done, abort all active sessions created by the listener
662        active_sessions_clone.iter().for_each(|entry| {
663            let client = entry.value();
664            debug!(session_id = %entry.key(), sock_addr = ?client.sock_addr, "aborting opened TCP session after listener has been closed");
665            client.abort_handle.abort()
666        });
667    });
668
669    open_listeners.0.insert(
670        ListenerId(hopr_utils::network_types::types::IpProtocol::TCP, bound_host),
671        StoredSessionEntry {
672            destination,
673            target: target_spec,
674            forward_path,
675            return_path,
676            clients: active_sessions,
677            max_client_sessions: max_clients,
678            max_surb_upstream,
679            response_buffer,
680            session_pool: Some(session_pool_size),
681            abort_handle,
682        },
683    );
684    Ok((bound_host, None, max_clients))
685}
686
687#[derive(Debug, thiserror::Error)]
688pub enum BindError {
689    #[error("conflict detected: listen host already in use")]
690    ListenHostAlreadyUsed,
691
692    #[error("unknown failure: {0}")]
693    UnknownFailure(String),
694}
695
696pub async fn create_udp_client_binding<T: SessionFactory>(
697    bind_host: std::net::SocketAddr,
698    port_range: Option<String>,
699    factory: T,
700    open_listeners: Arc<ListenerJoinHandles>,
701    destination: Address,
702    target_spec: SessionTargetSpec,
703    config: T::Cfg,
704) -> Result<(std::net::SocketAddr, Option<SessionId>, usize), BindError> {
705    // Bind the UDP socket first
706    let (bound_host, udp_socket) = udp_bind_to(bind_host, port_range).await.map_err(|e| {
707        if e.kind() == std::io::ErrorKind::AddrInUse {
708            BindError::ListenHostAlreadyUsed
709        } else {
710            BindError::UnknownFailure(format!("failed to start UDP listener on {bind_host}: {e}"))
711        }
712    })?;
713
714    info!(%bound_host, "UDP session listener bound");
715
716    let target = target_spec
717        .clone()
718        .into_target(IpProtocol::UDP)
719        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
720    let (forward_path, return_path) = factory
721        .routing_from_cfg(&config)
722        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
723    let (max_surb_upstream, response_buffer) = factory.listener_limits(&config);
724
725    // Create a single session for the UDP socket
726    let (session, configurator) = factory
727        .create_session(destination, target, config.clone())
728        .await
729        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
730
731    let open_listeners_clone = open_listeners.clone();
732    let listener_id = ListenerId(hopr_utils::network_types::types::IpProtocol::UDP, bound_host);
733
734    // Create an abort handle so that the Session can be terminated by aborting
735    // the UDP stream first. Because under the hood, the bind_session_to_stream uses
736    // `transfer_session` which in turn uses `copy_duplex_abortable`, aborting the
737    // `udp_socket` will:
738    //
739    // 1. Initiate graceful shutdown of `udp_socket`
740    // 2. Once done, initiate a graceful shutdown of `session`
741    // 3. Finally, return from the `bind_session_to_stream` which will terminate the spawned task
742    //
743    // This is needed because the `udp_socket` cannot terminate by itself.
744    let (abort_handle, abort_reg) = AbortHandle::new_pair();
745    let clients = Arc::new(DashMap::new());
746    let max_clients: usize = 1; // Maximum number of clients for this session. Currently always 1.
747
748    // TODO: add multiple client support to UDP sessions (#7370)
749    let session_id = *session.id();
750    clients.insert(
751        session_id,
752        ClientEntry {
753            sock_addr: bound_host,
754            abort_handle: abort_handle.clone(),
755            configurator,
756        },
757    );
758    hopr_utils::runtime::prelude::spawn(async move {
759        #[cfg(all(feature = "telemetry", not(test)))]
760        METRIC_ACTIVE_CLIENTS.increment(&["udp"], 1.0);
761
762        bind_session_to_stream(session, udp_socket, HOPR_UDP_BUFFER_SIZE, Some(abort_reg)).await;
763
764        #[cfg(all(feature = "telemetry", not(test)))]
765        METRIC_ACTIVE_CLIENTS.decrement(&["udp"], 1.0);
766
767        // Once the Session closes, remove it from the list
768        open_listeners_clone.0.remove(&listener_id);
769    });
770
771    open_listeners.0.insert(
772        listener_id,
773        StoredSessionEntry {
774            destination,
775            target: target_spec,
776            forward_path,
777            return_path,
778            max_client_sessions: max_clients,
779            max_surb_upstream,
780            response_buffer,
781            session_pool: None,
782            abort_handle,
783            clients,
784        },
785    );
786    Ok((bound_host, Some(session_id), max_clients))
787}
788
789async fn try_restricted_bind<F, S, Fut>(
790    addrs: Vec<std::net::SocketAddr>,
791    range_str: &str,
792    binder: F,
793) -> std::io::Result<S>
794where
795    F: Fn(Vec<std::net::SocketAddr>) -> Fut,
796    Fut: Future<Output = std::io::Result<S>>,
797{
798    if addrs.is_empty() {
799        return Err(std::io::Error::other("no valid socket addresses found"));
800    }
801
802    let range = range_str
803        .split_once(":")
804        .and_then(
805            |(a, b)| match u16::from_str(a).and_then(|a| Ok((a, u16::from_str(b)?))) {
806                Ok((a, b)) if a <= b => Some(a..=b),
807                _ => None,
808            },
809        )
810        .ok_or(std::io::Error::other(format!("invalid port range {range_str}")))?;
811
812    for port in range {
813        let addrs = addrs
814            .iter()
815            .map(|addr| std::net::SocketAddr::new(addr.ip(), port))
816            .collect::<Vec<_>>();
817        match binder(addrs).await {
818            Ok(listener) => return Ok(listener),
819            Err(error) => debug!(%error, "listen address not usable"),
820        }
821    }
822
823    Err(std::io::Error::new(
824        std::io::ErrorKind::AddrNotAvailable,
825        format!("no valid socket addresses found within range: {range_str}"),
826    ))
827}
828
829/// Listen on a specified address with a port from an optional port range for TCP connections.
830async fn tcp_listen_on<A: std::net::ToSocketAddrs>(
831    address: A,
832    port_range: Option<String>,
833) -> std::io::Result<(std::net::SocketAddr, TcpListener)> {
834    let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
835
836    // If automatic port allocation is requested and there's a restriction on the port range
837    // (via HOPRD_SESSION_PORT_RANGE), try to find an address within that range.
838    if addrs.iter().all(|a| a.port() == 0)
839        && let Some(range_str) = port_range
840    {
841        let tcp_listener = try_restricted_bind(
842            addrs,
843            &range_str,
844            |a| async move { TcpListener::bind(a.as_slice()).await },
845        )
846        .await?;
847        return Ok((tcp_listener.local_addr()?, tcp_listener));
848    }
849
850    let tcp_listener = TcpListener::bind(addrs.as_slice()).await?;
851    Ok((tcp_listener.local_addr()?, tcp_listener))
852}
853
854pub async fn udp_bind_to<A: std::net::ToSocketAddrs>(
855    address: A,
856    port_range: Option<String>,
857) -> std::io::Result<(std::net::SocketAddr, ConnectedUdpStream)> {
858    let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
859
860    let builder = ConnectedUdpStream::builder()
861        .with_buffer_size(HOPR_UDP_BUFFER_SIZE)
862        .with_foreign_data_mode(ForeignDataMode::Discard) // discard data from UDP clients other than the first one served
863        .with_queue_size(HOPR_UDP_QUEUE_SIZE)
864        .with_receiver_parallelism(
865            std::env::var("HOPRD_SESSION_ENTRY_UDP_RX_PARALLELISM")
866                .ok()
867                .and_then(|s| s.parse::<NonZeroUsize>().ok())
868                .map(UdpStreamParallelism::Specific)
869                .unwrap_or(UdpStreamParallelism::Auto),
870        );
871
872    // If automatic port allocation is requested and there's a restriction on the port range
873    // (via HOPRD_SESSION_PORT_RANGE), try to find an address within that range.
874    if addrs.iter().all(|a| a.port() == 0)
875        && let Some(range_str) = port_range
876    {
877        let udp_listener = try_restricted_bind(addrs, &range_str, |addrs| {
878            futures::future::ready(builder.clone().build(addrs.as_slice()))
879        })
880        .await?;
881
882        return Ok((*udp_listener.bound_address(), udp_listener));
883    }
884
885    let udp_socket = builder.build(address)?;
886    Ok((*udp_socket.bound_address(), udp_socket))
887}
888
889async fn bind_session_to_stream<T>(
890    mut session: HoprSession,
891    mut stream: T,
892    max_buf: usize,
893    abort_reg: Option<AbortRegistration>,
894) where
895    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
896{
897    let session_id = *session.id();
898    match transfer_session(&mut session, &mut stream, max_buf, abort_reg).await {
899        Ok((session_to_stream_bytes, stream_to_session_bytes)) => info!(
900            session_id = ?session_id,
901            session_to_stream_bytes, stream_to_session_bytes, "client session ended",
902        ),
903        Err(error) => error!(
904            session_id = ?session_id,
905            %error,
906            "error during data transfer"
907        ),
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use anyhow::Context;
914    use futures::{
915        FutureExt, StreamExt,
916        channel::mpsc::{UnboundedReceiver, UnboundedSender},
917    };
918    use futures_time::future::FutureExt as TimeFutureExt;
919    use hopr_api::types::crypto::crypto_traits::Randomizable;
920    use hopr_lib::{
921        HopRouting,
922        api::types::{
923            internal::{
924                prelude::HoprPseudonym,
925                routing::{DestinationRouting, RoutingOptions},
926            },
927            primitive::prelude::Address,
928        },
929        exports::transport::{ApplicationData, ApplicationDataIn, ApplicationDataOut, HoprSession, SessionId},
930    };
931    use hopr_transport::session::HoprSessionConfig;
932    use tokio::io::{AsyncReadExt, AsyncWriteExt};
933
934    use super::*;
935
936    fn loopback_transport() -> (
937        UnboundedSender<(DestinationRouting, ApplicationDataOut)>,
938        UnboundedReceiver<ApplicationDataIn>,
939    ) {
940        let (input_tx, input_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
941        let (output_tx, output_rx) = futures::channel::mpsc::unbounded::<ApplicationDataIn>();
942        tokio::task::spawn(
943            input_rx
944                .map(|(_, data)| {
945                    Ok(ApplicationDataIn {
946                        data: data.data,
947                        packet_info: Default::default(),
948                    })
949                })
950                .forward(output_tx)
951                .map(|e| tracing::debug!(?e, "loopback transport completed")),
952        );
953
954        (input_tx, output_rx)
955    }
956
957    #[tokio::test]
958    async fn hoprd_session_connection_should_create_a_working_tcp_socket_through_which_data_can_be_sent_and_received()
959    -> anyhow::Result<()> {
960        let session_id = SessionId::new(4567u64, HoprPseudonym::random());
961        let peer: Address = "0x5112D584a1C72Fc250176B57aEba5fFbbB287D8F".parse()?;
962        let cfg = HoprSessionConfig::default();
963        let session = HoprSession::new(
964            session_id,
965            DestinationRouting::forward_only(peer, RoutingOptions::IntermediatePath(Default::default())),
966            cfg,
967            loopback_transport(),
968            None,
969        )?;
970
971        let (bound_addr, tcp_listener) = tcp_listen_on(("127.0.0.1", 0), None)
972            .await
973            .context("listen_on failed")?;
974
975        tokio::task::spawn(async move {
976            match tcp_listener.accept().await {
977                Ok((stream, _)) => bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE, None).await,
978                Err(e) => error!("failed to accept connection: {e}"),
979            }
980        });
981
982        let mut tcp_stream = tokio::net::TcpStream::connect(bound_addr)
983            .await
984            .context("connect failed")?;
985
986        let data = vec![b"hello", b"world", b"this ", b"is   ", b"    a", b" test"];
987
988        for d in data.clone().into_iter() {
989            tcp_stream.write_all(d).await.context("write failed")?;
990        }
991
992        for d in data.iter() {
993            let mut buf = vec![0; d.len()];
994            tcp_stream.read_exact(&mut buf).await.context("read failed")?;
995        }
996
997        Ok(())
998    }
999
1000    #[test_log::test(tokio::test)]
1001    async fn hoprd_session_connection_should_create_a_working_udp_socket_through_which_data_can_be_sent_and_received()
1002    -> anyhow::Result<()> {
1003        let session_id = SessionId::new(4567u64, HoprPseudonym::random());
1004        let peer: Address = "0x5112D584a1C72Fc250176B57aEba5fFbbB287D8F".parse()?;
1005        let cfg = HoprSessionConfig::default();
1006        let session = HoprSession::new(
1007            session_id,
1008            DestinationRouting::forward_only(peer, RoutingOptions::IntermediatePath(Default::default())),
1009            cfg,
1010            loopback_transport(),
1011            None,
1012        )?;
1013
1014        let (listen_addr, udp_listener) = udp_bind_to(("127.0.0.1", 0), None)
1015            .await
1016            .context("udp_bind_to failed")?;
1017
1018        let (abort_handle, abort_registration) = AbortHandle::new_pair();
1019        let jh = tokio::task::spawn(bind_session_to_stream(
1020            session,
1021            udp_listener,
1022            ApplicationData::PAYLOAD_SIZE,
1023            Some(abort_registration),
1024        ));
1025
1026        let mut udp_stream = ConnectedUdpStream::builder()
1027            .with_buffer_size(ApplicationData::PAYLOAD_SIZE)
1028            .with_queue_size(HOPR_UDP_QUEUE_SIZE)
1029            .with_counterparty(listen_addr)
1030            .build(("127.0.0.1", 0))
1031            .context("bind failed")?;
1032
1033        let data = vec![b"hello", b"world", b"this ", b"is   ", b"    a", b" test"];
1034
1035        for d in data.clone().into_iter() {
1036            udp_stream.write_all(d).await.context("write failed")?;
1037            // ConnectedUdpStream performs flush with each write
1038        }
1039
1040        for d in data.iter() {
1041            let mut buf = vec![0; d.len()];
1042            udp_stream.read_exact(&mut buf).await.context("read failed")?;
1043        }
1044
1045        // Once aborted, the bind_session_to_stream task must terminate too
1046        abort_handle.abort();
1047        jh.timeout(futures_time::time::Duration::from_millis(200)).await??;
1048
1049        Ok(())
1050    }
1051
1052    fn stub_stored_entry() -> StoredSessionEntry {
1053        let (abort_handle, _) = AbortHandle::new_pair();
1054        StoredSessionEntry {
1055            destination: Address::default(),
1056            target: SessionTargetSpec::Plain("localhost:8080".into()),
1057            forward_path: Default::default(),
1058            return_path: Default::default(),
1059            max_client_sessions: 5,
1060            max_surb_upstream: None,
1061            response_buffer: None,
1062            session_pool: None,
1063            abort_handle,
1064            clients: Arc::new(DashMap::new()),
1065        }
1066    }
1067
1068    #[test]
1069    fn find_configurator_should_return_none_when_no_listeners() {
1070        let handles = ListenerJoinHandles::default();
1071        let session_id = SessionId::new(1234u64, HoprPseudonym::random());
1072        assert!(handles.find_configurator(&session_id).is_none());
1073    }
1074
1075    #[test]
1076    fn find_configurator_should_return_none_when_session_not_tracked() {
1077        let handles = ListenerJoinHandles::default();
1078        let listener_id = ListenerId(IpProtocol::TCP, "127.0.0.1:9091".parse().unwrap());
1079        handles.0.insert(listener_id, stub_stored_entry());
1080
1081        let session_id = SessionId::new(5678u64, HoprPseudonym::random());
1082        assert!(handles.find_configurator(&session_id).is_none());
1083    }
1084
1085    #[test]
1086    fn stored_session_entry_clients_should_start_empty() {
1087        let entry = stub_stored_entry();
1088        assert!(entry.get_clients().is_empty());
1089        assert_eq!(entry.max_client_sessions, 5);
1090    }
1091
1092    #[test]
1093    fn session_target_spec_plain_roundtrip() {
1094        let spec = SessionTargetSpec::Plain("localhost:8080".into());
1095        let s = spec.to_string();
1096        assert_eq!(s, "localhost:8080");
1097        assert_eq!(
1098            SessionTargetSpec::from_str(&s).unwrap(),
1099            SessionTargetSpec::Plain("localhost:8080".into())
1100        );
1101    }
1102
1103    #[test]
1104    fn session_target_spec_sealed_roundtrip() {
1105        let data = vec![0xde, 0xad, 0xbe, 0xef];
1106        let spec = SessionTargetSpec::Sealed(data.clone());
1107        let s = spec.to_string();
1108        assert!(s.starts_with("$$"));
1109        assert_eq!(
1110            SessionTargetSpec::from_str(&s).unwrap(),
1111            SessionTargetSpec::Sealed(data)
1112        );
1113    }
1114
1115    #[test]
1116    fn session_target_spec_service_roundtrip() {
1117        let spec = SessionTargetSpec::Service(42);
1118        let s = spec.to_string();
1119        assert_eq!(s, "#42");
1120        assert_eq!(SessionTargetSpec::from_str(&s).unwrap(), SessionTargetSpec::Service(42));
1121    }
1122
1123    #[test]
1124    fn build_binding_address() {
1125        let default = "10.0.0.1:10000".parse().unwrap();
1126
1127        let result = build_binding_host(Some("127.0.0.1:10000"), default);
1128        assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
1129
1130        let result = build_binding_host(None, default);
1131        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
1132
1133        let result = build_binding_host(Some("127.0.0.1"), default);
1134        assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
1135
1136        let result = build_binding_host(Some(":1234"), default);
1137        assert_eq!(result, "10.0.0.1:1234".parse::<std::net::SocketAddr>().unwrap());
1138
1139        let result = build_binding_host(Some(":"), default);
1140        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
1141
1142        let result = build_binding_host(Some(""), default);
1143        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
1144    }
1145}