Skip to main content

hopr_utils_session/
lib.rs

1//! Session-related utilities for HOPR
2//!
3//! This module provides utility functions and structures for managing sessions,
4//! including session lifecycle management, session data handling, and common
5//! session operations.
6
7use std::{
8    collections::VecDeque, fmt::Formatter, future::Future, hash::Hash, net::SocketAddr, num::NonZeroUsize,
9    str::FromStr, sync::Arc,
10};
11
12use anyhow::anyhow;
13use base64::Engine;
14use bytesize::ByteSize;
15use dashmap::DashMap;
16use futures::{
17    FutureExt, StreamExt, TryStreamExt,
18    future::{AbortHandle, AbortRegistration},
19};
20use hopr_api::{
21    chain::HoprChainApi,
22    graph::{
23        NetworkGraphTraverse, NetworkGraphUpdate, NetworkGraphView, NetworkGraphWrite,
24        traits::{EdgeObservableRead, EdgeObservableWrite},
25    },
26    network::NetworkStreamControl,
27};
28use hopr_async_runtime::Abortable;
29use hopr_lib::{
30    Address, Hopr, HoprSession, HoprSessionClientConfig, NetworkView, OffchainPublicKey, RoutingOptions, SURB_SIZE,
31    ServiceId, SessionId, SessionTarget, errors::HoprLibError, transfer_session,
32};
33use hopr_network_types::{
34    prelude::{ConnectedUdpStream, IpOrHost, IpProtocol, SealedHost, UdpStreamParallelism},
35    udp::ForeignDataMode,
36};
37use human_bandwidth::re::bandwidth::Bandwidth;
38use serde::{Deserialize, Serialize};
39use serde_with::serde_as;
40use tokio::net::TcpListener;
41use tracing::{debug, error, info};
42
43/// Size of the buffer for forwarding data to/from a TCP stream.
44pub const HOPR_TCP_BUFFER_SIZE: usize = 4096;
45
46/// Size of the buffer for forwarding data to/from a UDP stream.
47pub const HOPR_UDP_BUFFER_SIZE: usize = 16384;
48
49/// Size of the queue (back-pressure) for data incoming from a UDP stream.
50pub const HOPR_UDP_QUEUE_SIZE: usize = 8192;
51
52#[cfg(all(feature = "telemetry", not(test)))]
53lazy_static::lazy_static! {
54    static ref METRIC_ACTIVE_CLIENTS: hopr_metrics::MultiGauge = hopr_metrics::MultiGauge::new(
55        "hopr_session_hoprd_clients",
56        "Number of clients connected at this Entry node",
57        &["type"]
58    ).unwrap();
59}
60
61#[serde_as]
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63/// Session target specification.
64pub enum SessionTargetSpec {
65    Plain(String),
66    Sealed(#[serde_as(as = "serde_with::base64::Base64")] Vec<u8>),
67    Service(ServiceId),
68}
69
70impl std::fmt::Display for SessionTargetSpec {
71    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72        match self {
73            SessionTargetSpec::Plain(t) => write!(f, "{t}"),
74            SessionTargetSpec::Sealed(t) => write!(f, "$${}", base64::prelude::BASE64_URL_SAFE.encode(t)),
75            SessionTargetSpec::Service(t) => write!(f, "#{t}"),
76        }
77    }
78}
79
80impl FromStr for SessionTargetSpec {
81    type Err = HoprLibError;
82
83    fn from_str(s: &str) -> Result<Self, Self::Err> {
84        Ok(if let Some(stripped) = s.strip_prefix("$$") {
85            Self::Sealed(
86                base64::prelude::BASE64_URL_SAFE
87                    .decode(stripped)
88                    .map_err(|e| HoprLibError::Other(e.into()))?,
89            )
90        } else if let Some(stripped) = s.strip_prefix("#") {
91            Self::Service(
92                stripped
93                    .parse()
94                    .map_err(|_| HoprLibError::GeneralError("cannot parse service id".into()))?,
95            )
96        } else {
97            Self::Plain(s.to_owned())
98        })
99    }
100}
101
102impl SessionTargetSpec {
103    pub fn into_target(self, protocol: IpProtocol) -> Result<SessionTarget, HoprLibError> {
104        Ok(match (protocol, self) {
105            (IpProtocol::TCP, SessionTargetSpec::Plain(plain)) => {
106                SessionTarget::TcpStream(IpOrHost::from_str(&plain).map(SealedHost::from)?)
107            }
108            (IpProtocol::UDP, SessionTargetSpec::Plain(plain)) => {
109                SessionTarget::UdpStream(IpOrHost::from_str(&plain).map(SealedHost::from)?)
110            }
111            (IpProtocol::TCP, SessionTargetSpec::Sealed(enc)) => {
112                SessionTarget::TcpStream(SealedHost::Sealed(enc.into_boxed_slice()))
113            }
114            (IpProtocol::UDP, SessionTargetSpec::Sealed(enc)) => {
115                SessionTarget::UdpStream(SealedHost::Sealed(enc.into_boxed_slice()))
116            }
117            (_, SessionTargetSpec::Service(id)) => SessionTarget::ExitNode(id),
118        })
119    }
120}
121
122/// Entry stored in the session registry table.
123#[derive(Debug)]
124pub struct StoredSessionEntry {
125    /// Destination address of the Session counterparty.
126    pub destination: Address,
127    /// Target of the Session.
128    pub target: SessionTargetSpec,
129    /// Forward path used for the Session.
130    pub forward_path: RoutingOptions,
131    /// Return path used for the Session.
132    pub return_path: RoutingOptions,
133    /// The maximum number of client sessions that the listener can spawn.
134    pub max_client_sessions: usize,
135    /// The maximum number of SURB packets that can be sent upstream.
136    pub max_surb_upstream: Option<human_bandwidth::re::bandwidth::Bandwidth>,
137    /// The amount of response data the Session counterparty can deliver back to us, without us
138    /// having to request it.
139    pub response_buffer: Option<bytesize::ByteSize>,
140    /// How many Sessions to pool for clients.
141    pub session_pool: Option<usize>,
142    /// The abort handle for the Session processing.
143    pub abort_handle: AbortHandle,
144
145    clients: Arc<DashMap<SessionId, (SocketAddr, AbortHandle)>>,
146}
147
148impl StoredSessionEntry {
149    pub fn get_clients(&self) -> &Arc<DashMap<SessionId, (SocketAddr, AbortHandle)>> {
150        &self.clients
151    }
152}
153
154/// This function first tries to parse `requested` as the `ip:port` host pair.
155/// If that does not work, it tries to parse `requested` as a single IP address
156/// and as a `:` prefixed port number. Whichever of those fails, is replaced by the corresponding
157/// part from the given `default`.
158pub fn build_binding_host(requested: Option<&str>, default: std::net::SocketAddr) -> std::net::SocketAddr {
159    match requested.map(|r| std::net::SocketAddr::from_str(r).map_err(|_| r)) {
160        Some(Err(requested)) => {
161            // If the requested host is not parseable as a whole as `SocketAddr`, try only its parts
162            debug!(requested, %default, "using partially default listen host");
163            std::net::SocketAddr::new(
164                requested.parse().unwrap_or(default.ip()),
165                requested
166                    .strip_prefix(":")
167                    .and_then(|p| u16::from_str(p).ok())
168                    .unwrap_or(default.port()),
169            )
170        }
171        Some(Ok(requested)) => {
172            debug!(%requested, "using requested listen host");
173            requested
174        }
175        None => {
176            debug!(%default, "using default listen host");
177            default
178        }
179    }
180}
181
182#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
183pub struct ListenerId(pub IpProtocol, pub std::net::SocketAddr);
184
185impl std::fmt::Display for ListenerId {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        write!(f, "{}://{}:{}", self.0, self.1.ip(), self.1.port())
188    }
189}
190
191#[derive(Default)]
192pub struct ListenerJoinHandles(pub DashMap<ListenerId, StoredSessionEntry>);
193
194impl Abortable for ListenerJoinHandles {
195    fn abort_task(&self) {
196        self.0.alter_all(|_, v| {
197            v.abort_handle.abort();
198            v
199        });
200    }
201
202    fn was_aborted(&self) -> bool {
203        self.0.iter().all(|v| v.abort_handle.is_aborted())
204    }
205}
206
207pub struct SessionPool {
208    pool: Option<Arc<parking_lot::Mutex<VecDeque<HoprSession>>>>,
209    ah: Option<AbortHandle>,
210}
211
212impl SessionPool {
213    pub const MAX_SESSION_POOL_SIZE: usize = 5;
214
215    pub async fn new<Chain, Graph, Net>(
216        size: usize,
217        dst: Address,
218        target: SessionTarget,
219        cfg: HoprSessionClientConfig,
220        hopr: Arc<Hopr<Chain, Graph, Net>>,
221    ) -> Result<Self, anyhow::Error>
222    where
223        Chain: HoprChainApi + Clone + Send + Sync + 'static,
224        Graph: NetworkGraphView<NodeId = OffchainPublicKey>
225            + NetworkGraphUpdate
226            + NetworkGraphWrite<NodeId = OffchainPublicKey>
227            + NetworkGraphTraverse<NodeId = OffchainPublicKey>
228            + Clone
229            + Send
230            + Sync
231            + 'static,
232        <Graph as NetworkGraphTraverse>::Observed: EdgeObservableRead + Send + 'static,
233        <Graph as NetworkGraphWrite>::Observed: EdgeObservableWrite + Send,
234        Net: NetworkView + NetworkStreamControl + Send + Sync + Clone + 'static,
235    {
236        let pool = Arc::new(parking_lot::Mutex::new(VecDeque::with_capacity(size)));
237        let hopr_clone = hopr.clone();
238        let pool_clone = pool.clone();
239        futures::stream::iter(0..size.min(Self::MAX_SESSION_POOL_SIZE))
240            .map(Ok)
241            .try_for_each_concurrent(Self::MAX_SESSION_POOL_SIZE, move |i| {
242                let pool = pool_clone.clone();
243                let hopr = hopr_clone.clone();
244                let target = target.clone();
245                let cfg = cfg.clone();
246                async move {
247                    match hopr.connect_to(dst, target.clone(), cfg.clone()).await {
248                        Ok(s) => {
249                            debug!(session_id = %s.id(), num_session = i, "created a new session in pool");
250                            pool.lock().push_back(s);
251                            Ok(())
252                        }
253                        Err(error) => {
254                            error!(%error, num_session = i, "failed to establish session for pool");
255                            Err(anyhow!("failed to establish session #{i} in pool to {dst}: {error}"))
256                        }
257                    }
258                }
259            })
260            .await?;
261
262        // Spawn a task that periodically sends keep alive messages to the Session in the pool.
263        if !pool.lock().is_empty() {
264            let pool_clone_1 = pool.clone();
265            let pool_clone_2 = pool.clone();
266            let pool_clone_3 = pool.clone();
267            Ok(Self {
268                pool: Some(pool),
269                ah: Some(hopr_async_runtime::spawn_as_abortable!(
270                    futures_time::stream::interval(futures_time::time::Duration::from(
271                        std::time::Duration::from_secs(1).max(hopr.config().protocol.session.idle_timeout / 2)
272                    ))
273                    .take_while(move |_| {
274                        // Continue the infinite interval stream until there are sessions in the pool
275                        futures::future::ready(!pool_clone_1.lock().is_empty())
276                    })
277                    .flat_map(move |_| {
278                        // Get all SessionIds of the remaining Sessions in the pool
279                        let ids = pool_clone_2.lock().iter().map(|s| *s.id()).collect::<Vec<_>>();
280                        futures::stream::iter(ids)
281                    })
282                    .for_each(move |id| {
283                        let hopr = hopr.clone();
284                        let pool = pool_clone_3.clone();
285                        async move {
286                            // Make sure the Session is still alive, otherwise remove it from the pool
287                            if let Err(error) = hopr.keep_alive_session(&id).await {
288                                error!(%error, %dst, session_id = %id, "session in pool is not alive, removing from pool");
289                                pool.lock().retain(|s| *s.id() != id);
290                            }
291                        }
292                    })
293                ))
294            })
295        } else {
296            Ok(Self { pool: None, ah: None })
297        }
298    }
299
300    pub fn pop(&mut self) -> Option<HoprSession> {
301        self.pool.as_ref().and_then(|pool| pool.lock().pop_front())
302    }
303}
304
305impl Drop for SessionPool {
306    fn drop(&mut self) {
307        if let Some(ah) = self.ah.take() {
308            ah.abort();
309        }
310    }
311}
312
313#[allow(clippy::too_many_arguments)]
314pub async fn create_tcp_client_binding<Chain, Graph, Net>(
315    bind_host: std::net::SocketAddr,
316    port_range: Option<String>,
317    hopr: Arc<Hopr<Chain, Graph, Net>>,
318    open_listeners: Arc<ListenerJoinHandles>,
319    destination: Address,
320    target_spec: SessionTargetSpec,
321    config: HoprSessionClientConfig,
322    use_session_pool: Option<usize>,
323    max_client_sessions: Option<usize>,
324) -> Result<(std::net::SocketAddr, Option<SessionId>, usize), BindError>
325where
326    Chain: HoprChainApi + Clone + Send + Sync + 'static,
327    Graph: NetworkGraphView<NodeId = OffchainPublicKey>
328        + NetworkGraphUpdate
329        + NetworkGraphWrite<NodeId = OffchainPublicKey>
330        + NetworkGraphTraverse<NodeId = OffchainPublicKey>
331        + Clone
332        + Send
333        + Sync
334        + 'static,
335    <Graph as NetworkGraphTraverse>::Observed: EdgeObservableRead + Send + 'static,
336    <Graph as NetworkGraphWrite>::Observed: EdgeObservableWrite + Send,
337    Net: NetworkView + NetworkStreamControl + Send + Sync + Clone + 'static,
338{
339    // Bind the TCP socket first
340    let (bound_host, tcp_listener) = tcp_listen_on(bind_host, port_range).await.map_err(|e| {
341        if e.kind() == std::io::ErrorKind::AddrInUse {
342            BindError::ListenHostAlreadyUsed
343        } else {
344            BindError::UnknownFailure(format!("failed to start TCP listener on {bind_host}: {e}"))
345        }
346    })?;
347    info!(%bound_host, "TCP session listener bound");
348
349    // For each new TCP connection coming to the listener,
350    // open a Session with the same parameters
351    let target = target_spec
352        .clone()
353        .into_target(IpProtocol::TCP)
354        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
355
356    // Create a session pool if requested
357    let session_pool_size = use_session_pool.unwrap_or(0);
358    let mut session_pool = SessionPool::new(
359        session_pool_size,
360        destination,
361        target.clone(),
362        config.clone(),
363        hopr.clone(),
364    )
365    .await
366    .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
367
368    let active_sessions = Arc::new(DashMap::new());
369    let mut max_clients = max_client_sessions.unwrap_or(5).max(1);
370
371    if max_clients < session_pool_size {
372        max_clients = session_pool_size;
373    }
374
375    let config_clone = config.clone();
376    // Create an abort handler for the listener
377    let (abort_handle, abort_reg) = AbortHandle::new_pair();
378    let active_sessions_clone = active_sessions.clone();
379    hopr_async_runtime::prelude::spawn(async move {
380        let active_sessions_clone_2 = active_sessions_clone.clone();
381
382        futures::stream::Abortable::new(tokio_stream::wrappers::TcpListenerStream::new(tcp_listener), abort_reg)
383            .and_then(|sock| async { Ok((sock.peer_addr()?, sock)) })
384            .for_each(move |accepted_client| {
385                let data = config_clone.clone();
386                let target = target.clone();
387                let hopr = hopr.clone();
388                let active_sessions = active_sessions_clone_2.clone();
389
390                // Try to pop from the pool only if a client was accepted
391                let maybe_pooled_session = accepted_client.is_ok().then(|| session_pool.pop()).flatten();
392                async move {
393                    match accepted_client {
394                        Ok((sock_addr, mut stream)) => {
395                            debug!(?sock_addr, "incoming TCP connection");
396
397                            // Check that we are still within the quota,
398                            // otherwise shutdown the new client immediately
399                            if active_sessions.len() >= max_clients {
400                                error!(?bind_host, "no more client slots available at listener");
401                                use tokio::io::AsyncWriteExt;
402                                if let Err(error) = stream.shutdown().await {
403                                    error!(%error, ?sock_addr, "failed to shutdown TCP connection");
404                                }
405                                return;
406                            }
407
408                            // See if we still have some session pooled
409                            let session = match maybe_pooled_session {
410                                Some(s) => {
411                                    debug!(session_id = %s.id(), "using pooled session");
412                                    s
413                                }
414                                None => {
415                                    debug!("no more active sessions in the pool, creating a new one");
416                                    match hopr.connect_to(destination, target, data).await {
417                                        Ok(s) => s,
418                                        Err(error) => {
419                                            error!(%error, "failed to establish session");
420                                            return;
421                                        }
422                                    }
423                                }
424                            };
425
426                            let session_id = *session.id();
427                            debug!(?sock_addr, %session_id, "new session for incoming TCP connection");
428
429                            let (abort_handle, abort_reg) = AbortHandle::new_pair();
430                            active_sessions.insert(session_id, (sock_addr, abort_handle));
431
432                            #[cfg(all(feature = "telemetry", not(test)))]
433                            METRIC_ACTIVE_CLIENTS.increment(&["tcp"], 1.0);
434
435                            hopr_async_runtime::prelude::spawn(
436                                // The stream either terminates naturally (by the client closing the TCP connection)
437                                // or is terminated via the abort handle.
438                                bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE, Some(abort_reg)).then(
439                                    move |_| async move {
440                                        // Regardless how the session ended, remove the abort handle
441                                        // from the map
442                                        active_sessions.remove(&session_id);
443
444                                        debug!(%session_id, "tcp session has ended");
445
446                                        #[cfg(all(feature = "telemetry", not(test)))]
447                                        METRIC_ACTIVE_CLIENTS.decrement(&["tcp"], 1.0);
448                                    },
449                                ),
450                            );
451                        }
452                        Err(error) => error!(%error, "failed to accept connection"),
453                    }
454                }
455            })
456            .await;
457
458        // Once the listener is done, abort all active sessions created by the listener
459        active_sessions_clone.iter().for_each(|entry| {
460            let (sock_addr, handle) = entry.value();
461            debug!(session_id = %entry.key(), ?sock_addr, "aborting opened TCP session after listener has been closed");
462            handle.abort()
463        });
464    });
465
466    open_listeners.0.insert(
467        ListenerId(hopr_network_types::types::IpProtocol::TCP, bound_host),
468        StoredSessionEntry {
469            destination,
470            target: target_spec,
471            forward_path: config.forward_path.into(),
472            return_path: config.return_path.into(),
473            clients: active_sessions,
474            max_client_sessions: max_clients,
475            max_surb_upstream: config
476                .surb_management
477                .map(|v| Bandwidth::from_bps(v.max_surbs_per_sec * SURB_SIZE as u64)),
478            response_buffer: config
479                .surb_management
480                .map(|v| ByteSize::b(v.target_surb_buffer_size * SURB_SIZE as u64)),
481            session_pool: Some(session_pool_size),
482            abort_handle,
483        },
484    );
485    Ok((bound_host, None, max_clients))
486}
487
488#[derive(Debug, thiserror::Error)]
489pub enum BindError {
490    #[error("conflict detected: listen host already in use")]
491    ListenHostAlreadyUsed,
492
493    #[error("unknown failure: {0}")]
494    UnknownFailure(String),
495}
496
497pub async fn create_udp_client_binding<Chain, Graph, Net>(
498    bind_host: std::net::SocketAddr,
499    port_range: Option<String>,
500    hopr: Arc<Hopr<Chain, Graph, Net>>,
501    open_listeners: Arc<ListenerJoinHandles>,
502    destination: Address,
503    target_spec: SessionTargetSpec,
504    config: HoprSessionClientConfig,
505) -> Result<(std::net::SocketAddr, Option<SessionId>, usize), BindError>
506where
507    Chain: HoprChainApi + Clone + Send + Sync + 'static,
508    Graph: NetworkGraphView<NodeId = OffchainPublicKey>
509        + NetworkGraphUpdate
510        + NetworkGraphWrite<NodeId = OffchainPublicKey>
511        + NetworkGraphTraverse<NodeId = OffchainPublicKey>
512        + Clone
513        + Send
514        + Sync
515        + 'static,
516    <Graph as NetworkGraphTraverse>::Observed: EdgeObservableRead + Send + 'static,
517    <Graph as NetworkGraphWrite>::Observed: EdgeObservableWrite + Send,
518    Net: NetworkView + NetworkStreamControl + Send + Sync + Clone + 'static,
519{
520    // Bind the UDP socket first
521    let (bound_host, udp_socket) = udp_bind_to(bind_host, port_range).await.map_err(|e| {
522        if e.kind() == std::io::ErrorKind::AddrInUse {
523            BindError::ListenHostAlreadyUsed
524        } else {
525            BindError::UnknownFailure(format!("failed to start UDP listener on {bind_host}: {e}"))
526        }
527    })?;
528
529    info!(%bound_host, "UDP session listener bound");
530
531    let target = target_spec
532        .clone()
533        .into_target(IpProtocol::UDP)
534        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
535
536    // Create a single session for the UDP socket
537    let session = hopr
538        .connect_to(destination, target, config.clone())
539        .await
540        .map_err(|e| BindError::UnknownFailure(e.to_string()))?;
541
542    let open_listeners_clone = open_listeners.clone();
543    let listener_id = ListenerId(hopr_network_types::types::IpProtocol::UDP, bound_host);
544
545    // Create an abort handle so that the Session can be terminated by aborting
546    // the UDP stream first. Because under the hood, the bind_session_to_stream uses
547    // `transfer_session` which in turn uses `copy_duplex_abortable`, aborting the
548    // `udp_socket` will:
549    //
550    // 1. Initiate graceful shutdown of `udp_socket`
551    // 2. Once done, initiate a graceful shutdown of `session`
552    // 3. Finally, return from the `bind_session_to_stream` which will terminate the spawned task
553    //
554    // This is needed because the `udp_socket` cannot terminate by itself.
555    let (abort_handle, abort_reg) = AbortHandle::new_pair();
556    let clients = Arc::new(DashMap::new());
557    let max_clients: usize = 1; // Maximum number of clients for this session. Currently always 1.
558
559    // TODO: add multiple client support to UDP sessions (#7370)
560    let session_id = *session.id();
561    clients.insert(session_id, (bind_host, abort_handle.clone()));
562    hopr_async_runtime::prelude::spawn(async move {
563        #[cfg(all(feature = "telemetry", not(test)))]
564        METRIC_ACTIVE_CLIENTS.increment(&["udp"], 1.0);
565
566        bind_session_to_stream(session, udp_socket, HOPR_UDP_BUFFER_SIZE, Some(abort_reg)).await;
567
568        #[cfg(all(feature = "telemetry", not(test)))]
569        METRIC_ACTIVE_CLIENTS.decrement(&["udp"], 1.0);
570
571        // Once the Session closes, remove it from the list
572        open_listeners_clone.0.remove(&listener_id);
573    });
574
575    open_listeners.0.insert(
576        listener_id,
577        StoredSessionEntry {
578            destination,
579            target: target_spec,
580            forward_path: config.forward_path.into(),
581            return_path: config.return_path.into(),
582            max_client_sessions: max_clients,
583            max_surb_upstream: config
584                .surb_management
585                .map(|v| Bandwidth::from_bps(v.max_surbs_per_sec * SURB_SIZE as u64)),
586            response_buffer: config
587                .surb_management
588                .map(|v| ByteSize::b(v.target_surb_buffer_size * SURB_SIZE as u64)),
589            session_pool: None,
590            abort_handle,
591            clients,
592        },
593    );
594    Ok((bound_host, Some(session_id), max_clients))
595}
596
597async fn try_restricted_bind<F, S, Fut>(
598    addrs: Vec<std::net::SocketAddr>,
599    range_str: &str,
600    binder: F,
601) -> std::io::Result<S>
602where
603    F: Fn(Vec<std::net::SocketAddr>) -> Fut,
604    Fut: Future<Output = std::io::Result<S>>,
605{
606    if addrs.is_empty() {
607        return Err(std::io::Error::other("no valid socket addresses found"));
608    }
609
610    let range = range_str
611        .split_once(":")
612        .and_then(
613            |(a, b)| match u16::from_str(a).and_then(|a| Ok((a, u16::from_str(b)?))) {
614                Ok((a, b)) if a <= b => Some(a..=b),
615                _ => None,
616            },
617        )
618        .ok_or(std::io::Error::other(format!("invalid port range {range_str}")))?;
619
620    for port in range {
621        let addrs = addrs
622            .iter()
623            .map(|addr| std::net::SocketAddr::new(addr.ip(), port))
624            .collect::<Vec<_>>();
625        match binder(addrs).await {
626            Ok(listener) => return Ok(listener),
627            Err(error) => debug!(%error, "listen address not usable"),
628        }
629    }
630
631    Err(std::io::Error::new(
632        std::io::ErrorKind::AddrNotAvailable,
633        format!("no valid socket addresses found within range: {range_str}"),
634    ))
635}
636
637/// Listen on a specified address with a port from an optional port range for TCP connections.
638async fn tcp_listen_on<A: std::net::ToSocketAddrs>(
639    address: A,
640    port_range: Option<String>,
641) -> std::io::Result<(std::net::SocketAddr, TcpListener)> {
642    let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
643
644    // If automatic port allocation is requested and there's a restriction on the port range
645    // (via HOPRD_SESSION_PORT_RANGE), try to find an address within that range.
646    if addrs.iter().all(|a| a.port() == 0)
647        && let Some(range_str) = port_range
648    {
649        let tcp_listener = try_restricted_bind(
650            addrs,
651            &range_str,
652            |a| async move { TcpListener::bind(a.as_slice()).await },
653        )
654        .await?;
655        return Ok((tcp_listener.local_addr()?, tcp_listener));
656    }
657
658    let tcp_listener = TcpListener::bind(addrs.as_slice()).await?;
659    Ok((tcp_listener.local_addr()?, tcp_listener))
660}
661
662pub async fn udp_bind_to<A: std::net::ToSocketAddrs>(
663    address: A,
664    port_range: Option<String>,
665) -> std::io::Result<(std::net::SocketAddr, ConnectedUdpStream)> {
666    let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
667
668    let builder = ConnectedUdpStream::builder()
669        .with_buffer_size(HOPR_UDP_BUFFER_SIZE)
670        .with_foreign_data_mode(ForeignDataMode::Discard) // discard data from UDP clients other than the first one served
671        .with_queue_size(HOPR_UDP_QUEUE_SIZE)
672        .with_receiver_parallelism(
673            std::env::var("HOPRD_SESSION_ENTRY_UDP_RX_PARALLELISM")
674                .ok()
675                .and_then(|s| s.parse::<NonZeroUsize>().ok())
676                .map(UdpStreamParallelism::Specific)
677                .unwrap_or(UdpStreamParallelism::Auto),
678        );
679
680    // If automatic port allocation is requested and there's a restriction on the port range
681    // (via HOPRD_SESSION_PORT_RANGE), try to find an address within that range.
682    if addrs.iter().all(|a| a.port() == 0)
683        && let Some(range_str) = port_range
684    {
685        let udp_listener = try_restricted_bind(addrs, &range_str, |addrs| {
686            futures::future::ready(builder.clone().build(addrs.as_slice()))
687        })
688        .await?;
689
690        return Ok((*udp_listener.bound_address(), udp_listener));
691    }
692
693    let udp_socket = builder.build(address)?;
694    Ok((*udp_socket.bound_address(), udp_socket))
695}
696
697async fn bind_session_to_stream<T>(
698    mut session: HoprSession,
699    mut stream: T,
700    max_buf: usize,
701    abort_reg: Option<AbortRegistration>,
702) where
703    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
704{
705    let session_id = *session.id();
706    match transfer_session(&mut session, &mut stream, max_buf, abort_reg).await {
707        Ok((session_to_stream_bytes, stream_to_session_bytes)) => info!(
708            session_id = ?session_id,
709            session_to_stream_bytes, stream_to_session_bytes, "client session ended",
710        ),
711        Err(error) => error!(
712            session_id = ?session_id,
713            %error,
714            "error during data transfer"
715        ),
716    }
717}
718
719#[cfg(test)]
720mod tests {
721    use anyhow::Context;
722    use futures::{
723        FutureExt, StreamExt,
724        channel::mpsc::{UnboundedReceiver, UnboundedSender},
725    };
726    use futures_time::future::FutureExt as TimeFutureExt;
727    use hopr_api::types::crypto::crypto_traits::Randomizable;
728    use hopr_lib::{
729        Address, ApplicationData, ApplicationDataIn, ApplicationDataOut, HoprPseudonym, HoprSession, SessionId,
730        exports::types::internal::routing::{DestinationRouting, RoutingOptions},
731    };
732    use hopr_transport::session::HoprSessionConfig;
733    use tokio::io::{AsyncReadExt, AsyncWriteExt};
734
735    use super::*;
736
737    fn loopback_transport() -> (
738        UnboundedSender<(DestinationRouting, ApplicationDataOut)>,
739        UnboundedReceiver<ApplicationDataIn>,
740    ) {
741        let (input_tx, input_rx) = futures::channel::mpsc::unbounded::<(DestinationRouting, ApplicationDataOut)>();
742        let (output_tx, output_rx) = futures::channel::mpsc::unbounded::<ApplicationDataIn>();
743        tokio::task::spawn(
744            input_rx
745                .map(|(_, data)| {
746                    Ok(ApplicationDataIn {
747                        data: data.data,
748                        packet_info: Default::default(),
749                    })
750                })
751                .forward(output_tx)
752                .map(|e| tracing::debug!(?e, "loopback transport completed")),
753        );
754
755        (input_tx, output_rx)
756    }
757
758    #[tokio::test]
759    async fn hoprd_session_connection_should_create_a_working_tcp_socket_through_which_data_can_be_sent_and_received()
760    -> anyhow::Result<()> {
761        let session_id = SessionId::new(4567u64, HoprPseudonym::random());
762        let peer: Address = "0x5112D584a1C72Fc250176B57aEba5fFbbB287D8F".parse()?;
763        let cfg = HoprSessionConfig::default();
764        let session = HoprSession::new(
765            session_id,
766            DestinationRouting::forward_only(peer, RoutingOptions::IntermediatePath(Default::default())),
767            cfg,
768            loopback_transport(),
769            None,
770        )?;
771
772        let (bound_addr, tcp_listener) = tcp_listen_on(("127.0.0.1", 0), None)
773            .await
774            .context("listen_on failed")?;
775
776        tokio::task::spawn(async move {
777            match tcp_listener.accept().await {
778                Ok((stream, _)) => bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE, None).await,
779                Err(e) => error!("failed to accept connection: {e}"),
780            }
781        });
782
783        let mut tcp_stream = tokio::net::TcpStream::connect(bound_addr)
784            .await
785            .context("connect failed")?;
786
787        let data = vec![b"hello", b"world", b"this ", b"is   ", b"    a", b" test"];
788
789        for d in data.clone().into_iter() {
790            tcp_stream.write_all(d).await.context("write failed")?;
791        }
792
793        for d in data.iter() {
794            let mut buf = vec![0; d.len()];
795            tcp_stream.read_exact(&mut buf).await.context("read failed")?;
796        }
797
798        Ok(())
799    }
800
801    #[test_log::test(tokio::test)]
802    async fn hoprd_session_connection_should_create_a_working_udp_socket_through_which_data_can_be_sent_and_received()
803    -> anyhow::Result<()> {
804        let session_id = SessionId::new(4567u64, HoprPseudonym::random());
805        let peer: Address = "0x5112D584a1C72Fc250176B57aEba5fFbbB287D8F".parse()?;
806        let cfg = HoprSessionConfig::default();
807        let session = HoprSession::new(
808            session_id,
809            DestinationRouting::forward_only(peer, RoutingOptions::IntermediatePath(Default::default())),
810            cfg,
811            loopback_transport(),
812            None,
813        )?;
814
815        let (listen_addr, udp_listener) = udp_bind_to(("127.0.0.1", 0), None)
816            .await
817            .context("udp_bind_to failed")?;
818
819        let (abort_handle, abort_registration) = AbortHandle::new_pair();
820        let jh = tokio::task::spawn(bind_session_to_stream(
821            session,
822            udp_listener,
823            ApplicationData::PAYLOAD_SIZE,
824            Some(abort_registration),
825        ));
826
827        let mut udp_stream = ConnectedUdpStream::builder()
828            .with_buffer_size(ApplicationData::PAYLOAD_SIZE)
829            .with_queue_size(HOPR_UDP_QUEUE_SIZE)
830            .with_counterparty(listen_addr)
831            .build(("127.0.0.1", 0))
832            .context("bind failed")?;
833
834        let data = vec![b"hello", b"world", b"this ", b"is   ", b"    a", b" test"];
835
836        for d in data.clone().into_iter() {
837            udp_stream.write_all(d).await.context("write failed")?;
838            // ConnectedUdpStream performs flush with each write
839        }
840
841        for d in data.iter() {
842            let mut buf = vec![0; d.len()];
843            udp_stream.read_exact(&mut buf).await.context("read failed")?;
844        }
845
846        // Once aborted, the bind_session_to_stream task must terminate too
847        abort_handle.abort();
848        jh.timeout(futures_time::time::Duration::from_millis(200)).await??;
849
850        Ok(())
851    }
852
853    #[test]
854    fn build_binding_address() {
855        let default = "10.0.0.1:10000".parse().unwrap();
856
857        let result = build_binding_host(Some("127.0.0.1:10000"), default);
858        assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
859
860        let result = build_binding_host(None, default);
861        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
862
863        let result = build_binding_host(Some("127.0.0.1"), default);
864        assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
865
866        let result = build_binding_host(Some(":1234"), default);
867        assert_eq!(result, "10.0.0.1:1234".parse::<std::net::SocketAddr>().unwrap());
868
869        let result = build_binding_host(Some(":"), default);
870        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
871
872        let result = build_binding_host(Some(""), default);
873        assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
874    }
875}