Skip to main content

hopr_transport_session/
utils.rs

1use std::time::Duration;
2
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use hopr_async_runtime::AbortHandle;
5use hopr_network_types::prelude::DestinationRouting;
6use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataOut};
7use hopr_protocol_start::{KeepAliveFlag, KeepAliveMessage};
8use tracing::{Instrument, debug, error, instrument};
9
10use crate::{
11    AtomicSurbFlowEstimator, SessionId,
12    balancer::{BalancerStateValues, RateController, RateLimitStreamExt, SurbFlowEstimator},
13    errors::TransportSessionError,
14    types::HoprStartProtocol,
15};
16
17/// Convenience function to copy data in both directions between a [`Session`](crate::HoprSession) and arbitrary
18/// async IO stream.
19/// This function is only available with Tokio and will panic with other runtimes.
20///
21/// The `abort_stream` will terminate the transfer from the `stream` side, i.e.:
22/// 1. Initiates graceful shutdown of `stream`
23/// 2. Once done, initiates a graceful shutdown of `session`
24/// 3. The function terminates, returning the number of bytes transferred in both directions.
25#[cfg(feature = "runtime-tokio")]
26pub async fn transfer_session<S>(
27    session: &mut crate::HoprSession,
28    stream: &mut S,
29    max_buffer: usize,
30    abort_stream: Option<futures::future::AbortRegistration>,
31) -> std::io::Result<(usize, usize)>
32where
33    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
34{
35    // We can use equally sized buffer for both directions
36    tracing::debug!(
37        session_id = ?session.id(),
38        egress_buffer = max_buffer,
39        ingress_buffer = max_buffer,
40        "session buffers"
41    );
42
43    if let Some(abort_stream) = abort_stream {
44        // We only allow aborting from the "stream" side, not from the "session side"
45        // This is useful for UDP-like streams on the "stream" side, which cannot be terminated
46        // by a signal from outside (e.g.: for TCP sockets such signal is socket closure).
47        let (_, dummy) = futures::future::AbortHandle::new_pair();
48        hopr_network_types::utils::copy_duplex_abortable(
49            session,
50            stream,
51            (max_buffer, max_buffer),
52            (dummy, abort_stream),
53        )
54        .await
55        .map(|(a, b)| (a as usize, b as usize))
56    } else {
57        hopr_network_types::utils::copy_duplex(session, stream, (max_buffer, max_buffer))
58            .await
59            .map(|(a, b)| (a as usize, b as usize))
60    }
61}
62
63/// This function will use the given generator to generate an initial seeding key.
64/// It will check whether the given cache already contains a value for that key, and if not,
65/// calls the generator (with the previous value) to generate a new seeding key and retry.
66/// The function either finds a suitable free slot, inserting value generated by `value_fn` and returns the found key,
67/// or terminates with `None` when `gen` returns the initial seed again.
68pub(crate) async fn insert_into_next_slot<F, K, U, V>(
69    cache: &moka::future::Cache<K, V>,
70    mut generator: F,
71    value_fn: U,
72) -> Option<(K, V)>
73where
74    F: FnMut(Option<K>) -> K,
75    K: Copy + std::hash::Hash + Eq + Send + Sync + 'static,
76    U: FnOnce(K) -> V,
77    V: Clone + Send + Sync + 'static,
78{
79    cache.run_pending_tasks().await;
80
81    // Wrap the FnOnce so we can "consume" it exactly once,
82    // but only when we actually insert into a free slot.
83    let value_fn = std::sync::Arc::new(parking_lot::Mutex::new(Some(value_fn)));
84
85    let initial = generator(None);
86    let mut next = initial;
87    loop {
88        let value_fn = value_fn.clone();
89        let insertion_result = cache
90            .entry(next)
91            .and_try_compute_with(move |e| {
92                if e.is_none() {
93                    let f = value_fn
94                        .lock()
95                        .take()
96                        .expect("impossible: value_fn was already consumed");
97
98                    futures::future::ok::<_, ()>(moka::ops::compute::Op::Put(f(next)))
99                } else {
100                    futures::future::ok::<_, ()>(moka::ops::compute::Op::Nop)
101                }
102            })
103            .await;
104
105        // If we inserted successfully, break the loop and return the insertion key
106        if let Ok(moka::ops::compute::CompResult::Inserted(val)) = insertion_result {
107            return Some((next, val.into_value()));
108        }
109
110        // Otherwise, generate the next key
111        next = generator(Some(next));
112
113        // If generated keys made it to full loop, return failure
114        if next == initial {
115            return None;
116        }
117    }
118}
119
120/// Indicates whether the [keep-alive stream](spawn_keep_alive_stream) should notify the Session counterparty
121/// about the SURB target (Entry) or SURB level (Exit).
122#[derive(Debug, Clone)]
123pub(crate) enum SurbNotificationMode {
124    /// No keep-alive messages are sent to the Session counterparty.
125    DoNotNotify,
126    /// Session initiator notifies the Session recipient about the desired SURB target level.
127    Target,
128    /// Session recipient notifies the Session initiator about the current SURB level.
129    Level(AtomicSurbFlowEstimator),
130}
131
132/// Spawns a task for a rate-limited stream of Keep-Alive messages to the Session counterparty.
133#[instrument(level = "debug", skip(sender, routing, notification_mode, cfg))]
134pub(crate) fn spawn_keep_alive_stream<S>(
135    session_id: SessionId,
136    sender: S,
137    routing: DestinationRouting,
138    notification_mode: SurbNotificationMode,
139    cfg: std::sync::Arc<BalancerStateValues>,
140) -> (RateController, AbortHandle)
141where
142    S: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Clone + Send + Sync + Unpin + 'static,
143    S::Error: std::error::Error + Send + Sync + 'static,
144{
145    // The stream is suspended until the caller sets a rate via the Controller
146    let controller = RateController::new(0, Duration::from_secs(1));
147
148    let (ka_stream, abort_handle) = futures::stream::abortable(
149        futures::stream::repeat_with(move || match &notification_mode {
150            SurbNotificationMode::Target => HoprStartProtocol::KeepAlive(KeepAliveMessage {
151                session_id,
152                flags: KeepAliveFlag::BalancerTarget.into(),
153                additional_data: cfg.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
154            }),
155            SurbNotificationMode::Level(estimator) => HoprStartProtocol::KeepAlive(KeepAliveMessage {
156                session_id,
157                flags: KeepAliveFlag::BalancerState.into(),
158                additional_data: estimator.saturating_diff(),
159            }),
160            SurbNotificationMode::DoNotNotify => HoprStartProtocol::KeepAlive(KeepAliveMessage {
161                session_id,
162                flags: None.into(),
163                additional_data: 0,
164            }),
165        })
166        .rate_limit_with_controller(&controller),
167    );
168
169    let sender_clone = sender.clone();
170    let fwd_routing_clone = routing.clone();
171
172    // This task will automatically terminate once the returned abort handle is used.
173    debug!(%session_id, "spawning keep-alive stream");
174    hopr_async_runtime::prelude::spawn(
175        ka_stream
176            .map(move |msg| {
177                ApplicationData::try_from(msg)
178                    .map(|data| (fwd_routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
179            })
180            .map_err(TransportSessionError::from)
181            .try_for_each_concurrent(None, move |msg| {
182                let mut sender_clone = sender_clone.clone();
183                async move {
184                    sender_clone
185                        .send(msg)
186                        .await
187                        .map_err(TransportSessionError::packet_sending)
188                }
189            })
190            .then(move |res| {
191                match res {
192                    Ok(_) => tracing::debug!(
193                        component = "session",
194                        %session_id,
195                        task = "session keepalive",
196                        "background task finished"
197                    ),
198                    Err(error) => error!(%session_id, %error, "keep-alive stream failed"),
199                }
200                futures::future::ready(())
201            })
202            .in_current_span(),
203    );
204
205    (controller, abort_handle)
206}
207
208#[cfg(test)]
209mod tests {
210    use anyhow::anyhow;
211
212    use super::*;
213
214    #[tokio::test]
215    async fn test_insert_into_next_slot() -> anyhow::Result<()> {
216        let cache = moka::future::Cache::new(10);
217
218        for i in 0..5 {
219            let (k, v) = insert_into_next_slot(
220                &cache,
221                |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
222                |k| format!("foo_{k}"),
223            )
224            .await
225            .ok_or(anyhow!("should insert"))?;
226            assert_eq!(k, i);
227            assert_eq!(format!("foo_{i}"), v);
228            assert_eq!(Some(v), cache.get(&i).await);
229        }
230
231        assert!(
232            insert_into_next_slot(
233                &cache,
234                |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
235                |_| "foo".to_string()
236            )
237            .await
238            .is_none(),
239            "must not find slot when full"
240        );
241
242        Ok(())
243    }
244}