Skip to main content

hopr_transport_session/balancer/
rate_limiting.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        Arc,
6        atomic::{AtomicU64, Ordering},
7    },
8    task::{Context, Poll},
9    time::{Duration, Instant},
10};
11
12use futures_time::task::Sleep;
13use pin_project::pin_project;
14
15/// Controller for [`RateLimitedStream`] to allow dynamic controlling of the stream's rate.
16#[derive(Clone, Debug)]
17pub struct RateController(Arc<AtomicU64>);
18
19impl Default for RateController {
20    fn default() -> Self {
21        Self(Arc::new(AtomicU64::new(0)))
22    }
23}
24
25fn rate_from_delay(delay_micros: u64) -> f64 {
26    if delay_micros > 0 {
27        1.0 / Duration::from_micros(delay_micros).as_secs_f64()
28    } else {
29        0.0
30    }
31}
32
33#[allow(unused)]
34impl RateController {
35    const MIN_DELAY: Duration = Duration::from_micros(100);
36
37    pub fn new(elements_per_unit: usize, unit: Duration) -> Self {
38        let rc = RateController::default();
39        rc.set_rate_per_unit(elements_per_unit, unit);
40        rc
41    }
42
43    /// Update the rate limit (elements per unit).
44    pub fn set_rate_per_unit(&self, elements_per_unit: usize, unit: Duration) {
45        assert!(unit > Duration::ZERO, "unit must be greater than zero");
46
47        if elements_per_unit > 0 {
48            // Calculate the next allowable time based on the rate
49            let rate_per_sec = (elements_per_unit as f64 / unit.as_secs_f64());
50
51            // Convert to duration (seconds per element)
52            let new_rate = Duration::from_secs_f64(1.0 / rate_per_sec)
53                .max(Self::MIN_DELAY)
54                .as_micros()
55                .min(u64::MAX as u128) as u64; // Clamp to u64 to avoid overflow
56
57            self.0.store(new_rate, Ordering::Relaxed);
58        } else {
59            self.0.store(0, Ordering::Relaxed);
60        }
61    }
62
63    /// Checks whether the rate is set to 0 at this controller.
64    pub fn has_zero_rate(&self) -> bool {
65        self.0.load(Ordering::Relaxed) == 0
66    }
67
68    /// Gets the delay per element (inverse rate).
69    pub fn get_delay_per_element(&self) -> Duration {
70        Duration::from_micros(self.0.load(Ordering::Relaxed))
71    }
72
73    /// Get the current rate limit per time unit.
74    pub fn get_rate_per_sec(&self) -> f64 {
75        rate_from_delay(self.0.load(Ordering::Relaxed))
76    }
77}
78
79enum StreamState {
80    Read,
81    NoRate,
82    Wait,
83}
84
85/// A stream adapter that yields elements at a controlled rate, with dynamic rate adjustment.
86///
87/// See [`RateLimitStreamExt::rate_limit_per_unit`].
88#[must_use = "streams do nothing unless polled"]
89#[pin_project]
90pub struct RateLimitedStream<S: futures::Stream> {
91    #[pin]
92    inner: S,
93    item: Option<S::Item>,
94    #[pin]
95    delay: Option<Sleep>,
96    state: StreamState,
97    delay_time: Arc<AtomicU64>,
98}
99
100impl<S: futures::Stream> RateLimitedStream<S> {
101    /// Creates a stream with rate limit controllable using the given controller.
102    pub fn new_with_controller(stream: S, controller: &RateController) -> Self {
103        Self {
104            inner: stream,
105            item: None,
106            delay: None,
107            state: if controller.0.load(Ordering::Relaxed) > 0 {
108                StreamState::Read
109            } else {
110                StreamState::NoRate
111            },
112            delay_time: controller.0.clone(),
113        }
114    }
115
116    /// Creates a stream with some initial rate limit of elements per a time unit.
117    pub fn new_with_rate_per_unit(stream: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
118        let rc = RateController::new(initial_rate_per_unit, unit);
119        (Self::new_with_controller(stream, &rc), rc)
120    }
121}
122
123impl<S, T> futures::Stream for RateLimitedStream<S>
124where
125    S: futures::Stream<Item = T> + Unpin,
126{
127    type Item = T;
128
129    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130        let mut this = self.project();
131
132        loop {
133            match this.state {
134                StreamState::Read => {
135                    let yield_start = Instant::now();
136                    if let Some(item) = futures::ready!(this.inner.as_mut().poll_next(cx)) {
137                        *this.item = Some(item);
138                        let delay_time = this.delay_time.load(Ordering::Relaxed);
139                        if delay_time > 0 {
140                            let wait = Duration::from_micros(delay_time)
141                                .saturating_sub(yield_start.elapsed())
142                                .max(RateController::MIN_DELAY);
143                            *this.delay = Some(futures_time::task::sleep(wait.into()));
144                            *this.state = StreamState::Wait;
145                        } else {
146                            *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
147                            *this.state = StreamState::NoRate;
148                        }
149                    } else {
150                        return Poll::Ready(None);
151                    }
152                }
153                StreamState::NoRate => {
154                    if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
155                        let _ = futures::ready!(delay.as_mut().poll(cx));
156                    }
157                    let delay_time = this.delay_time.load(Ordering::Relaxed);
158                    if delay_time > 0 {
159                        *this.delay = Some(futures_time::task::sleep(Duration::from_micros(delay_time).into()));
160                        if this.item.is_some() {
161                            *this.state = StreamState::Wait;
162                        } else {
163                            *this.state = StreamState::Read;
164                        }
165                    } else {
166                        *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
167                        *this.state = StreamState::NoRate;
168                    }
169                }
170                StreamState::Wait => {
171                    if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
172                        let _ = futures::ready!(delay.as_mut().poll(cx));
173                        *this.state = StreamState::Read;
174                        return Poll::Ready(this.item.take());
175                    }
176                }
177            }
178        }
179    }
180}
181
182/// Extension trait to add rate limiting to any stream
183pub trait RateLimitStreamExt: futures::Stream + Sized {
184    /// Creates a rate-limited stream that yields elements at the given rate.
185    ///
186    /// The rate can be controlled dynamically during the lifetime of the stream by using
187    /// the returned [`RateController`].
188    ///
189    /// If `elements_per_unit` is 0, the stream will not yield until the limit is changed
190    /// using the [`RateController`] to a non-zero value.
191    fn rate_limit_per_unit(
192        self,
193        elements_per_unit: usize,
194        unit: Duration,
195    ) -> (RateLimitedStream<Self>, RateController) {
196        RateLimitedStream::new_with_rate_per_unit(self, elements_per_unit, unit)
197    }
198
199    /// Creates a rate-limited stream that yields elements at the given rate.
200    ///
201    /// The rate can be controlled dynamically during the lifetime of the stream by using
202    /// the given [`RateController`].
203    ///
204    /// If the `controller` has [zero rate](RateController::has_zero_rate),
205    /// the stream will not yield until the limit is
206    /// changed using the [`RateController`] to a non-zero value.
207    fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedStream<Self> {
208        RateLimitedStream::new_with_controller(self, controller)
209    }
210}
211
212impl<S: futures::Stream + Sized> RateLimitStreamExt for S {}
213
214/// A sink adapter that allows ingesting items at a controlled rate, with dynamic rate adjustment.
215///
216/// If the underlying Sink is cloneable, this object will be cloneable too, each clone having
217/// the same rate-limit. Therefore, the total rate-limit is multiple of the number of clones
218/// of this sink.
219///
220/// See [`RateLimitSinkExt::rate_limit_per_unit`].
221#[must_use = "sinks do nothing unless polled"]
222#[pin_project]
223pub struct RateLimitedSink<S> {
224    #[pin]
225    inner: S,
226    delay_micros: Arc<AtomicU64>,
227    tokens: u64,
228    last_check: Option<Instant>,
229    #[pin]
230    sleep: Option<Sleep>,
231    state: SinkState,
232}
233
234enum SinkState {
235    Ready,
236    Waiting,
237}
238
239impl<S> RateLimitedSink<S> {
240    /// Creates a sink with ingestion rate controllable using the given controller.
241    pub fn new_with_controller(inner: S, controller: &RateController) -> Self {
242        Self {
243            inner,
244            delay_micros: controller.0.clone(),
245            tokens: 0,
246            last_check: None,
247            sleep: None,
248            state: SinkState::Ready,
249        }
250    }
251
252    /// Creates a sink with some initial rate limit of elements per a time unit.
253    pub fn new_with_rate_per_unit(inner: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
254        let rc = RateController::new(initial_rate_per_unit, unit);
255        (Self::new_with_controller(inner, &rc), rc)
256    }
257}
258
259impl<S, Item> futures::Sink<Item> for RateLimitedSink<S>
260where
261    S: futures::Sink<Item> + Unpin,
262{
263    type Error = S::Error;
264
265    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
266        let mut this = self.project();
267
268        loop {
269            let current_delay = this.delay_micros.load(Ordering::Relaxed);
270            let current_rate_limit = rate_from_delay(current_delay);
271
272            if let Some(last_check) = this.last_check.as_mut() {
273                *this.tokens += (current_rate_limit * last_check.elapsed().as_secs_f64()).round() as u64;
274                *last_check = Instant::now();
275            } else {
276                // This happens only on the first poll_ready
277                *this.last_check = Some(Instant::now());
278            }
279
280            match this.state {
281                SinkState::Ready => {
282                    if *this.tokens > 0 {
283                        futures::ready!(this.inner.as_mut().poll_ready(cx))?;
284
285                        tracing::trace!(tokens = *this.tokens, "tokens left");
286                        return Poll::Ready(Ok(()));
287                    } else {
288                        tracing::trace!("no tokens left");
289                        *this.state = SinkState::Waiting;
290                    }
291                }
292                SinkState::Waiting => {
293                    if let Some(sleep) = this.sleep.as_mut().as_pin_mut() {
294                        futures::ready!(sleep.poll(cx));
295                        this.sleep.set(None);
296
297                        tracing::trace!("waiting done");
298                        *this.state = SinkState::Ready;
299                    } else if current_delay > 0 {
300                        // Sleep the minimum amount of time to replenish at least one token
301                        *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_micros(
302                            current_delay,
303                        )));
304                    } else {
305                        // Sleep for some fixed duration if the rate is 0
306                        *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_millis(50)));
307                    }
308                }
309            }
310        }
311    }
312
313    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::Error> {
314        let this = self.project();
315        if *this.tokens > 0 {
316            *this.tokens -= 1;
317            tracing::trace!("token consumed");
318            this.inner.start_send(item)
319        } else {
320            panic!("start_send called without poll_ready");
321        }
322    }
323
324    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
325        self.project().inner.poll_flush(cx)
326    }
327
328    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
329        self.project().inner.poll_close(cx)
330    }
331}
332
333impl<S: Clone> Clone for RateLimitedSink<S> {
334    fn clone(&self) -> Self {
335        Self {
336            inner: self.inner.clone(),
337            delay_micros: self.delay_micros.clone(),
338            tokens: 0,
339            last_check: None,
340            sleep: None,
341            state: SinkState::Ready,
342        }
343    }
344}
345
346/// Extension trait to add rate limiting to any sink.
347pub trait RateLimitSinkExt<T>: futures::Sink<T> + Sized {
348    /// Creates a rate-limited sink that allows ingesting items at the given rate.
349    ///
350    /// The rate can be controlled dynamically during the lifetime of the sink by using
351    /// the returned [`RateController`].
352    ///
353    /// If `elements_per_unit` is 0, the sink will not ingest items until the limit is changed
354    /// using the [`RateController`] to a non-zero value.
355    fn rate_limit_per_unit(self, elements_per_unit: usize, unit: Duration) -> (RateLimitedSink<Self>, RateController) {
356        RateLimitedSink::new_with_rate_per_unit(self, elements_per_unit, unit)
357    }
358
359    /// Creates a rate-limited sink that allows ingesting items at the given rate.
360    ///
361    /// The rate can be controlled dynamically during the lifetime of the sink by using
362    /// the given [`RateController`].
363    ///
364    /// If the `controller` has [zero rate](RateController::has_zero_rate),
365    /// the sink will not ingest items until the limit is
366    /// changed using the [`RateController`] to a non-zero value.
367    fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedSink<Self> {
368        RateLimitedSink::new_with_controller(self, controller)
369    }
370}
371
372impl<T, S: futures::Sink<T> + Sized> RateLimitSinkExt<T> for S {}
373
374#[cfg(test)]
375mod tests {
376    use std::time::{Duration, Instant};
377
378    use futures::{
379        SinkExt, pin_mut,
380        stream::{self, StreamExt},
381    };
382    use futures_time::future::FutureExt;
383
384    use super::*;
385
386    #[test]
387    fn test_rate_controller_set_rate_per_unit() {
388        let rc = RateController(Arc::new(AtomicU64::new(0)));
389        rc.set_rate_per_unit(2500, 2 * Duration::from_secs(1));
390        assert_eq!(rc.get_rate_per_sec(), 1250.0);
391    }
392
393    #[tokio::test]
394    async fn test_rate_limited_stream_respects_rate() {
395        // Create a stream with 5 elements
396        let stream = stream::iter(1..=5);
397
398        // Set a rate of 10 elements per second (100ms per element)
399        let (rate_limited, controller) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
400
401        assert_eq!(controller.get_rate_per_sec(), 10.0);
402
403        let start = Instant::now();
404
405        // Collect all elements from the stream
406        let items: Vec<i32> = rate_limited.collect().await;
407
408        let elapsed = start.elapsed();
409
410        // Verify all elements were yielded
411        assert_eq!(items, vec![1, 2, 3, 4, 5]);
412
413        // Verify the rate limiting worked
414        // With 5 elements at 10 per second; we expect ~400ms (4 delays of 100ms)
415        // We use 300ms as a lower bound to account for processing time
416        assert!(
417            elapsed >= Duration::from_millis(300),
418            "Stream completed too quickly: {elapsed:?}"
419        );
420
421        // We use 600ms as an upper bound to allow for some overhead
422        assert!(
423            elapsed <= Duration::from_millis(600),
424            "Stream completed too slowly: {elapsed:?}"
425        );
426    }
427
428    #[tokio::test]
429    async fn test_rate_limited_first_item_should_be_delayed() {
430        // Create a stream with 3 elements
431        let stream = stream::iter(1..=1);
432
433        // Set a rate of 1 element per 100ms
434        let (rate_limited, _) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
435
436        pin_mut!(rate_limited);
437
438        let start = Instant::now();
439        assert_eq!(Some(1), rate_limited.next().await);
440        assert!(start.elapsed() >= Duration::from_millis(100));
441    }
442
443    #[tokio::test]
444    async fn test_rate_limited_stream_empty() {
445        // Create an empty stream
446        let stream = stream::iter::<Vec<i32>>(vec![]);
447
448        // Apply rate limiting
449        let (mut rate_limited, _) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
450
451        // Verify we get None right away
452        assert_eq!(rate_limited.next().await, None);
453    }
454
455    #[tokio::test]
456    async fn test_rate_limited_stream_zero_rate() -> anyhow::Result<()> {
457        // Create a stream with 3 elements
458        let stream = stream::iter(1..=3);
459
460        // Set a rate of 0 elements per second (= will not yield)
461        let (mut rate_limited, _) = stream.rate_limit_per_unit(0, Duration::from_millis(50));
462
463        assert!(
464            rate_limited
465                .next()
466                .timeout(futures_time::time::Duration::from_millis(100))
467                .await
468                .is_err(),
469            "zero rate should not yield anything"
470        );
471
472        Ok(())
473    }
474
475    #[tokio::test]
476    async fn test_rate_limited_stream_should_pause_on_zero_rate() -> anyhow::Result<()> {
477        // Create a stream with 3 elements
478        let stream = stream::iter(1..=3);
479
480        // Set a rate of 1 element per 100ms
481        let (mut rate_limited, controller) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
482
483        let start = Instant::now();
484        assert_eq!(Some(1), rate_limited.next().await);
485        let elapsed = start.elapsed();
486        assert!(
487            elapsed >= Duration::from_millis(100),
488            "first element too fast {elapsed:?}"
489        );
490
491        let start = Instant::now();
492        assert_eq!(Some(2), rate_limited.next().await);
493        let elapsed = start.elapsed();
494        assert!(
495            elapsed >= Duration::from_millis(100),
496            "first element too fast {elapsed:?}"
497        );
498
499        controller.set_rate_per_unit(0, Duration::from_millis(100));
500
501        assert!(
502            rate_limited
503                .next()
504                .timeout(futures_time::time::Duration::from_millis(200))
505                .await
506                .is_err(),
507            "zero rate should not yield anything"
508        );
509
510        Ok(())
511    }
512
513    #[tokio::test]
514    async fn test_rate_limited_stream_zero_rate_should_restart_when_increased() -> anyhow::Result<()> {
515        // Create a stream with 3 elements
516        let stream = stream::iter(1..=3);
517
518        // Set a rate of 0 elements per second (= will not yield)
519        let (mut rate_limited, controller) = stream.rate_limit_per_unit(0, Duration::from_secs(1));
520
521        assert!(
522            rate_limited
523                .next()
524                .timeout(futures_time::time::Duration::from_millis(100))
525                .await
526                .is_err(),
527            "zero rate should not yield anything"
528        );
529
530        controller.set_rate_per_unit(1, Duration::from_millis(100));
531
532        let start = Instant::now();
533        let all_items = rate_limited.collect::<Vec<_>>().await;
534        let all_items_elapsed = start.elapsed();
535
536        assert_eq!(all_items, vec![1, 2, 3]);
537        assert!(
538            all_items_elapsed >= Duration::from_millis(300),
539            "all items should have been yielded in at least 300ms instead {all_items_elapsed:?}"
540        );
541
542        Ok(())
543    }
544
545    #[tokio::test]
546    async fn test_rate_changing_during_stream() {
547        // Create a stream with 6 elements
548        let stream = stream::iter(1..=6);
549
550        // Start with 5 elements per second (200ms per element)
551        let (mut rate_limited, controller) = stream.rate_limit_per_unit(5, Duration::from_secs(1));
552
553        // Consume first 3 elements
554        let start = Instant::now();
555        for i in 1..=3 {
556            assert_eq!(Some(i), rate_limited.next().await);
557        }
558
559        // Measure time for the first 3 elements
560        let first_half_elapsed = start.elapsed();
561
562        // Change rate to 2 elements per second (500ms per element)
563        controller.set_rate_per_unit(2, Duration::from_secs(1));
564
565        // Consume last 3 elements
566        for i in 4..=6 {
567            assert_eq!(Some(i), rate_limited.next().await);
568        }
569
570        // Measure total time
571        let total_elapsed = start.elapsed();
572        let second_half_elapsed = total_elapsed - first_half_elapsed;
573
574        // The first 3 elements at 5 per second should take ~600ms (200 ms each)
575        assert!(
576            first_half_elapsed >= Duration::from_millis(600),
577            "First half too fast: {first_half_elapsed:?}"
578        );
579        assert!(
580            first_half_elapsed <= Duration::from_millis(700),
581            "First half too slow: {first_half_elapsed:?}"
582        );
583
584        // The last 3 elements at 2 per second should take ~1500ms (500 ms each)
585        assert!(
586            second_half_elapsed >= Duration::from_millis(1500),
587            "Second half too fast: {second_half_elapsed:?}"
588        );
589        assert!(
590            second_half_elapsed <= Duration::from_millis(1600),
591            "Second half too slow: {second_half_elapsed:?}"
592        );
593    }
594
595    #[tokio::test]
596    async fn test_very_high_rate() {
597        // Create a stream with 100 elements
598        let stream = stream::iter(1..=100);
599
600        // Set a very high rate (1000 per second)
601        let (rate_limited, _) = stream.rate_limit_per_unit(1000, Duration::from_secs(1));
602
603        let start = Instant::now();
604
605        // Collect all elements
606        let items: Vec<i32> = rate_limited.collect().await;
607
608        let elapsed = start.elapsed();
609
610        // Verify all elements were yielded
611        assert_eq!(items.len(), 100);
612        assert_eq!(items.first(), Some(&1));
613        assert_eq!(items.last(), Some(&100));
614
615        // Even at a high rate, processing 100 elements should take some time
616        // but less than 150 ms (theoretical time would be ~99ms)
617        assert!(
618            elapsed >= Duration::from_millis(100),
619            "Very high rate stream took too long: {elapsed:?}"
620        );
621
622        assert!(
623            elapsed < Duration::from_millis(150),
624            "Very high rate stream took too long: {elapsed:?}"
625        );
626    }
627
628    #[tokio::test]
629    async fn test_concurrent_rate_change() {
630        use futures::future::join;
631
632        // Create a stream with 10 elements
633        let stream = stream::iter(1..=10);
634
635        // Start with 2 elements per second
636        let (mut rate_limited, controller) = stream.rate_limit_per_unit(2, Duration::from_secs(1));
637
638        // Set up a task to process the stream
639        let stream_task = async {
640            let mut count = 0;
641            let mut items = Vec::new();
642
643            while let Some(item) = rate_limited.next().await {
644                items.push(item);
645                count += 1;
646
647                // After 3 elements, wait briefly to allow rate change to take effect
648                if count == 3 {
649                    futures_time::task::sleep(Duration::from_millis(50).into()).await;
650                }
651            }
652
653            items
654        };
655
656        // Set up a task to change the rate after a delay
657        let rate_change_task = async move {
658            // Wait a bit for the stream to start processing
659            futures_time::task::sleep(Duration::from_millis(100).into()).await;
660
661            // Change the rate to 20 per second
662            controller.set_rate_per_unit(20, Duration::from_secs(1));
663
664            // Return the new rate
665            20
666        };
667
668        // Run both tasks concurrently
669        let (items, new_rate) = join(stream_task, rate_change_task).await;
670
671        // Verify results
672        assert_eq!(items.len(), 10, "Should have received all 10 items");
673        assert_eq!(new_rate, 20, "Rate should have been changed to 20");
674    }
675
676    #[test_log::test(tokio::test)]
677    async fn rate_limited_sink_should_respect_rate() -> anyhow::Result<()> {
678        let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
679        let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
680
681        let input = (0..10).collect::<Vec<_>>();
682
683        let start = Instant::now();
684        pin_mut!(tx);
685        tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
686            .timeout(futures_time::time::Duration::from_millis(500))
687            .await??;
688
689        let elapsed = start.elapsed();
690        assert!(
691            elapsed >= Duration::from_millis(200),
692            "sending took too little {elapsed:?}"
693        );
694
695        tx.close().await?;
696
697        let collected = rx.collect::<Vec<_>>().await;
698        assert_eq!(input, collected);
699
700        Ok(())
701    }
702
703    #[test_log::test(tokio::test)]
704    async fn rate_limited_sink_should_replenish_when_idle() -> anyhow::Result<()> {
705        let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
706        let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
707
708        pin_mut!(tx);
709
710        let input = (0..5).collect::<Vec<_>>();
711
712        let start = Instant::now();
713        tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
714            .timeout(futures_time::time::Duration::from_millis(500))
715            .await??;
716
717        let elapsed = start.elapsed();
718        assert!(
719            elapsed < Duration::from_millis(120),
720            "sending took too much {elapsed:?}"
721        );
722
723        tokio::time::sleep(Duration::from_millis(100)).await;
724
725        let start = Instant::now();
726        tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
727            .timeout(futures_time::time::Duration::from_millis(500))
728            .await??;
729
730        let elapsed = start.elapsed();
731        assert!(
732            elapsed < Duration::from_millis(120),
733            "sending took too much {elapsed:?}"
734        );
735
736        tx.close().await?;
737
738        let collected = rx.collect::<Vec<_>>().await;
739        assert_eq!(input.into_iter().cycle().take(10).collect::<Vec<_>>(), collected);
740
741        Ok(())
742    }
743
744    #[test_log::test(tokio::test)]
745    async fn rate_limited_sink_should_not_send_when_zero_rate() -> anyhow::Result<()> {
746        let (tx, _) = futures::channel::mpsc::unbounded::<i32>();
747        let (tx, _) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
748
749        pin_mut!(tx);
750
751        assert!(
752            tx.send(1i32)
753                .timeout(futures_time::time::Duration::from_millis(100))
754                .await
755                .is_err()
756        );
757        Ok(())
758    }
759
760    #[test_log::test(tokio::test)]
761    async fn rate_limited_sink_should_recover_after_rate_is_increased() -> anyhow::Result<()> {
762        let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
763        let (tx, ctl) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
764
765        pin_mut!(tx);
766
767        assert!(
768            tx.send(1i32)
769                .timeout(futures_time::time::Duration::from_millis(100))
770                .await
771                .is_err()
772        );
773
774        ctl.set_rate_per_unit(10, Duration::from_millis(10));
775
776        tx.send(2i32)
777            .timeout(futures_time::time::Duration::from_millis(100))
778            .await??;
779
780        pin_mut!(rx);
781        assert_eq!(
782            Some(2),
783            rx.next()
784                .timeout(futures_time::time::Duration::from_millis(100))
785                .await?
786        );
787
788        Ok(())
789    }
790}