hopr_transport_session/
utils.rs

1use std::time::Duration;
2
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use hopr_async_runtime::AbortHandle;
5use hopr_crypto_packet::prelude::HoprPacket;
6use hopr_network_types::prelude::DestinationRouting;
7use hopr_protocol_app::prelude::ApplicationData;
8use tracing::{debug, error};
9
10use crate::{
11    SessionId,
12    balancer::{RateController, RateLimitStreamExt, SurbControllerWithCorrection},
13    errors::TransportSessionError,
14    types::HoprStartProtocol,
15};
16
17/// Convenience function to copy data in both directions between a [`Session`](crate::Session) and arbitrary
18/// async IO stream.
19/// This function is only available with Tokio and will panic with other runtimes.
20///
21/// The `abort_stream` will terminate the transfer from the `stream` side, i.e.:
22/// 1. Initiates graceful shutdown of `stream`
23/// 2. Once done, initiates a graceful shutdown of `session`
24/// 3. The function terminates, returning the number of bytes transferred in both directions.
25#[cfg(feature = "runtime-tokio")]
26pub async fn transfer_session<S>(
27    session: &mut crate::Session,
28    stream: &mut S,
29    max_buffer: usize,
30    abort_stream: Option<futures::future::AbortRegistration>,
31) -> std::io::Result<(usize, usize)>
32where
33    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
34{
35    // We can use equally sized buffer for both directions
36    tracing::debug!(
37        session_id = ?session.id(),
38        egress_buffer = max_buffer,
39        ingress_buffer = max_buffer,
40        "session buffers"
41    );
42
43    if let Some(abort_stream) = abort_stream {
44        // We only allow aborting from the "stream" side, not from the "session side"
45        // This is useful for UDP-like streams on the "stream" side, which cannot be terminated
46        // by a signal from outside (e.g.: for TCP sockets such signal is socket closure).
47        let (_, dummy) = futures::future::AbortHandle::new_pair();
48        hopr_network_types::utils::copy_duplex_abortable(
49            session,
50            stream,
51            (max_buffer, max_buffer),
52            (dummy, abort_stream),
53        )
54        .await
55        .map(|(a, b)| (a as usize, b as usize))
56    } else {
57        hopr_network_types::utils::copy_duplex(session, stream, (max_buffer, max_buffer))
58            .await
59            .map(|(a, b)| (a as usize, b as usize))
60    }
61}
62
63/// This function will use the given generator to generate an initial seeding key.
64/// It will check whether the given cache already contains a value for that key, and if not,
65/// calls the generator (with the previous value) to generate a new seeding key and retry.
66/// The function either finds a suitable free slot, inserting the `value` and returns the found key,
67/// or terminates with `None` when `gen` returns the initial seed again.
68pub(crate) async fn insert_into_next_slot<K, V, F>(
69    cache: &moka::future::Cache<K, V>,
70    generator: F,
71    value: V,
72) -> Option<K>
73where
74    K: Copy + std::hash::Hash + Eq + Send + Sync + 'static,
75    V: Clone + Send + Sync + 'static,
76    F: Fn(Option<K>) -> K,
77{
78    cache.run_pending_tasks().await;
79
80    let initial = generator(None);
81    let mut next = initial;
82    loop {
83        let insertion_result = cache
84            .entry(next)
85            .and_try_compute_with(|e| {
86                if e.is_none() {
87                    futures::future::ok::<_, ()>(moka::ops::compute::Op::Put(value.clone()))
88                } else {
89                    futures::future::ok::<_, ()>(moka::ops::compute::Op::Nop)
90                }
91            })
92            .await;
93
94        // If we inserted successfully, break the loop and return the insertion key
95        if let Ok(moka::ops::compute::CompResult::Inserted(_)) = insertion_result {
96            return Some(next);
97        }
98
99        // Otherwise, generate the next key
100        next = generator(Some(next));
101
102        // If generated keys made it to full loop, return failure
103        if next == initial {
104            return None;
105        }
106    }
107}
108
109pub(crate) fn spawn_keep_alive_stream<S>(
110    session_id: SessionId,
111    sender: S,
112    routing: DestinationRouting,
113) -> (SurbControllerWithCorrection, AbortHandle)
114where
115    S: futures::Sink<(DestinationRouting, ApplicationData)> + Clone + Send + Sync + Unpin + 'static,
116    S::Error: std::error::Error + Send + Sync + 'static,
117{
118    let elem = HoprStartProtocol::KeepAlive(session_id.into());
119
120    // The stream is suspended until the caller sets a rate via the Controller
121    let controller = RateController::new(0, Duration::from_secs(1));
122
123    let (ka_stream, abort_handle) =
124        futures::stream::abortable(futures::stream::repeat(elem).rate_limit_with_controller(&controller));
125
126    let sender_clone = sender.clone();
127    let fwd_routing_clone = routing.clone();
128
129    // This task will automatically terminate once the returned abort handle is used.
130    debug!(%session_id, "spawning keep-alive stream");
131    hopr_async_runtime::prelude::spawn(
132        ka_stream
133            .map(move |msg| ApplicationData::try_from(msg).map(|m| (fwd_routing_clone.clone(), m)))
134            .map_err(TransportSessionError::from)
135            .try_for_each_concurrent(None, move |msg| {
136                let mut sender_clone = sender_clone.clone();
137                async move {
138                    sender_clone
139                        .send(msg)
140                        .await
141                        .map_err(|e| TransportSessionError::PacketSendingError(e.to_string()))
142                }
143            })
144            .then(move |res| {
145                match res {
146                    Ok(_) => debug!(%session_id, "keep-alive stream done"),
147                    Err(error) => error!(%session_id, %error, "keep-alive stream failed"),
148                }
149                futures::future::ready(())
150            }),
151    );
152
153    // Currently, a keep-alive message can bear `HoprPacket::MAX_SURBS_IN_PACKET` SURBs,
154    // so the correction by this factor is applied.
155    (
156        SurbControllerWithCorrection(controller, HoprPacket::MAX_SURBS_IN_PACKET as u32),
157        abort_handle,
158    )
159}
160
161#[cfg(test)]
162mod tests {
163    use anyhow::anyhow;
164
165    use super::*;
166
167    #[tokio::test]
168    async fn test_insert_into_next_slot() -> anyhow::Result<()> {
169        let cache = moka::future::Cache::new(10);
170
171        for i in 0..5 {
172            let v = insert_into_next_slot(&cache, |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0), "foo".to_string())
173                .await
174                .ok_or(anyhow!("should insert"))?;
175            assert_eq!(v, i);
176            assert_eq!(Some("foo".to_string()), cache.get(&i).await);
177        }
178
179        assert!(
180            insert_into_next_slot(&cache, |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0), "foo".to_string())
181                .await
182                .is_none(),
183            "must not find slot when full"
184        );
185
186        Ok(())
187    }
188}