1use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use futures_time::future::Timer;
12use tracing::instrument;
13
14use crate::{
15 errors::SessionError,
16 processing::types::{
17 FrameBuilder, FrameDashMap, FrameHashMap, FrameInspector, FrameMap, FrameMapEntry, FrameMapOccupiedEntry,
18 FrameMapVacantEntry,
19 },
20 protocol::{Frame, FrameId, Segment},
21};
22
23#[cfg(all(not(test), feature = "telemetry"))]
24lazy_static::lazy_static! {
25 static ref METRIC_TIME_TO_FRAME_FINISH: hopr_types::telemetry::SimpleHistogram =
26 hopr_types::telemetry::SimpleHistogram::new(
27 "hopr_session_time_to_finish_frame",
28 "Measures time in milliseconds it takes a frame to be reassembled",
29 vec![1.0, 2.0, 5.0, 10.0, 25.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 400.0, 500.0],
30 ).unwrap();
31}
32
33#[must_use = "streams do nothing unless polled"]
59#[pin_project::pin_project]
60pub struct Reassembler<S, M> {
61 #[pin]
62 inner: S,
63 #[pin]
64 timer: futures_time::task::Sleep,
65 incomplete_frames: M,
66 expired_frames: Vec<FrameId>,
67 max_age: Duration,
68 capacity: usize,
69 last_expiration: Option<Instant>,
70}
71
72impl<S: futures::Stream<Item = Segment>, M: FrameMap> Reassembler<S, M> {
73 fn new(inner: S, incomplete_frames: M, max_age: Duration, capacity: usize) -> Self {
74 Self {
75 inner,
76 timer: futures_time::task::sleep(
77 (max_age + Duration::from_millis(1))
78 .max(Duration::from_millis(1))
79 .into(),
80 ),
81 incomplete_frames,
82 expired_frames: Vec::with_capacity(capacity),
83 last_expiration: None,
84 max_age,
85 capacity,
86 }
87 }
88
89 fn expire_frames(incomplete_frames: &mut M, expired_frames: &mut Vec<FrameId>, max_age: Duration) {
90 incomplete_frames.retain(|id, builder| {
91 if builder.last_recv.elapsed() >= max_age {
92 expired_frames.push(*id);
93 false
94 } else {
95 true
96 }
97 });
98 }
99}
100
101impl<S: futures::Stream<Item = Segment>, M: FrameMap> futures::Stream for Reassembler<S, M> {
102 type Item = Result<Frame, SessionError>;
103
104 #[instrument(name = "Reassembler::poll_next", level = "trace", skip(self, cx), fields(num_incomplete = self.incomplete_frames.len()), ret)]
105 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 let mut this = self.project();
107 loop {
108 if let Some(frame_id) = this.expired_frames.pop() {
109 tracing::trace!(frame_id, "emit discarded frame");
110 return Poll::Ready(Some(Err(SessionError::FrameDiscarded(frame_id))));
111 }
112
113 let timer_poll = if this.incomplete_frames.len() > 0 {
115 this.timer.as_mut().poll(cx)
116 } else {
117 Poll::Pending
118 };
119
120 let inner_poll = if this.incomplete_frames.len() < *this.capacity {
122 this.inner.as_mut().poll_next(cx)
123 } else {
124 tracing::warn!("reassembler has reached its capacity");
126 Poll::Pending
127 };
128
129 tracing::trace!("polling next");
130 match (inner_poll, timer_poll) {
131 (Poll::Ready(Some(item)), timer) => {
132 if timer.is_ready() {
133 this.timer.as_mut().reset_timer();
134 }
135
136 tracing::trace!(
137 frame_id = item.frame_id,
138 seq_idx = item.seq_idx,
139 seq_len = %item.seq_flags,
140 "received segment"
141 );
142
143 match this.incomplete_frames.entry(item.frame_id) {
144 FrameMapEntry::Occupied(mut e) => {
145 let builder = e.get_builder_mut();
146 let seg_id = item.id();
147 match builder.add_segment(item) {
148 Ok(_) => {
149 tracing::trace!(frame_id = builder.frame_id(), %seg_id, "added segment");
150 if builder.is_complete() {
151 #[cfg(all(not(test), feature = "telemetry"))]
152 METRIC_TIME_TO_FRAME_FINISH
153 .observe(builder.created.elapsed().as_millis() as f64);
154
155 tracing::trace!(frame_id = builder.frame_id(), "frame is complete");
156 return Poll::Ready(Some(e.finalize().try_into()));
157 }
158 }
159 Err(error) => {
160 tracing::error!(%error, %seg_id, "encountered invalid segment");
161 }
162 }
163 }
164 FrameMapEntry::Vacant(e) => {
165 let builder = FrameBuilder::from(item);
166 if builder.is_complete() {
167 #[cfg(all(not(test), feature = "telemetry"))]
168 METRIC_TIME_TO_FRAME_FINISH.observe(builder.created.elapsed().as_millis() as f64);
169
170 tracing::trace!(frame_id = builder.frame_id(), "segment frame is complete");
171 return Poll::Ready(Some(builder.try_into()));
172 } else {
173 tracing::trace!(frame_id = builder.frame_id(), "added segment for new frame");
174 e.insert_builder(builder);
175 }
176 }
177 };
178
179 if this.last_expiration.is_none_or(|e| e.elapsed() >= *this.max_age) {
182 Self::expire_frames(this.incomplete_frames, this.expired_frames, *this.max_age);
183 *this.last_expiration = Some(Instant::now());
184 }
185 }
186 (Poll::Ready(None), _) => {
187 tracing::trace!("inner stream closed, dumping incomplete frames");
189 if this.incomplete_frames.len() > 0 {
190 this.incomplete_frames.retain(|id, _| {
191 this.expired_frames.push(*id);
192 false
193 });
194 } else {
195 tracing::trace!("done");
196 return Poll::Ready(None);
197 }
198 }
199 (Poll::Pending, Poll::Ready(_)) => {
200 Self::expire_frames(this.incomplete_frames, this.expired_frames, *this.max_age);
202 *this.last_expiration = Some(Instant::now());
203 this.timer.as_mut().reset_timer();
204 }
205 (Poll::Pending, Poll::Pending) => return Poll::Pending,
206 }
207 }
208 }
209}
210
211pub trait ReassemblerExt: futures::Stream<Item = Segment> {
213 fn reassembler(self, timeout: Duration, capacity: usize) -> Reassembler<Self, FrameHashMap>
216 where
217 Self: Sized,
218 {
219 Reassembler::new(
221 self,
222 FrameHashMap::with_capacity(FrameInspector::INCOMPLETE_FRAME_RATIO * capacity + 1),
223 timeout,
224 capacity,
225 )
226 }
227
228 fn reassembler_with_inspector(
234 self,
235 timeout: Duration,
236 capacity: usize,
237 inspector: FrameInspector,
238 ) -> Reassembler<Self, FrameDashMap>
239 where
240 Self: Sized,
241 {
242 Reassembler::new(self, inspector.0.clone(), timeout, capacity)
243 }
244}
245
246impl<T: ?Sized> ReassemblerExt for T where T: futures::Stream<Item = Segment> {}
247
248#[cfg(test)]
249mod tests {
250 use std::cmp::Ordering;
251
252 use anyhow::anyhow;
253 use futures::{SinkExt, StreamExt, TryStreamExt, pin_mut};
254 use futures_time::future::FutureExt;
255 use hex_literal::hex;
256 use rand::{SeedableRng, prelude::SliceRandom, rngs::StdRng};
257
258 use super::*;
259 use crate::utils::segment;
260
261 const RNG_SEED: [u8; 32] = hex!("d8a471f1c20490a3442b96fdde9d1807428096e1601b0cef0eea7e6d44a24c01");
262
263 fn result_comparator(a: &Result<Frame, SessionError>, b: &Result<Frame, SessionError>) -> Ordering {
264 match (a, b) {
265 (Ok(a), Ok(b)) => a.frame_id.cmp(&b.frame_id),
266 (Err(SessionError::FrameDiscarded(a)), Ok(b)) => a.cmp(&b.frame_id),
267 (Ok(a), Err(SessionError::FrameDiscarded(b))) => a.frame_id.cmp(b),
268 (Err(SessionError::FrameDiscarded(a)), Err(SessionError::FrameDiscarded(b))) => a.cmp(b),
269 _ => panic!("unexpected result"),
270 }
271 }
272
273 #[test_log::test(tokio::test)]
274 pub async fn reassembler_should_reassemble_frames() -> anyhow::Result<()> {
275 let expected = (1u32..=10)
276 .map(|frame_id| Frame {
277 frame_id,
278 data: hopr_types::crypto_random::random_bytes::<100>().into(),
279 is_terminating: false,
280 })
281 .collect::<Vec<_>>();
282
283 let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
284 let r_stream = r_stream.reassembler(Duration::from_secs(5), 1024);
285
286 let mut segments = expected
287 .iter()
288 .cloned()
289 .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
290 .collect::<Vec<_>>();
291
292 let mut rng = StdRng::from_seed(RNG_SEED);
293 segments.shuffle(&mut rng);
294
295 let jh = hopr_utils::runtime::prelude::spawn(futures::stream::iter(segments).map(Ok).forward(r_sink));
296
297 let mut actual = r_stream
298 .try_collect::<Vec<_>>()
299 .timeout(futures_time::time::Duration::from_secs(5))
300 .await??;
301
302 assert_eq!(actual.len(), expected.len());
303
304 actual.sort_by_key(|a| a.frame_id);
305 assert_eq!(actual, expected);
306
307 let _ = jh.await?;
308 Ok(())
309 }
310
311 #[test_log::test(tokio::test)]
312 pub async fn reassembler_should_discard_incomplete_frames_on_expiration() -> anyhow::Result<()> {
313 let expected = (1u32..=10)
314 .map(|frame_id| Frame {
315 frame_id,
316 data: hopr_types::crypto_random::random_bytes::<100>().into(),
317 is_terminating: false,
318 })
319 .collect::<Vec<_>>();
320
321 let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
322 let r_stream = r_stream.reassembler(Duration::from_millis(45), 1024);
323
324 let mut segments = expected
325 .iter()
326 .cloned()
327 .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
328 .filter(|s| s.frame_id != 2 || s.seq_idx != 1)
329 .collect::<Vec<_>>();
330
331 let mut rng = StdRng::from_seed(RNG_SEED);
332 segments.shuffle(&mut rng);
333
334 pin_mut!(r_sink);
335 r_sink.send_all(&mut futures::stream::iter(segments).map(Ok)).await?;
336
337 let mut actual = Vec::new();
338 pin_mut!(r_stream);
339 for _ in 0..expected.len() {
340 actual.push(r_stream.next().await.ok_or(anyhow!("missing frame"))?);
341 }
342 r_sink.close().await?;
343 assert_eq!(None, r_stream.try_next().await?);
344
345 actual.sort_by(result_comparator);
346
347 assert_eq!(actual.len(), expected.len());
348
349 for i in 0..expected.len() {
350 if i != 1 {
351 assert!(matches!(&actual[i], Ok(f) if *f == expected[i]));
352 } else {
353 assert!(matches!(actual[i], Err(SessionError::FrameDiscarded(2))));
355 }
356 }
357
358 Ok(())
359 }
360
361 #[test_log::test(tokio::test)]
362 pub async fn reassembler_should_discard_incomplete_frames_on_close() -> anyhow::Result<()> {
363 let expected = (1u32..=10)
364 .map(|frame_id| Frame {
365 frame_id,
366 data: hopr_types::crypto_random::random_bytes::<100>().into(),
367 is_terminating: false,
368 })
369 .collect::<Vec<_>>();
370
371 let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
372 let r_stream = r_stream.reassembler(Duration::from_millis(100), 1024);
373
374 let mut segments = expected
375 .iter()
376 .cloned()
377 .flat_map(|f| segment(f.data, 22, f.frame_id).unwrap())
378 .filter(|s| s.frame_id != 5 || s.seq_idx != 2)
379 .collect::<Vec<_>>();
380
381 let mut rng = StdRng::from_seed(RNG_SEED);
382 segments.shuffle(&mut rng);
383
384 let jh = hopr_utils::runtime::prelude::spawn(futures::stream::iter(segments).map(Ok).forward(r_sink));
385
386 let mut actual = r_stream
387 .collect::<Vec<_>>()
388 .timeout(futures_time::time::Duration::from_secs(5))
389 .await?;
390
391 assert_eq!(actual.len(), expected.len());
393
394 actual.sort_by(result_comparator);
395
396 for i in 0..expected.len() {
397 if i != 4 {
398 assert!(matches!(&actual[i], Ok(f) if *f == expected[i]));
399 } else {
400 assert!(matches!(actual[i], Err(SessionError::FrameDiscarded(5))));
402 }
403 }
404
405 let _ = jh.await?;
406 Ok(())
407 }
408
409 #[test_log::test(tokio::test)]
410 pub async fn reassembler_should_wait_and_discard_if_full() -> anyhow::Result<()> {
411 let expected = (1u32..=5)
412 .map(|frame_id| Frame {
413 frame_id,
414 data: hopr_types::crypto_random::random_bytes::<30>().into(),
415 is_terminating: false,
416 })
417 .collect::<Vec<_>>();
418
419 let (r_sink, r_stream) = futures::channel::mpsc::unbounded();
420 let r_stream = r_stream.reassembler(Duration::from_millis(200), 3);
421
422 pin_mut!(r_sink);
423 pin_mut!(r_stream);
424
425 let segments = expected
427 .iter()
428 .cloned()
429 .flat_map(|f| segment(f.data, 20, f.frame_id).unwrap())
430 .collect::<Vec<_>>();
431
432 let to_send = [
433 segments[1].clone(),
435 segments[2].clone(),
437 segments[5].clone(),
439 ];
440
441 let start = Instant::now();
442
443 r_sink.send_all(&mut futures::stream::iter(to_send).map(Ok)).await?;
445
446 assert!(
448 r_stream
449 .next()
450 .timeout(futures_time::time::Duration::from_millis(20))
451 .await
452 .is_err()
453 );
454
455 r_sink
457 .send_all(
458 &mut futures::stream::iter([
459 segments[6].clone(),
460 segments[7].clone(),
461 segments[8].clone(),
462 segments[9].clone(),
463 ])
464 .map(Ok),
465 )
466 .await?;
467
468 let mut reassembled = Vec::new();
469 for _ in 0..5 {
470 reassembled.push(r_stream.next().await.ok_or(anyhow!("missing frame"))?);
471 }
472 reassembled.sort_by(result_comparator);
473
474 assert!(
475 matches!(reassembled[0], Err(SessionError::FrameDiscarded(1))),
476 "{:?} must be discarded ID 1",
477 reassembled[0]
478 );
479 assert!(
480 matches!(reassembled[1], Err(SessionError::FrameDiscarded(2))),
481 "{:?} must be discarded ID 2",
482 reassembled[1]
483 );
484 assert!(
485 matches!(reassembled[2], Err(SessionError::FrameDiscarded(3))),
486 "{:?} must be discarded ID 3",
487 reassembled[2]
488 );
489 assert!(
490 matches!(&reassembled[3], Ok(f) if f == &expected[3].clone()),
491 "{:?} (idx 3) must be {:?}",
492 reassembled[3],
493 expected[3]
494 );
495 assert!(
496 matches!(&reassembled[4], Ok(f) if f == &expected[4].clone()),
497 "{:?} (idx 3) must be {:?}",
498 reassembled[4],
499 expected[4]
500 );
501
502 r_sink.close().await?;
503 assert_eq!(None, r_stream.try_next().await?);
504
505 assert!(start.elapsed() >= Duration::from_millis(200));
506
507 Ok(())
508 }
509}