hopr_network_types/
timeout.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures::{Stream, StreamExt};
8
9/// Represents a sink that will time out after a given duration if an item
10/// cannot be sent.
11#[pin_project::pin_project]
12pub struct TimeoutSink<S> {
13    #[pin]
14    inner: S,
15    #[pin]
16    timer: Option<futures_time::task::Sleep>,
17    timeout: std::time::Duration,
18}
19
20/// Error type for [`TimeoutSink`].
21#[derive(Debug, thiserror::Error, strum::EnumTryAs)]
22pub enum SinkTimeoutError<E> {
23    /// Inner sink could not make progress within the timeout.
24    #[error("sink timed out")]
25    Timeout,
26    /// Inner sink returned an error.
27    #[error("inner sink error: {0}")]
28    Inner(E),
29}
30
31impl<I, S: futures::Sink<I>> futures::Sink<I> for TimeoutSink<S> {
32    type Error = SinkTimeoutError<S::Error>;
33
34    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
35        let mut this = self.project();
36
37        // First, see if we can make progress on the inner sink.
38        match this.inner.poll_ready(cx) {
39            Poll::Ready(res) => {
40                // The inner sink is ready, so we can clear the timer.
41                this.timer.set(None);
42                Poll::Ready(res.map_err(SinkTimeoutError::Inner))
43            }
44            Poll::Pending => {
45                if this.timer.is_none() {
46                    // If no timer is present, create one with the given timeout.
47                    this.timer
48                        .set(Some(futures_time::task::sleep(futures_time::time::Duration::from(
49                            *this.timeout,
50                        ))));
51                }
52
53                // If a timer is present, poll it as well
54                if let Some(timer) = this.timer.as_mut().as_pin_mut() {
55                    futures::ready!(timer.poll(cx));
56                    this.timer.set(None);
57                    // The timer has expired, so we won't poll the inner sink again
58                    // and return an error.
59                    Poll::Ready(Err(SinkTimeoutError::Timeout))
60                } else {
61                    // Cannot happen as the timer is always set at this point
62                    unreachable!();
63                }
64            }
65        }
66    }
67
68    fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
69        self.project().inner.start_send(item).map_err(SinkTimeoutError::Inner)
70    }
71
72    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        self.project().inner.poll_flush(cx).map_err(SinkTimeoutError::Inner)
74    }
75
76    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77        self.project().inner.poll_close(cx).map_err(SinkTimeoutError::Inner)
78    }
79}
80
81impl<S: Clone> Clone for TimeoutSink<S> {
82    fn clone(&self) -> Self {
83        Self {
84            inner: self.inner.clone(),
85            timer: None,
86            timeout: self.timeout,
87        }
88    }
89}
90
91/// [`futures::Sink`] adaptor that adds timeout.
92pub trait TimeoutSinkExt<I>: futures::Sink<I> {
93    /// Attaches a timeout onto this [`futures::Sink`]'s `poll_ready` function.
94    ///
95    /// The returned `Sink` will return an error if `poll_ready` does not
96    /// return within the given `timeout`.
97    fn with_timeout(self, timeout: std::time::Duration) -> TimeoutSink<Self>
98    where
99        Self: Sized,
100    {
101        TimeoutSink {
102            inner: self,
103            timer: None,
104            timeout,
105        }
106    }
107}
108
109impl<T: ?Sized, I> TimeoutSinkExt<I> for T where T: futures::Sink<I> {}
110
111#[pin_project::pin_project]
112pub struct ForwardWithTimeout<St, Si, Item> {
113    #[pin]
114    sink: Option<Si>,
115    #[pin]
116    stream: futures::stream::Fuse<St>,
117    buffered_item: Option<Item>,
118}
119
120impl<St: futures::Stream, Si, Item> ForwardWithTimeout<St, Si, Item> {
121    pub(crate) fn new(stream: St, sink: Si) -> Self {
122        Self {
123            sink: Some(sink),
124            stream: stream.fuse(),
125            buffered_item: None,
126        }
127    }
128}
129
130impl<St, Si, Item, E> futures::future::FusedFuture for ForwardWithTimeout<St, Si, Item>
131where
132    Si: futures::Sink<Item, Error = SinkTimeoutError<E>>,
133    St: Stream<Item = Result<Item, E>>,
134{
135    fn is_terminated(&self) -> bool {
136        self.sink.is_none()
137    }
138}
139
140impl<St, Si, Item, E> Future for ForwardWithTimeout<St, Si, Item>
141where
142    Si: futures::Sink<Item, Error = SinkTimeoutError<E>>,
143    St: Stream<Item = Result<Item, E>>,
144{
145    type Output = Result<(), E>;
146
147    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148        let mut this = self.project();
149        let mut si = this
150            .sink
151            .as_mut()
152            .as_pin_mut()
153            .expect("polled `Forward` after completion");
154
155        loop {
156            // If we've got an item buffered already, we need to try to write it to the
157            // sink before we can do anything else
158            if this.buffered_item.is_some() {
159                match futures::ready!(si.as_mut().poll_ready(cx)) {
160                    Ok(_) => {
161                        si.as_mut()
162                            .start_send(this.buffered_item.take().unwrap())
163                            .map_err(|e| e.try_as_inner().unwrap())?;
164                    }
165                    Err(SinkTimeoutError::Timeout) => {
166                        // If there was a timeout, drop the buffered item and continue
167                        // polling the stream for the next one.
168                        *this.buffered_item = None;
169                        continue;
170                    }
171                    Err(SinkTimeoutError::Inner(e)) => return Poll::Ready(Err(e)),
172                }
173            }
174
175            match this.stream.as_mut().poll_next(cx)? {
176                Poll::Ready(Some(item)) => {
177                    *this.buffered_item = Some(item);
178                }
179                Poll::Ready(None) => {
180                    futures::ready!(si.poll_close(cx)).map_err(|e| e.try_as_inner().unwrap())?;
181                    this.sink.set(None);
182                    return Poll::Ready(Ok(()));
183                }
184                Poll::Pending => {
185                    futures::ready!(si.poll_flush(cx)).map_err(|e| e.try_as_inner().unwrap())?;
186                    return Poll::Pending;
187                }
188            }
189        }
190    }
191}
192
193/// [`futures::TryStream`] extension that allows forwarding items to a sink with a timeout while
194/// discarding timed out items.
195pub trait TimeoutStreamExt: futures::TryStream {
196    /// Specialization of [`StreamExt::forward`] for Sinks using the [`SinkTimeoutError`].
197    ///
198    /// If the `sink` returns [`SinkTimeoutError::Timeout`], the current item from this
199    /// stream is discarded and the forwarding process continues with the next item
200    /// until the stream is depleted.
201    ///
202    /// This is in contrast to [`StreamExt::forward`] which would terminate with [`SinkTimeoutError::Timeout`].
203    ///
204    /// Errors other than [`SinkTimeoutError::Timeout`] cause the forwarding to terminate
205    /// with that error (as in the original behavior of [`StreamExt::forward`]).
206    fn forward_to_timeout<S>(self, sink: S) -> ForwardWithTimeout<Self, S, Self::Ok>
207    where
208        S: futures::Sink<Self::Ok, Error = SinkTimeoutError<Self::Error>>,
209        Self: Sized,
210    {
211        ForwardWithTimeout::new(self, sink)
212    }
213}
214
215impl<T: ?Sized> TimeoutStreamExt for T where T: futures::TryStream {}
216
217#[cfg(test)]
218mod tests {
219    use futures::SinkExt;
220
221    use super::*;
222
223    #[derive(Default)]
224    struct FixedSink<const N: usize, I>(Vec<I>);
225
226    impl<const N: usize, I> futures::Sink<I> for FixedSink<N, I> {
227        type Error = std::convert::Infallible;
228
229        fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
230            if self.0.len() < N {
231                Poll::Ready(Ok(()))
232            } else {
233                Poll::Pending
234            }
235        }
236
237        fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
238            // SAFETY: We're not moving any pinned data, just mutating the Vec in place
239            let this = unsafe { self.get_unchecked_mut() };
240            this.0.push(item);
241            Ok(())
242        }
243
244        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
245            Poll::Ready(Ok(()))
246        }
247
248        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249            Poll::Ready(Ok(()))
250        }
251    }
252
253    #[test_log::test(tokio::test)]
254    async fn test_timeout_sink() -> anyhow::Result<()> {
255        let mut sink = FixedSink::<1, i32>::default();
256
257        {
258            let mut timed_sink = (&mut sink).with_timeout(std::time::Duration::from_millis(10));
259
260            timed_sink.send(10).await?;
261            assert!(matches!(timed_sink.send(20).await, Err(SinkTimeoutError::Timeout)));
262        }
263
264        assert_eq!(1, sink.0.len());
265        sink.0.remove(0);
266
267        {
268            let mut timed_sink = (&mut sink).with_timeout(std::time::Duration::from_millis(10));
269
270            timed_sink.send(10).await?;
271            assert!(matches!(timed_sink.send(20).await, Err(SinkTimeoutError::Timeout)));
272        }
273
274        Ok(())
275    }
276
277    #[test_log::test(tokio::test)]
278    async fn test_forward_with_timeout() -> anyhow::Result<()> {
279        let stream = futures::stream::iter([1, 2, 3, 4, 5]).map(Ok);
280
281        let mut sink = FixedSink::<2, i32>::default();
282
283        let start = std::time::Instant::now();
284        stream
285            .forward_to_timeout((&mut sink).with_timeout(std::time::Duration::from_millis(10)))
286            .await?;
287        assert!(
288            start.elapsed() > std::time::Duration::from_millis(29),
289            "should've taken at least 30ms"
290        );
291
292        assert_eq!(2, sink.0.len());
293        assert_eq!(1, sink.0[0]);
294        assert_eq!(2, sink.0[1]);
295
296        Ok(())
297    }
298}