1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures::{Stream, StreamExt};
8
9#[pin_project::pin_project]
12pub struct TimeoutSink<S> {
13 #[pin]
14 inner: S,
15 #[pin]
16 timer: Option<futures_time::task::Sleep>,
17 timeout: std::time::Duration,
18}
19
20#[derive(Debug, thiserror::Error, strum::EnumTryAs)]
22pub enum SinkTimeoutError<E> {
23 #[error("sink timed out")]
25 Timeout,
26 #[error("inner sink error: {0}")]
28 Inner(E),
29}
30
31impl<I, S: futures::Sink<I>> futures::Sink<I> for TimeoutSink<S> {
32 type Error = SinkTimeoutError<S::Error>;
33
34 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
35 let mut this = self.project();
36
37 match this.inner.poll_ready(cx) {
39 Poll::Ready(res) => {
40 this.timer.set(None);
42 Poll::Ready(res.map_err(SinkTimeoutError::Inner))
43 }
44 Poll::Pending => {
45 if this.timer.is_none() {
46 this.timer
48 .set(Some(futures_time::task::sleep(futures_time::time::Duration::from(
49 *this.timeout,
50 ))));
51 }
52
53 if let Some(timer) = this.timer.as_mut().as_pin_mut() {
55 futures::ready!(timer.poll(cx));
56 this.timer.set(None);
57 Poll::Ready(Err(SinkTimeoutError::Timeout))
60 } else {
61 unreachable!();
63 }
64 }
65 }
66 }
67
68 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
69 self.project().inner.start_send(item).map_err(SinkTimeoutError::Inner)
70 }
71
72 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73 self.project().inner.poll_flush(cx).map_err(SinkTimeoutError::Inner)
74 }
75
76 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77 self.project().inner.poll_close(cx).map_err(SinkTimeoutError::Inner)
78 }
79}
80
81impl<S: Clone> Clone for TimeoutSink<S> {
82 fn clone(&self) -> Self {
83 Self {
84 inner: self.inner.clone(),
85 timer: None,
86 timeout: self.timeout,
87 }
88 }
89}
90
91pub trait TimeoutSinkExt<I>: futures::Sink<I> {
93 fn with_timeout(self, timeout: std::time::Duration) -> TimeoutSink<Self>
98 where
99 Self: Sized,
100 {
101 TimeoutSink {
102 inner: self,
103 timer: None,
104 timeout,
105 }
106 }
107}
108
109impl<T: ?Sized, I> TimeoutSinkExt<I> for T where T: futures::Sink<I> {}
110
111#[pin_project::pin_project]
112pub struct ForwardWithTimeout<St, Si, Item> {
113 #[pin]
114 sink: Option<Si>,
115 #[pin]
116 stream: futures::stream::Fuse<St>,
117 buffered_item: Option<Item>,
118}
119
120impl<St: futures::Stream, Si, Item> ForwardWithTimeout<St, Si, Item> {
121 pub(crate) fn new(stream: St, sink: Si) -> Self {
122 Self {
123 sink: Some(sink),
124 stream: stream.fuse(),
125 buffered_item: None,
126 }
127 }
128}
129
130impl<St, Si, Item, E> futures::future::FusedFuture for ForwardWithTimeout<St, Si, Item>
131where
132 Si: futures::Sink<Item, Error = SinkTimeoutError<E>>,
133 St: Stream<Item = Result<Item, E>>,
134{
135 fn is_terminated(&self) -> bool {
136 self.sink.is_none()
137 }
138}
139
140impl<St, Si, Item, E> Future for ForwardWithTimeout<St, Si, Item>
141where
142 Si: futures::Sink<Item, Error = SinkTimeoutError<E>>,
143 St: Stream<Item = Result<Item, E>>,
144{
145 type Output = Result<(), E>;
146
147 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148 let mut this = self.project();
149 let mut si = this
150 .sink
151 .as_mut()
152 .as_pin_mut()
153 .expect("polled `Forward` after completion");
154
155 loop {
156 if this.buffered_item.is_some() {
159 match futures::ready!(si.as_mut().poll_ready(cx)) {
160 Ok(_) => {
161 si.as_mut()
162 .start_send(this.buffered_item.take().unwrap())
163 .map_err(|e| e.try_as_inner().unwrap())?;
164 }
165 Err(SinkTimeoutError::Timeout) => {
166 *this.buffered_item = None;
169 continue;
170 }
171 Err(SinkTimeoutError::Inner(e)) => return Poll::Ready(Err(e)),
172 }
173 }
174
175 match this.stream.as_mut().poll_next(cx)? {
176 Poll::Ready(Some(item)) => {
177 *this.buffered_item = Some(item);
178 }
179 Poll::Ready(None) => {
180 futures::ready!(si.poll_close(cx)).map_err(|e| e.try_as_inner().unwrap())?;
181 this.sink.set(None);
182 return Poll::Ready(Ok(()));
183 }
184 Poll::Pending => {
185 futures::ready!(si.poll_flush(cx)).map_err(|e| e.try_as_inner().unwrap())?;
186 return Poll::Pending;
187 }
188 }
189 }
190 }
191}
192
193pub trait TimeoutStreamExt: futures::TryStream {
196 fn forward_to_timeout<S>(self, sink: S) -> ForwardWithTimeout<Self, S, Self::Ok>
207 where
208 S: futures::Sink<Self::Ok, Error = SinkTimeoutError<Self::Error>>,
209 Self: Sized,
210 {
211 ForwardWithTimeout::new(self, sink)
212 }
213}
214
215impl<T: ?Sized> TimeoutStreamExt for T where T: futures::TryStream {}
216
217#[cfg(test)]
218mod tests {
219 use futures::SinkExt;
220
221 use super::*;
222
223 #[derive(Default)]
224 struct FixedSink<const N: usize, I>(Vec<I>);
225
226 impl<const N: usize, I> futures::Sink<I> for FixedSink<N, I> {
227 type Error = std::convert::Infallible;
228
229 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
230 if self.0.len() < N {
231 Poll::Ready(Ok(()))
232 } else {
233 Poll::Pending
234 }
235 }
236
237 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
238 let this = unsafe { self.get_unchecked_mut() };
240 this.0.push(item);
241 Ok(())
242 }
243
244 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
245 Poll::Ready(Ok(()))
246 }
247
248 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249 Poll::Ready(Ok(()))
250 }
251 }
252
253 #[test_log::test(tokio::test)]
254 async fn test_timeout_sink() -> anyhow::Result<()> {
255 let mut sink = FixedSink::<1, i32>::default();
256
257 {
258 let mut timed_sink = (&mut sink).with_timeout(std::time::Duration::from_millis(10));
259
260 timed_sink.send(10).await?;
261 assert!(matches!(timed_sink.send(20).await, Err(SinkTimeoutError::Timeout)));
262 }
263
264 assert_eq!(1, sink.0.len());
265 sink.0.remove(0);
266
267 {
268 let mut timed_sink = (&mut sink).with_timeout(std::time::Duration::from_millis(10));
269
270 timed_sink.send(10).await?;
271 assert!(matches!(timed_sink.send(20).await, Err(SinkTimeoutError::Timeout)));
272 }
273
274 Ok(())
275 }
276
277 #[test_log::test(tokio::test)]
278 async fn test_forward_with_timeout() -> anyhow::Result<()> {
279 let stream = futures::stream::iter([1, 2, 3, 4, 5]).map(Ok);
280
281 let mut sink = FixedSink::<2, i32>::default();
282
283 let start = std::time::Instant::now();
284 stream
285 .forward_to_timeout((&mut sink).with_timeout(std::time::Duration::from_millis(10)))
286 .await?;
287 assert!(
288 start.elapsed() > std::time::Duration::from_millis(29),
289 "should've taken at least 30ms"
290 );
291
292 assert_eq!(2, sink.0.len());
293 assert_eq!(1, sink.0[0]);
294 assert_eq!(2, sink.0[1]);
295
296 Ok(())
297 }
298}