1use std::time::Duration;
2
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use hopr_api::types::internal::routing::DestinationRouting;
5use hopr_protocol_app::prelude::{ApplicationData, ApplicationDataOut};
6use hopr_protocol_start::{KeepAliveFlag, KeepAliveMessage};
7use hopr_utils::runtime::AbortHandle;
8use tracing::{Instrument, debug, error, instrument};
9
10use crate::{
11 AtomicSurbFlowEstimator, SessionId,
12 balancer::{BalancerStateValues, RateController, RateLimitStreamExt, SurbFlowEstimator},
13 errors::TransportSessionError,
14 types::HoprStartProtocol,
15};
16
17#[cfg(feature = "runtime-tokio")]
26pub async fn transfer_session<S>(
27 session: &mut crate::HoprSession,
28 stream: &mut S,
29 max_buffer: usize,
30 abort_stream: Option<futures::future::AbortRegistration>,
31) -> std::io::Result<(usize, usize)>
32where
33 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
34{
35 tracing::debug!(
37 session_id = ?session.id(),
38 egress_buffer = max_buffer,
39 ingress_buffer = max_buffer,
40 "session buffers"
41 );
42
43 if let Some(abort_stream) = abort_stream {
44 let (_, dummy) = futures::future::AbortHandle::new_pair();
48 hopr_utils::network_types::utils::copy_duplex_abortable(
49 session,
50 stream,
51 (max_buffer, max_buffer),
52 (dummy, abort_stream),
53 )
54 .await
55 .map(|(a, b)| (a as usize, b as usize))
56 } else {
57 hopr_utils::network_types::utils::copy_duplex(session, stream, (max_buffer, max_buffer))
58 .await
59 .map(|(a, b)| (a as usize, b as usize))
60 }
61}
62
63pub(crate) async fn insert_into_next_slot<F, K, U, V>(
69 cache: &moka::future::Cache<K, V>,
70 mut generator: F,
71 value_fn: U,
72) -> Option<(K, V)>
73where
74 F: FnMut(Option<K>) -> K,
75 K: Copy + std::hash::Hash + Eq + Send + Sync + 'static,
76 U: FnOnce(K) -> V,
77 V: Clone + Send + Sync + 'static,
78{
79 cache.run_pending_tasks().await;
80
81 let value_fn = std::sync::Arc::new(parking_lot::Mutex::new(Some(value_fn)));
84
85 let initial = generator(None);
86 let mut next = initial;
87 loop {
88 let value_fn = value_fn.clone();
89 let insertion_result = cache
90 .entry(next)
91 .and_try_compute_with(move |e| {
92 if e.is_none() {
93 let f = value_fn
94 .lock()
95 .take()
96 .expect("impossible: value_fn was already consumed");
97
98 futures::future::ok::<_, ()>(moka::ops::compute::Op::Put(f(next)))
99 } else {
100 futures::future::ok::<_, ()>(moka::ops::compute::Op::Nop)
101 }
102 })
103 .await;
104
105 if let Ok(moka::ops::compute::CompResult::Inserted(val)) = insertion_result {
107 return Some((next, val.into_value()));
108 }
109
110 next = generator(Some(next));
112
113 if next == initial {
115 return None;
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
123pub(crate) enum SurbNotificationMode {
124 DoNotNotify,
126 Target,
128 Level(AtomicSurbFlowEstimator),
130}
131
132#[instrument(level = "debug", skip(sender, routing, notification_mode, cfg))]
134pub(crate) fn spawn_keep_alive_stream<S>(
135 session_id: SessionId,
136 sender: S,
137 routing: DestinationRouting,
138 notification_mode: SurbNotificationMode,
139 cfg: std::sync::Arc<BalancerStateValues>,
140) -> (RateController, AbortHandle)
141where
142 S: futures::Sink<(DestinationRouting, ApplicationDataOut)> + Clone + Send + Sync + Unpin + 'static,
143 S::Error: std::error::Error + Send + Sync + 'static,
144{
145 let controller = RateController::new(0, Duration::from_secs(1));
147
148 let (ka_stream, abort_handle) = futures::stream::abortable(
149 futures::stream::repeat_with(move || match ¬ification_mode {
150 SurbNotificationMode::Target => HoprStartProtocol::KeepAlive(KeepAliveMessage {
151 session_id,
152 flags: KeepAliveFlag::BalancerTarget.into(),
153 additional_data: cfg.target_surb_buffer_size.load(std::sync::atomic::Ordering::Relaxed),
154 }),
155 SurbNotificationMode::Level(estimator) => HoprStartProtocol::KeepAlive(KeepAliveMessage {
156 session_id,
157 flags: KeepAliveFlag::BalancerState.into(),
158 additional_data: estimator.saturating_diff(),
159 }),
160 SurbNotificationMode::DoNotNotify => HoprStartProtocol::KeepAlive(KeepAliveMessage {
161 session_id,
162 flags: None.into(),
163 additional_data: 0,
164 }),
165 })
166 .rate_limit_with_controller(&controller),
167 );
168
169 let sender_clone = sender.clone();
170 let fwd_routing_clone = routing.clone();
171
172 debug!(%session_id, "spawning keep-alive stream");
174 let keep_alive_diag = hopr_utils::runtime::diagnostics::ConcurrentDiagnostics::new(
175 "session_keep_alive_try_for_each_concurrent",
176 module_path!(),
177 file!(),
178 line!(),
179 );
180 hopr_utils::runtime::prelude::spawn(hopr_utils::runtime::diagnostics::instrument(
181 ka_stream
182 .map(move |msg| {
183 ApplicationData::try_from(msg)
184 .map(|data| (fwd_routing_clone.clone(), ApplicationDataOut::with_no_packet_info(data)))
185 })
186 .map_err(TransportSessionError::from)
187 .try_for_each_concurrent(None, move |msg| {
188 let mut sender_clone = sender_clone.clone();
189 let keep_alive_diag = keep_alive_diag.clone();
190 keep_alive_diag.wrap(async move {
191 sender_clone
192 .send(msg)
193 .await
194 .map_err(TransportSessionError::packet_sending)
195 })
196 })
197 .then(move |res| {
198 match res {
199 Ok(_) => tracing::debug!(
200 component = "session",
201 %session_id,
202 task = "session keepalive",
203 "background task finished"
204 ),
205 Err(error) => error!(%session_id, %error, "keep-alive stream failed"),
206 }
207 futures::future::ready(())
208 })
209 .in_current_span(),
210 "session_keep_alive",
211 module_path!(),
212 file!(),
213 line!(),
214 ));
215
216 (controller, abort_handle)
217}
218
219#[cfg(test)]
220mod tests {
221 use anyhow::anyhow;
222
223 use super::*;
224
225 #[tokio::test]
226 async fn test_insert_into_next_slot() -> anyhow::Result<()> {
227 let cache = moka::future::Cache::new(10);
228
229 for i in 0..5 {
230 let (k, v) = insert_into_next_slot(
231 &cache,
232 |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
233 |k| format!("foo_{k}"),
234 )
235 .await
236 .ok_or(anyhow!("should insert"))?;
237 assert_eq!(k, i);
238 assert_eq!(format!("foo_{i}"), v);
239 assert_eq!(Some(v), cache.get(&i).await);
240 }
241
242 assert!(
243 insert_into_next_slot(
244 &cache,
245 |prev| prev.map(|v| (v + 1) % 5).unwrap_or(0),
246 |_| "foo".to_string()
247 )
248 .await
249 .is_none(),
250 "must not find slot when full"
251 );
252
253 Ok(())
254 }
255}