1use std::time::Duration;
2
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use hopr_async_runtime::AbortHandle;
5use hopr_network_types::prelude::DestinationRouting;
6use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataOut};
7use hopr_protocol_start::{KeepAliveFlag, KeepAliveMessage};
8use tracing::{Instrument, debug, error, instrument};
9
10use crate::{
11 AtomicSurbFlowEstimator, SessionId,
12 balancer::{BalancerStateValues, RateController, RateLimitStreamExt, SurbFlowEstimator},
13 errors::TransportSessionError,
14 types::HoprStartProtocol,
15};
16
17#[cfg(feature = "runtime-tokio")]
26pub async fn transfer_session<S>(
27 session: &mut crate::HoprSession,
28 stream: &mut S,
29 max_buffer: usize,
30 abort_stream: Option<futures::future::AbortRegistration>,
31) -> std::io::Result<(usize, usize)>
32where
33 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
34{
35 tracing::debug!(
37 session_id = ?session.id(),
38 egress_buffer = max_buffer,
39 ingress_buffer = max_buffer,
40 "session buffers"
41 );
42
43 if let Some(abort_stream) = abort_stream {
44 let (_, dummy) = futures::future::AbortHandle::new_pair();
48 hopr_network_types::utils::copy_duplex_abortable(
49 session,
50 stream,
51 (max_buffer, max_buffer),
52 (dummy, abort_stream),
53 )
54 .await
55 .map(|(a, b)| (a as usize, b as usize))
56 } else {
57 hopr_network_types::utils::copy_duplex(session, stream, (max_buffer, max_buffer))
58 .await
59 .map(|(a, b)| (a as usize, b as usize))
60 }
61}
62
63pub(crate) async fn insert_into_next_slot<F, K, U, V>(
69 cache: &moka::future::Cache<K, V>,
70 mut generator: F,
71 value_fn: U,
72) -> Option<(K, V)>
73where
74 F: FnMut(Option<K>) -> K,
75 K: Copy + std::hash::Hash + Eq + Send + Sync + 'static,
76 U: FnOnce(K) -> V,
77 V: Clone + Send + Sync + 'static,
78{
79 cache.run_pending_tasks().await;
80
81 let value_fn = std::sync::Arc::new(parking_lot::Mutex::new(Some(value_fn)));
84
85 let initial = generator(None);
86 let mut next = initial;
87 loop {
88 let value_fn = value_fn.clone();
89 let insertion_result = cache
90 .entry(next)
91 .and_try_compute_with(move |e| {
92 if e.is_none() {
93 let f = value_fn
94 .lock()
95 .take()
96 .expect("impossible: value_fn was already consumed");
97
98 futures::future::ok::<_, ()>(moka::ops::compute::Op::Put(f(next)))
99 } else {
100 futures::future::ok::<_, ()>(moka::ops::compute::Op::Nop)
101 }
102 })
103 .await;
104
105 if let Ok(moka::ops::compute::CompResult::Inserted(val)) = insertion_result {
107 return Some((next, val.into_value()));
108 }
109
110 next = generator(Some(next));
112
113 if next == initial {
115 return None;
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
123pub(crate) enum SurbNotificationMode {
124 DoNotNotify,
126 Target,
128 Level(AtomicSurbFlowEstimator),
130}
131
132#[instrument(level = "debug", skip(sender, routing, notification_mode, cfg))]
134pub(crate) fn spawn_keep_alive_stream<S>(
135 session_id: SessionId,
136 sender: S,
137 routing: DestinationRouting,
138 notification_mode: SurbNotificationMode,
139 cfg: std::sync::Arc<BalancerStateValues>,
140) -> (RateController, AbortHandle)
141where
142 S: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Clone + Send + Sync + Unpin + 'static,
143 S::Error: std::error::Error + Send + Sync + 'static,
144{
145 let controller = RateController::new(0, Duration::from_secs(1));
147
148 let (ka_stream, abort_handle) = futures::stream::abortable(
149 futures::stream::repeat_with(move || match ¬ification_mode {
150 SurbNotificationMode::Target => HoprStartProtocol::KeepAlive(KeepAliveMessage {
151 session_id,
152 flags: KeepAliveFlag::BalancerTarget.into(),
153 additional_data: cfg.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
154 }),
155 SurbNotificationMode::Level(estimator) => HoprStartProtocol::KeepAlive(KeepAliveMessage {
156 session_id,
157 flags: KeepAliveFlag::BalancerState.into(),
158 additional_data: estimator.saturating_diff(),
159 }),
160 SurbNotificationMode::DoNotNotify => HoprStartProtocol::KeepAlive(KeepAliveMessage {
161 session_id,
162 flags: None.into(),
163 additional_data: 0,
164 }),
165 })
166 .rate_limit_with_controller(&controller),
167 );
168
169 let sender_clone = sender.clone();
170 let fwd_routing_clone = routing.clone();
171
172 debug!(%session_id, "spawning keep-alive stream");
174 hopr_async_runtime::prelude::spawn(
175 ka_stream
176 .map(move |msg| {
177 ApplicationData::try_from(msg)
178 .map(|data| (fwd_routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
179 })
180 .map_err(TransportSessionError::from)
181 .try_for_each_concurrent(None, move |msg| {
182 let mut sender_clone = sender_clone.clone();
183 async move {
184 sender_clone
185 .send(msg)
186 .await
187 .map_err(TransportSessionError::packet_sending)
188 }
189 })
190 .then(move |res| {
191 match res {
192 Ok(_) => tracing::debug!(
193 component = "session",
194 %session_id,
195 task = "session keepalive",
196 "background task finished"
197 ),
198 Err(error) => error!(%session_id, %error, "keep-alive stream failed"),
199 }
200 futures::future::ready(())
201 })
202 .in_current_span(),
203 );
204
205 (controller, abort_handle)
206}
207
208#[cfg(test)]
209mod tests {
210 use anyhow::anyhow;
211
212 use super::*;
213
214 #[tokio::test]
215 async fn test_insert_into_next_slot() -> anyhow::Result<()> {
216 let cache = moka::future::Cache::new(10);
217
218 for i in 0..5 {
219 let (k, v) = insert_into_next_slot(
220 &cache,
221 |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
222 |k| format!("foo_{k}"),
223 )
224 .await
225 .ok_or(anyhow!("should insert"))?;
226 assert_eq!(k, i);
227 assert_eq!(format!("foo_{i}"), v);
228 assert_eq!(Some(v), cache.get(&i).await);
229 }
230
231 assert!(
232 insert_into_next_slot(
233 &cache,
234 |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
235 |_| "foo".to_string()
236 )
237 .await
238 .is_none(),
239 "must not find slot when full"
240 );
241
242 Ok(())
243 }
244}