Skip to main content

hopr_protocol_session/processing/
segmenter.rs

1//! This module defines a [`Segmenter`] adaptor for [`futures::Sink`].
2use std::{
3    collections::VecDeque,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use tracing::instrument;
9
10use crate::{
11    protocol::{FrameId, Segment, SeqIndicator, SessionMessage},
12    utils::segment_into,
13};
14
15/// Segmenter is an adaptor to [`futures::Sink`] of [`Segment`] items
16/// that turns it into [`futures::io::AsyncWrite`].
17///
18/// Bytes written to the Segmenter are buffered up and/or chopped into [`Segment`]
19/// of at most `C` in payload size (the `data` member).
20///
21/// The bytes are written to the
22/// underlying Sink once more than `frame_size` is written (or unless flushed),
23/// so the Segmenter naturally acts as a buffered writer.
24/// Any unflushed bytes written to the Segmenter will be lost when it is closed.
25///
26/// The data are grouped into [`Frames`](crate::protocol::Frame) of the size given by the `frame_size`
27/// parameter, segments in each such group share the same [`FrameId`].
28/// This acts as a natural buffering feature of a Segmenter.
29///
30/// Segmenter can optionally send a [terminating](Segment::terminating) when `poll_close`
31/// is called.
32///
33/// Segmenter is essentially inverse of [`Reassembler`](super::reassembly::Reassembler).
34///
35/// Use [`SegmenterExt`] to turn a `Segment` sink into an `AsyncWrite` object using the `Segmenter`.
36#[must_use = "sinks do nothing unless polled"]
37#[pin_project::pin_project]
38pub struct Segmenter<const C: usize, S> {
39    #[pin]
40    inner: S,
41    state: State,
42    frame: Vec<u8>,
43    ready_segments: VecDeque<Segment>,
44    frame_size: usize,
45    frame_id: FrameId,
46    is_closed: bool,
47    send_terminating_segment: bool,
48}
49
50enum State {
51    BufferingFrame,
52    WritingFrame,
53}
54
55impl<const C: usize, S> Segmenter<C, S>
56where
57    S: futures::Sink<Segment>,
58    S::Error: std::error::Error + Send + Sync + 'static,
59{
60    fn new(inner: S, frame_size: usize, send_terminating_segment: bool) -> Self {
61        // We must clamp to at most SeqIndicator::MAX + 1 segments per frame.
62        // This is true for Session protocol without partial acknowledgements.
63        // When partial acknowledgements are enabled, the maximum frame size is less,
64        // and the caller must take care of it.
65        let frame_size = frame_size.clamp(
66            C,
67            (C - SessionMessage::<C>::SEGMENT_OVERHEAD) * (SeqIndicator::MAX + 1) as usize,
68        );
69
70        Self {
71            inner,
72            state: State::BufferingFrame,
73            frame: Vec::with_capacity(frame_size),
74            ready_segments: VecDeque::with_capacity(frame_size.div_ceil(C - SessionMessage::<C>::SEGMENT_OVERHEAD)),
75            frame_size,
76            frame_id: 1,
77            is_closed: false,
78            send_terminating_segment,
79        }
80    }
81}
82
83impl<const C: usize, S> futures::io::AsyncWrite for Segmenter<C, S>
84where
85    S: futures::Sink<Segment>,
86    S::Error: std::error::Error + Send + Sync + 'static,
87{
88    #[instrument(name = "Segmenter::poll_write", level = "trace", skip(self, cx, buf), fields(frame_id = self.frame_id, buf_len = buf.len(), frame_size = self.frame.len(), ready_segments = self.ready_segments.len()), ret)]
89    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
90        if self.is_closed {
91            return Poll::Ready(Err(std::io::Error::new(
92                std::io::ErrorKind::BrokenPipe,
93                "segmenter closed",
94            )));
95        }
96
97        let mut this = self.project();
98        loop {
99            match this.state {
100                State::BufferingFrame => {
101                    // If there's space in the frame, keep writing to it
102                    if *this.frame_size > this.frame.len() {
103                        let to_write = buf.len().min(*this.frame_size - this.frame.len());
104                        this.frame.extend_from_slice(&buf[..to_write]);
105
106                        return Poll::Ready(Ok(to_write));
107                    } else {
108                        // No more space in the frame buffer, we need to segment it
109                        // and write segments to the downstream
110                        segment_into(
111                            this.frame.as_slice(),
112                            C - SessionMessage::<C>::SEGMENT_OVERHEAD,
113                            *this.frame_id,
114                            this.ready_segments,
115                        )
116                        .map_err(std::io::Error::other)?;
117
118                        tracing::trace!(num_segments = this.ready_segments.len(), "frame ready");
119
120                        this.frame.clear();
121                        *this.frame_id += 1;
122                        *this.state = State::WritingFrame;
123                    }
124                }
125                State::WritingFrame => {
126                    if !this.ready_segments.is_empty() {
127                        // Keep writing segments to downstream
128                        futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
129
130                        let segment = this.ready_segments.pop_front().unwrap();
131                        tracing::trace!(seg_id = %segment.id(), "segment goes out");
132                        this.inner.as_mut().start_send(segment).map_err(std::io::Error::other)?;
133                    } else {
134                        // Once we're done, we can buffer another frame
135                        *this.state = State::BufferingFrame;
136                        tracing::trace!("all segments out");
137                    }
138                }
139            }
140        }
141    }
142
143    #[instrument(name = "Segmenter::poll_flush", level = "trace", skip(self, cx), fields(frame_id = self.frame_id, frame_size = self.frame.len(), ready_segments = self.ready_segments.len()), ret)]
144    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
145        if self.is_closed {
146            return Poll::Ready(Err(std::io::Error::new(
147                std::io::ErrorKind::BrokenPipe,
148                "segmenter closed",
149            )));
150        }
151
152        let mut this = self.project();
153        loop {
154            // If there's any data in the unfinished frame, segment it
155            if !this.frame.is_empty() {
156                // Flush the downstream sink first
157                futures::ready!(this.inner.as_mut().poll_flush(cx).map_err(std::io::Error::other))?;
158
159                // Segment whatever data is in the frame
160                // At this point ready_segments must be empty,
161                // because poll_write always makes sure it is before returning Ready
162                segment_into(
163                    this.frame.as_slice(),
164                    C - SessionMessage::<C>::SEGMENT_OVERHEAD,
165                    *this.frame_id,
166                    this.ready_segments,
167                )
168                .map_err(std::io::Error::other)?;
169
170                tracing::trace!(num_segments = this.ready_segments.len(), "flushed frame ready");
171
172                this.frame.clear();
173                *this.frame_id += 1;
174            } else if !this.ready_segments.is_empty() {
175                futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
176
177                let segment = this.ready_segments.pop_front().unwrap();
178                tracing::trace!(seg_id = %segment.id(), "segment flushing out");
179
180                this.inner.as_mut().start_send(segment).map_err(std::io::Error::other)?;
181            } else {
182                // Both buffers are empty, so only flush the downstream
183                futures::ready!(this.inner.as_mut().poll_flush(cx).map_err(std::io::Error::other))?;
184
185                tracing::trace!("all segments flushed out");
186                return Poll::Ready(Ok(()));
187            }
188        }
189    }
190
191    #[instrument(name = "Segmenter::poll_close", level = "trace", skip(self, cx), fields(frame_id = self.frame_id) , ret)]
192    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
193        let mut this = self.project();
194
195        if *this.send_terminating_segment && !*this.is_closed {
196            futures::ready!(this.inner.as_mut().poll_ready(cx).map_err(std::io::Error::other))?;
197            let dummy = Segment::terminating(*this.frame_id);
198            this.inner.as_mut().start_send(dummy).map_err(std::io::Error::other)?;
199            tracing::trace!("sent terminating segment");
200        }
201
202        *this.is_closed = true;
203        this.inner.as_mut().poll_close(cx).map_err(std::io::Error::other)
204    }
205}
206
207/// Sink extension methods for segmenting binary data into a sink.
208pub trait SegmenterExt: futures::Sink<Segment> {
209    /// Attaches a [`Segmenter`] to the underlying sink.
210    fn segmenter<const C: usize>(self, frame_size: usize) -> Segmenter<C, Self>
211    where
212        Self: Sized,
213        Self::Error: std::error::Error + Send + Sync + 'static,
214    {
215        Segmenter::new(self, frame_size, false)
216    }
217
218    /// Attaches a [`Segmenter`] to the underlying sink.
219    /// The `Segmenter` also sends a [terminating](Segment::terminating) when closed.
220    fn segmenter_with_terminating_segment<const C: usize>(self, frame_size: usize) -> Segmenter<C, Self>
221    where
222        Self: Sized,
223        Self::Error: std::error::Error + Send + Sync + 'static,
224    {
225        Segmenter::new(self, frame_size, true)
226    }
227}
228
229impl<T: ?Sized> SegmenterExt for T where T: futures::Sink<Segment> {}
230
231#[cfg(test)]
232mod tests {
233    use anyhow::{Context, anyhow};
234    use futures::{AsyncWriteExt, Stream, StreamExt, pin_mut};
235    use futures_time::future::FutureExt;
236
237    use super::*;
238    use crate::{protocol::SeqNum, utils::segment};
239
240    const MTU: usize = 1000;
241    const SMTU: usize = MTU - SessionMessage::<MTU>::SEGMENT_OVERHEAD;
242    const FRAME_SIZE: usize = 1500;
243
244    const SEGMENTS_PER_FRAME: usize = FRAME_SIZE / MTU + 1;
245
246    async fn assert_frame_segments(
247        start_frame_id: FrameId,
248        num_frames: usize,
249        segments: &mut (impl Stream<Item = Segment> + Unpin),
250        data: &[u8],
251    ) -> anyhow::Result<()> {
252        for i in 0..num_frames * SEGMENTS_PER_FRAME {
253            let start_frame_id = start_frame_id as usize;
254            let frame_id = i / SEGMENTS_PER_FRAME + start_frame_id;
255            tracing::debug!("testing frame id {frame_id} {}", (i % SEGMENTS_PER_FRAME) as SeqNum);
256
257            let seg = segments
258                .next()
259                .timeout(futures_time::time::Duration::from_millis(500))
260                .await
261                .context(format!("assert_frame_segments {i}"))?
262                .ok_or(anyhow!("no more segments"))?;
263
264            assert_eq!(frame_id as FrameId, seg.frame_id);
265            assert_eq!((i % SEGMENTS_PER_FRAME) as SeqNum, seg.seq_idx);
266            assert_eq!((FRAME_SIZE / MTU + 1) as SeqNum, seg.seq_flags.seq_len());
267            if i % SEGMENTS_PER_FRAME == 0 {
268                assert_eq!(SMTU, seg.data.len());
269                assert_eq!(
270                    &data[(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU
271                        ..(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU + SMTU],
272                    seg.data.as_ref()
273                );
274            } else {
275                assert_eq!(FRAME_SIZE % SMTU, seg.data.len());
276                assert_eq!(
277                    &data[(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU
278                        ..(frame_id - start_frame_id) * FRAME_SIZE + i % SEGMENTS_PER_FRAME * SMTU + FRAME_SIZE % SMTU],
279                    seg.data.as_ref()
280                );
281            }
282        }
283
284        Ok(())
285    }
286
287    #[tokio::test]
288    async fn segmenter_should_not_segment_small_data_unless_flushed() -> anyhow::Result<()> {
289        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
290        let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
291
292        writer.write_all(b"test").await?;
293
294        pin_mut!(segments);
295        segments
296            .next()
297            .timeout(futures_time::time::Duration::from_millis(10))
298            .await
299            .expect_err("should time out");
300
301        writer.flush().await?;
302
303        let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
304        assert_eq!(1, seg.frame_id);
305        assert_eq!(1, seg.seq_flags.seq_len());
306        assert_eq!(0, seg.seq_idx);
307        assert_eq!(b"test", seg.data.as_ref());
308
309        Ok(())
310    }
311
312    #[parameterized::parameterized(num_frames = { 1, 3, 5, 11 })]
313    #[parameterized_macro(tokio::test)]
314    async fn segmenter_should_segment_complete_frames(num_frames: usize) -> anyhow::Result<()> {
315        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
316        let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
317
318        let mut all_data = Vec::new();
319        for _ in 0..num_frames {
320            let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
321            writer.write_all(&data).await?;
322            all_data.extend(data);
323        }
324
325        writer.flush().await?;
326
327        pin_mut!(segments);
328        assert_frame_segments(1, num_frames, &mut segments, &all_data).await?;
329
330        writer.close().await?;
331
332        assert_eq!(None, segments.next().await);
333        Ok(())
334    }
335
336    #[tokio::test]
337    async fn segmenter_full_frame_segmentation_must_be_consistent_with_segment_function() -> anyhow::Result<()> {
338        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
339        let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
340
341        let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
342
343        writer.write_all(&data).await?;
344        writer.flush().await?;
345        writer.close().await?;
346
347        // Segmenter already takes into account the SessionMessage overhead
348        let expected = segment(data, SMTU, 1)?;
349        let actual = segments.collect::<Vec<_>>().await;
350
351        assert_eq!(expected, actual);
352
353        Ok(())
354    }
355
356    #[test_log::test(tokio::test)]
357    async fn segmenter_full_frame_segmentation_must_also_include_terminating_segment() -> anyhow::Result<()> {
358        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
359        let mut writer = segments_tx.segmenter_with_terminating_segment::<MTU>(FRAME_SIZE);
360
361        let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
362
363        writer.write_all(&data).await?;
364        writer.flush().await?;
365        writer.close().await?;
366
367        // Segmenter already takes into account the SessionMessage overhead
368        let mut expected = segment(data, SMTU, 1)?;
369        expected.push(Segment::terminating(2));
370        let actual = segments.collect::<Vec<_>>().await;
371
372        assert_eq!(expected, actual);
373
374        Ok(())
375    }
376
377    #[test_log::test(tokio::test)]
378    async fn segmenter_should_segment_complete_frame_with_misaligned_mtu() -> anyhow::Result<()> {
379        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
380        let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
381
382        // Make sure the FRAME_SIZE is not a multiple of MTU
383        assert_ne!(0, FRAME_SIZE % MTU);
384
385        let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
386        writer.write_all(&data).await?;
387        writer.flush().await?;
388        writer.close().await?;
389
390        pin_mut!(segments);
391
392        for i in 0..(FRAME_SIZE / MTU) {
393            let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
394            assert_eq!(1, seg.frame_id);
395            assert_eq!(i as SeqNum, seg.seq_idx);
396            assert_eq!(((FRAME_SIZE / SMTU) + 1) as SeqNum, seg.seq_flags.seq_len());
397            assert_eq!(SMTU, seg.data.len());
398            assert_eq!(&data[i * SMTU..i * SMTU + SMTU], seg.data.as_ref());
399        }
400
401        let seg = segments.next().await.ok_or(anyhow!("no more segments"))?;
402        assert_eq!(1, seg.frame_id);
403        assert_eq!((FRAME_SIZE / SMTU) as SeqNum, seg.seq_idx);
404        assert_eq!(((FRAME_SIZE / SMTU) + 1) as SeqNum, seg.seq_flags.seq_len());
405        assert_eq!(FRAME_SIZE % SMTU, seg.data.len());
406        assert_eq!(&data[FRAME_SIZE - FRAME_SIZE % SMTU..], seg.data.as_ref());
407
408        assert_eq!(None, segments.next().await);
409        Ok(())
410    }
411
412    #[test_log::test(tokio::test)]
413    async fn segmenter_should_segment_multiple_complete_frames_and_incomplete_frame_on_flush() -> anyhow::Result<()> {
414        let (segments_tx, segments) = futures::channel::mpsc::unbounded();
415        let mut writer = segments_tx.segmenter::<MTU>(FRAME_SIZE);
416
417        let data = hopr_types::crypto_random::random_bytes::<{ FRAME_SIZE + 4 }>();
418        writer.write_all(&data).await?;
419
420        pin_mut!(segments);
421
422        // The first frame should come out even without a flush
423        assert_frame_segments(1, 1, &mut segments, &data).await?;
424
425        // And no more segment comes out for the remaining bytes
426        segments
427            .next()
428            .timeout(futures_time::time::Duration::from_millis(10))
429            .await
430            .expect_err("should time out");
431
432        // ... until it is flushed
433        writer.flush().await?;
434
435        let seg = segments
436            .next()
437            .timeout(futures_time::time::Duration::from_millis(500))
438            .await?
439            .ok_or(anyhow!("no more segments"))?;
440        assert_eq!(2, seg.frame_id);
441        assert_eq!(0, seg.seq_idx);
442        assert_eq!(1, seg.seq_flags.seq_len());
443        assert_eq!(4, seg.data.len());
444        assert_eq!(&data[FRAME_SIZE..], seg.data.as_ref());
445
446        // The next full frame should come out normally after a flush
447        let data = hopr_types::crypto_random::random_bytes::<FRAME_SIZE>();
448        writer.write_all(&data).await?;
449        writer.flush().await?;
450
451        assert_frame_segments(3, 1, &mut segments, &data).await?;
452
453        Ok(())
454    }
455
456    #[test_log::test(tokio::test)]
457    async fn segmenter_should_work_with_buffering_backend() -> anyhow::Result<()> {
458        let (tx, rx) = futures::channel::mpsc::channel(5);
459        let mut writer = tx.segmenter::<MTU>(FRAME_SIZE);
460
461        let data = hopr_types::crypto_random::random_bytes::<{ 10 * FRAME_SIZE }>();
462
463        let jh_recv = tokio::task::spawn(
464            rx.collect::<Vec<_>>()
465                .delay(futures_time::time::Duration::from_millis(200)),
466        );
467        let jh_send = tokio::task::spawn(async move {
468            writer.write_all(&data).await?;
469            writer.flush().await?;
470            writer.close().await?;
471            Ok::<_, std::io::Error>(())
472        });
473
474        let (segments, send_res) = futures::future::try_join(jh_recv, jh_send).await?;
475        send_res?;
476
477        assert_frame_segments(1, 10, &mut futures::stream::iter(segments), &data).await
478    }
479}