1use std::{
3 collections::VecDeque,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use tracing::instrument;
9
10use crate::{
11 protocol::{FrameId, Segment, SeqIndicator, SessionMessage},
12 utils::segment_into,
13};
14
15#[must_use = "sinks do nothing unless polled"]
37#[pin_project::pin_project]
38pub struct Segmenter<const C: usize, S> {
39 #[pin]
40 inner: S,
41 state: State,
42 frame: Vec<u8>,
43 ready_segments: VecDeque<Segment>,
44 frame_size: usize,
45 frame_id: FrameId,
46 is_closed: bool,
47 send_terminating_segment: bool,
48}
49
50enum State {
51 BufferingFrame,
52 WritingFrame,
53}
54
55impl<const C: usize, S> Segmenter<C, S>
56where
57 S: futures::Sink<Segment>,
58 S::Error: std::error::Error + Send + Sync + 'static,
59{
60 fn new(inner: S, frame_size: usize, send_terminating_segment: bool) -> Self {
61 let frame_size = frame_size.clamp(
66 C,
67 (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
68 );
69
70 Self {
71 inner,
72 state: State::BufferingFrame,
73 frame: Vec::with_capacity(frame_size),
74 ready_segments: VecDeque::with_capacity(frame_size.div_ceil(C - SessionMessage::<C>::SEGMENT_OVERHEAD)),
75 frame_size,
76 frame_id: 1,
77 is_closed: false,
78 send_terminating_segment,
79 }
80 }
81}
82
83impl<const C: usize, S> futures::io::AsyncWrite for Segmenter<C, S>
84where
85 S: futures::Sink<Segment>,
86 S::Error: std::error::Error + Send + Sync + 'static,
87{
88 #[instrument(name = "Segmenter::poll_write", level = "trace", skip(self, cx, buf), fields(frame_id = self.frame_id, buf_len = buf.len(), frame_size = self.frame.len(), ready_segments = self.ready_segments.len()), ret)]
89 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
90 if self.is_closed {
91 return Poll::Ready(Err(std::io::Error::new(
92 std::io::ErrorKind::BrokenPipe,
93 "segmenter closed",
94 )));
95 }
96
97 let mut this = self.project();
98 loop {
99 match this.state {
100 State::BufferingFrame => {
101 if *this.frame_size > this.frame.len() {
103 let to_write = buf.len().min(*this.frame_size - this.frame.len());
104 this.frame.extend_from_slice(&buf[..to_write]);
105
106 return Poll::Ready(Ok(to_write));
107 } else {
108 segment_into(
111 this.frame.as_slice(),
112 C - SessionMessage::<C>::SEGMENT_OVERHEAD,
113 *this.frame_id,
114 this.ready_segments,
115 )
116 .map_err(std::io::Error::other)?;
117
118 tracing::trace!(num_segments = this.ready_segments.len(), "frame ready");
119
120 this.frame.clear();
121 *this.frame_id += 1;
122 *this.state = State::WritingFrame;
123 }
124 }
125 State::WritingFrame => {
126 if !this.ready_segments.is_empty() {
127 futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
129
130 let segment = this.ready_segments.pop_front().unwrap();
131 tracing::trace!(seg_id = %segment.id(), "segment goes out");
132 this.inner.as_mut().start_send(segment).map_err(std::io::Error::other)?;
133 } else {
134 *this.state = State::BufferingFrame;
136 tracing::trace!("all segments out");
137 }
138 }
139 }
140 }
141 }
142
143 #[instrument(name = "Segmenter::poll_flush", level = "trace", skip(self, cx), fields(frame_id = self.frame_id, frame_size = self.frame.len(), ready_segments = self.ready_segments.len()), ret)]
144 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
145 if self.is_closed {
146 return Poll::Ready(Err(std::io::Error::new(
147 std::io::ErrorKind::BrokenPipe,
148 "segmenter closed",
149 )));
150 }
151
152 let mut this = self.project();
153 loop {
154 if !this.frame.is_empty() {
156 futures::ready!(this.inner.as_mut().poll_flush(cx).map_err(std::io::Error::other))?;
158
159 segment_into(
163 this.frame.as_slice(),
164 C - SessionMessage::<C>::SEGMENT_OVERHEAD,
165 *this.frame_id,
166 this.ready_segments,
167 )
168 .map_err(std::io::Error::other)?;
169
170 tracing::trace!(num_segments = this.ready_segments.len(), "flushed frame ready");
171
172 this.frame.clear();
173 *this.frame_id += 1;
174 } else if !this.ready_segments.is_empty() {
175 futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
176
177 let segment = this.ready_segments.pop_front().unwrap();
178 tracing::trace!(seg_id = %segment.id(), "segment flushing out");
179
180 this.inner.as_mut().start_send(segment).map_err(std::io::Error::other)?;
181 } else {
182 futures::ready!(this.inner.as_mut().poll_flush(cx).map_err(std::io::Error::other))?;
184
185 tracing::trace!("all segments flushed out");
186 return Poll::Ready(Ok(()));
187 }
188 }
189 }
190
191 #[instrument(name = "Segmenter::poll_close", level = "trace", skip(self, cx), fields(frame_id = self.frame_id) , ret)]
192 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
193 let mut this = self.project();
194
195 if *this.send_terminating_segment && !*this.is_closed {
196 futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
197 let dummy = Segment::terminating(*this.frame_id);
198 this.inner.as_mut().start_send(dummy).map_err(std::io::Error::other)?;
199 tracing::trace!("sent terminating segment");
200 }
201
202 *this.is_closed = true;
203 this.inner.as_mut().poll_close(cx).map_err(std::io::Error::other)
204 }
205}
206
207pub trait SegmenterExt: futures::Sink<Segment> {
209 fn segmenter<const C: usize>(self, frame_size: usize) -> Segmenter<C, Self>
211 where
212 Self: Sized,
213 Self::Error: std::error::Error + Send + Sync + 'static,
214 {
215 Segmenter::new(self, frame_size, false)
216 }
217
218 fn segmenter_with_terminating_segment<const C: usize>(self, frame_size: usize) -> Segmenter<C, Self>
221 where
222 Self: Sized,
223 Self::Error: std::error::Error + Send + Sync + 'static,
224 {
225 Segmenter::new(self, frame_size, true)
226 }
227}
228
229impl<T: ?Sized> SegmenterExt for T where T: futures::Sink<Segment> {}
230
231#[cfg(test)]
232mod tests {
233 use anyhow::{Context, anyhow};
234 use futures::{AsyncWriteExt, Stream, StreamExt, pin_mut};
235 use futures_time::future::FutureExt;
236
237 use super::*;
238 use crate::{protocol::SeqNum, utils::segment};
239
240 const MTU: usize = 1000;
241 const SMTU: usize = MTU - SessionMessage::<MTU>::SEGMENT_OVERHEAD;
242 const FRAME_SIZE: usize = 1500;
243
244 const SEGMENTS_PER_FRAME: usize = FRAME_SIZE / MTU + 1;
245
246 async fn assert_frame_segments(
247 start_frame_id: FrameId,
248 num_frames: usize,
249 segments: &mut (impl Stream<Item = Segment> + Unpin),
250 data: &[u8],
251 ) -> anyhow::Result<()> {
252 for i in 0..num_frames * SEGMENTS_PER_FRAME {
253 let start_frame_id = start_frame_id as usize;
254 let frame_id = i / SEGMENTS_PER_FRAME + start_frame_id;
255 tracing::debug!("testing frame id {frame_id} {}", (i % SEGMENTS_PER_FRAME) as SeqNum);
256
257 let seg = segments
258 .next()
259 .timeout(futures_time::time::Duration::from_millis(500))
260 .await
261 .context(format!("assert_frame_segments {i}"))?
262 .ok_or(anyhow!("no more segments"))?;
263
264 assert_eq!(frame_id as FrameId, seg.frame_id);
265 assert_eq!((i % SEGMENTS_PER_FRAME) as SeqNum, seg.seq_idx);
266 assert_eq!((FRAME_SIZE / MTU + 1) as SeqNum, seg.seq_flags.seq_len());
267 if i % SEGMENTS_PER_FRAME == 0 {
268 assert_eq!(SMTU, seg.data.len());
269 assert_eq!(
270 &data[(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU
271 ..(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU + SMTU],
272 seg.data.as_ref()
273 );
274 } else {
275 assert_eq!(FRAME_SIZE % SMTU, seg.data.len());
276 assert_eq!(
277 &data[(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU
278 ..(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU + FRAME_SIZE % SMTU],
279 seg.data.as_ref()
280 );
281 }
282 }
283
284 Ok(())
285 }
286
287 #[tokio::test]
288 async fn segmenter_should_not_segment_small_data_unless_flushed() -> anyhow::Result<()> {
289 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
290 let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
291
292 writer.write_all(b"test").await?;
293
294 pin_mut!(segments);
295 segments
296 .next()
297 .timeout(futures_time::time::Duration::from_millis(10))
298 .await
299 .expect_err("should time out");
300
301 writer.flush().await?;
302
303 let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
304 assert_eq!(1, seg.frame_id);
305 assert_eq!(1, seg.seq_flags.seq_len());
306 assert_eq!(0, seg.seq_idx);
307 assert_eq!(b"test", seg.data.as_ref());
308
309 Ok(())
310 }
311
312 #[parameterized::parameterized(num_frames = { 1, 3, 5, 11 })]
313 #[parameterized_macro(tokio::test)]
314 async fn segmenter_should_segment_complete_frames(num_frames: usize) -> anyhow::Result<()> {
315 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
316 let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
317
318 let mut all_data = Vec::new();
319 for _ in 0..num_frames {
320 let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
321 writer.write_all(&data).await?;
322 all_data.extend(data);
323 }
324
325 writer.flush().await?;
326
327 pin_mut!(segments);
328 assert_frame_segments(1, num_frames, &mut segments, &all_data).await?;
329
330 writer.close().await?;
331
332 assert_eq!(None, segments.next().await);
333 Ok(())
334 }
335
336 #[tokio::test]
337 async fn segmenter_full_frame_segmentation_must_be_consistent_with_segment_function() -> anyhow::Result<()> {
338 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
339 let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
340
341 let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
342
343 writer.write_all(&data).await?;
344 writer.flush().await?;
345 writer.close().await?;
346
347 let expected = segment(data, SMTU, 1)?;
349 let actual = segments.collect::<Vec<_>>().await;
350
351 assert_eq!(expected, actual);
352
353 Ok(())
354 }
355
356 #[test_log::test(tokio::test)]
357 async fn segmenter_full_frame_segmentation_must_also_include_terminating_segment() -> anyhow::Result<()> {
358 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
359 let mut writer = segments_tx.segmenter_with_terminating_segment::<MTU>(FRAME_SIZE);
360
361 let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
362
363 writer.write_all(&data).await?;
364 writer.flush().await?;
365 writer.close().await?;
366
367 let mut expected = segment(data, SMTU, 1)?;
369 expected.push(Segment::terminating(2));
370 let actual = segments.collect::<Vec<_>>().await;
371
372 assert_eq!(expected, actual);
373
374 Ok(())
375 }
376
377 #[test_log::test(tokio::test)]
378 async fn segmenter_should_segment_complete_frame_with_misaligned_mtu() -> anyhow::Result<()> {
379 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
380 let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
381
382 assert_ne!(0, FRAME_SIZE % MTU);
384
385 let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
386 writer.write_all(&data).await?;
387 writer.flush().await?;
388 writer.close().await?;
389
390 pin_mut!(segments);
391
392 for i in 0..(FRAME_SIZE / MTU) {
393 let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
394 assert_eq!(1, seg.frame_id);
395 assert_eq!(i as SeqNum, seg.seq_idx);
396 assert_eq!(((FRAME_SIZE / SMTU) + 1) as SeqNum, seg.seq_flags.seq_len());
397 assert_eq!(SMTU, seg.data.len());
398 assert_eq!(&data[i * SMTU..i * SMTU + SMTU], seg.data.as_ref());
399 }
400
401 let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
402 assert_eq!(1, seg.frame_id);
403 assert_eq!((FRAME_SIZE / SMTU) as SeqNum, seg.seq_idx);
404 assert_eq!(((FRAME_SIZE / SMTU) + 1) as SeqNum, seg.seq_flags.seq_len());
405 assert_eq!(FRAME_SIZE % SMTU, seg.data.len());
406 assert_eq!(&data[FRAME_SIZE - FRAME_SIZE % SMTU..], seg.data.as_ref());
407
408 assert_eq!(None, segments.next().await);
409 Ok(())
410 }
411
412 #[test_log::test(tokio::test)]
413 async fn segmenter_should_segment_multiple_complete_frames_and_incomplete_frame_on_flush() -> anyhow::Result<()> {
414 let (segments_tx, segments) = futures::channel::mpsc::unbounded();
415 let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
416
417 let data = hopr_types::crypto_random::random_bytes::<{ FRAME_SIZE + 4 }>();
418 writer.write_all(&data).await?;
419
420 pin_mut!(segments);
421
422 assert_frame_segments(1, 1, &mut segments, &data).await?;
424
425 segments
427 .next()
428 .timeout(futures_time::time::Duration::from_millis(10))
429 .await
430 .expect_err("should time out");
431
432 writer.flush().await?;
434
435 let seg = segments
436 .next()
437 .timeout(futures_time::time::Duration::from_millis(500))
438 .await?
439 .ok_or(anyhow!("no more segments"))?;
440 assert_eq!(2, seg.frame_id);
441 assert_eq!(0, seg.seq_idx);
442 assert_eq!(1, seg.seq_flags.seq_len());
443 assert_eq!(4, seg.data.len());
444 assert_eq!(&data[FRAME_SIZE..], seg.data.as_ref());
445
446 let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
448 writer.write_all(&data).await?;
449 writer.flush().await?;
450
451 assert_frame_segments(3, 1, &mut segments, &data).await?;
452
453 Ok(())
454 }
455
456 #[test_log::test(tokio::test)]
457 async fn segmenter_should_work_with_buffering_backend() -> anyhow::Result<()> {
458 let (tx, rx) = futures::channel::mpsc::channel(5);
459 let mut writer = tx.segmenter::<MTU>(FRAME_SIZE);
460
461 let data = hopr_types::crypto_random::random_bytes::<{ 10 * FRAME_SIZE }>();
462
463 let jh_recv = tokio::task::spawn(
464 rx.collect::<Vec<_>>()
465 .delay(futures_time::time::Duration::from_millis(200)),
466 );
467 let jh_send = tokio::task::spawn(async move {
468 writer.write_all(&data).await?;
469 writer.flush().await?;
470 writer.close().await?;
471 Ok::<_, std::io::Error>(())
472 });
473
474 let (segments, send_res) = futures::future::try_join(jh_recv, jh_send).await?;
475 send_res?;
476
477 assert_frame_segments(1, 10, &mut futures::stream::iter(segments), &data).await
478 }
479}