1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc,
6 atomic::{AtomicU64, Ordering},
7 },
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11
12use futures_time::task::Sleep;
13use pin_project::pin_project;
14
15#[derive(Clone, Debug)]
17pub struct RateController(Arc<AtomicU64>);
18
19impl Default for RateController {
20 fn default() -> Self {
21 Self(Arc::new(AtomicU64::new(0)))
22 }
23}
24
25fn rate_from_delay(delay_micros: u64) -> f64 {
26 if delay_micros > 0 {
27 1.0 / Duration::from_micros(delay_micros).as_secs_f64()
28 } else {
29 0.0
30 }
31}
32
33#[allow(unused)]
34impl RateController {
35 const MIN_DELAY: Duration = Duration::from_micros(100);
36
37 pub fn new(elements_per_unit: usize, unit: Duration) -> Self {
38 let rc = RateController::default();
39 rc.set_rate_per_unit(elements_per_unit, unit);
40 rc
41 }
42
43 pub fn set_rate_per_unit(&self, elements_per_unit: usize, unit: Duration) {
45 assert!(unit > Duration::ZERO, "unit must be greater than zero");
46
47 if elements_per_unit > 0 {
48 let rate_per_sec = (elements_per_unit as f64 / unit.as_secs_f64());
50
51 let new_rate = Duration::from_secs_f64(1.0 / rate_per_sec)
53 .max(Self::MIN_DELAY)
54 .as_micros()
55 .min(u64::MAX as u128) as u64; self.0.store(new_rate, Ordering::Relaxed);
58 } else {
59 self.0.store(0, Ordering::Relaxed);
60 }
61 }
62
63 pub fn has_zero_rate(&self) -> bool {
65 self.0.load(Ordering::Relaxed) == 0
66 }
67
68 pub fn get_delay_per_element(&self) -> Duration {
70 Duration::from_micros(self.0.load(Ordering::Relaxed))
71 }
72
73 pub fn get_rate_per_sec(&self) -> f64 {
75 rate_from_delay(self.0.load(Ordering::Relaxed))
76 }
77}
78
79enum StreamState {
80 Read,
81 NoRate,
82 Wait,
83}
84
85#[must_use = "streams do nothing unless polled"]
89#[pin_project]
90pub struct RateLimitedStream<S: futures::Stream> {
91 #[pin]
92 inner: S,
93 item: Option<S::Item>,
94 #[pin]
95 delay: Option<Sleep>,
96 state: StreamState,
97 delay_time: Arc<AtomicU64>,
98}
99
100impl<S: futures::Stream> RateLimitedStream<S> {
101 pub fn new_with_controller(stream: S, controller: &RateController) -> Self {
103 Self {
104 inner: stream,
105 item: None,
106 delay: None,
107 state: if controller.0.load(Ordering::Relaxed) > 0 {
108 StreamState::Read
109 } else {
110 StreamState::NoRate
111 },
112 delay_time: controller.0.clone(),
113 }
114 }
115
116 pub fn new_with_rate_per_unit(stream: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
118 let rc = RateController::new(initial_rate_per_unit, unit);
119 (Self::new_with_controller(stream, &rc), rc)
120 }
121}
122
123impl<S, T> futures::Stream for RateLimitedStream<S>
124where
125 S: futures::Stream<Item = T> + Unpin,
126{
127 type Item = T;
128
129 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 let mut this = self.project();
131
132 loop {
133 match this.state {
134 StreamState::Read => {
135 let yield_start = Instant::now();
136 if let Some(item) = futures::ready!(this.inner.as_mut().poll_next(cx)) {
137 *this.item = Some(item);
138 let delay_time = this.delay_time.load(Ordering::Relaxed);
139 if delay_time > 0 {
140 let wait = Duration::from_micros(delay_time)
141 .saturating_sub(yield_start.elapsed())
142 .max(RateController::MIN_DELAY);
143 *this.delay = Some(futures_time::task::sleep(wait.into()));
144 *this.state = StreamState::Wait;
145 } else {
146 *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
147 *this.state = StreamState::NoRate;
148 }
149 } else {
150 return Poll::Ready(None);
151 }
152 }
153 StreamState::NoRate => {
154 if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
155 let _ = futures::ready!(delay.as_mut().poll(cx));
156 }
157 let delay_time = this.delay_time.load(Ordering::Relaxed);
158 if delay_time > 0 {
159 *this.delay = Some(futures_time::task::sleep(Duration::from_micros(delay_time).into()));
160 if this.item.is_some() {
161 *this.state = StreamState::Wait;
162 } else {
163 *this.state = StreamState::Read;
164 }
165 } else {
166 *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
167 *this.state = StreamState::NoRate;
168 }
169 }
170 StreamState::Wait => {
171 if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
172 let _ = futures::ready!(delay.as_mut().poll(cx));
173 *this.state = StreamState::Read;
174 return Poll::Ready(this.item.take());
175 }
176 }
177 }
178 }
179 }
180}
181
182pub trait RateLimitStreamExt: futures::Stream + Sized {
184 fn rate_limit_per_unit(
192 self,
193 elements_per_unit: usize,
194 unit: Duration,
195 ) -> (RateLimitedStream<Self>, RateController) {
196 RateLimitedStream::new_with_rate_per_unit(self, elements_per_unit, unit)
197 }
198
199 fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedStream<Self> {
208 RateLimitedStream::new_with_controller(self, controller)
209 }
210}
211
212impl<S: futures::Stream + Sized> RateLimitStreamExt for S {}
213
214#[must_use = "sinks do nothing unless polled"]
222#[pin_project]
223pub struct RateLimitedSink<S> {
224 #[pin]
225 inner: S,
226 delay_micros: Arc<AtomicU64>,
227 tokens: u64,
228 last_check: Option<Instant>,
229 #[pin]
230 sleep: Option<Sleep>,
231 state: SinkState,
232}
233
234enum SinkState {
235 Ready,
236 Waiting,
237}
238
239impl<S> RateLimitedSink<S> {
240 pub fn new_with_controller(inner: S, controller: &RateController) -> Self {
242 Self {
243 inner,
244 delay_micros: controller.0.clone(),
245 tokens: 0,
246 last_check: None,
247 sleep: None,
248 state: SinkState::Ready,
249 }
250 }
251
252 pub fn new_with_rate_per_unit(inner: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
254 let rc = RateController::new(initial_rate_per_unit, unit);
255 (Self::new_with_controller(inner, &rc), rc)
256 }
257}
258
259impl<S, Item> futures::Sink<Item> for RateLimitedSink<S>
260where
261 S: futures::Sink<Item> + Unpin,
262{
263 type Error = S::Error;
264
265 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
266 let mut this = self.project();
267
268 loop {
269 let current_delay = this.delay_micros.load(Ordering::Relaxed);
270 let current_rate_limit = rate_from_delay(current_delay);
271
272 if let Some(last_check) = this.last_check.as_mut() {
273 *this.tokens += (current_rate_limit * last_check.elapsed().as_secs_f64()).round() as u64;
274 *last_check = Instant::now();
275 } else {
276 *this.last_check = Some(Instant::now());
278 }
279
280 match this.state {
281 SinkState::Ready => {
282 if *this.tokens > 0 {
283 futures::ready!(this.inner.as_mut().poll_ready(cx))?;
284
285 tracing::trace!(tokens = *this.tokens, "tokens left");
286 return Poll::Ready(Ok(()));
287 } else {
288 tracing::trace!("no tokens left");
289 *this.state = SinkState::Waiting;
290 }
291 }
292 SinkState::Waiting => {
293 if let Some(sleep) = this.sleep.as_mut().as_pin_mut() {
294 futures::ready!(sleep.poll(cx));
295 this.sleep.set(None);
296
297 tracing::trace!("waiting done");
298 *this.state = SinkState::Ready;
299 } else if current_delay > 0 {
300 *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_micros(
302 current_delay,
303 )));
304 } else {
305 *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_millis(50)));
307 }
308 }
309 }
310 }
311 }
312
313 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::Error> {
314 let this = self.project();
315 if *this.tokens > 0 {
316 *this.tokens -= 1;
317 tracing::trace!("token consumed");
318 this.inner.start_send(item)
319 } else {
320 panic!("start_send called without poll_ready");
321 }
322 }
323
324 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
325 self.project().inner.poll_flush(cx)
326 }
327
328 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
329 self.project().inner.poll_close(cx)
330 }
331}
332
333impl<S: Clone> Clone for RateLimitedSink<S> {
334 fn clone(&self) -> Self {
335 Self {
336 inner: self.inner.clone(),
337 delay_micros: self.delay_micros.clone(),
338 tokens: 0,
339 last_check: None,
340 sleep: None,
341 state: SinkState::Ready,
342 }
343 }
344}
345
346pub trait RateLimitSinkExt<T>: futures::Sink<T> + Sized {
348 fn rate_limit_per_unit(self, elements_per_unit: usize, unit: Duration) -> (RateLimitedSink<Self>, RateController) {
356 RateLimitedSink::new_with_rate_per_unit(self, elements_per_unit, unit)
357 }
358
359 fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedSink<Self> {
368 RateLimitedSink::new_with_controller(self, controller)
369 }
370}
371
372impl<T, S: futures::Sink<T> + Sized> RateLimitSinkExt<T> for S {}
373
374#[cfg(test)]
375mod tests {
376 use std::time::{Duration, Instant};
377
378 use futures::{
379 SinkExt, pin_mut,
380 stream::{self, StreamExt},
381 };
382 use futures_time::future::FutureExt;
383
384 use super::*;
385
386 #[test]
387 fn test_rate_controller_set_rate_per_unit() {
388 let rc = RateController(Arc::new(AtomicU64::new(0)));
389 rc.set_rate_per_unit(2500, 2 * Duration::from_secs(1));
390 assert_eq!(rc.get_rate_per_sec(), 1250.0);
391 }
392
393 #[tokio::test]
394 async fn test_rate_limited_stream_respects_rate() {
395 let stream = stream::iter(1..=5);
397
398 let (rate_limited, controller) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
400
401 assert_eq!(controller.get_rate_per_sec(), 10.0);
402
403 let start = Instant::now();
404
405 let items: Vec<i32> = rate_limited.collect().await;
407
408 let elapsed = start.elapsed();
409
410 assert_eq!(items, vec![1, 2, 3, 4, 5]);
412
413 assert!(
417 elapsed >= Duration::from_millis(300),
418 "Stream completed too quickly: {elapsed:?}"
419 );
420
421 assert!(
423 elapsed <= Duration::from_millis(600),
424 "Stream completed too slowly: {elapsed:?}"
425 );
426 }
427
428 #[tokio::test]
429 async fn test_rate_limited_first_item_should_be_delayed() {
430 let stream = stream::iter(1..=1);
432
433 let (rate_limited, _) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
435
436 pin_mut!(rate_limited);
437
438 let start = Instant::now();
439 assert_eq!(Some(1), rate_limited.next().await);
440 assert!(start.elapsed() >= Duration::from_millis(100));
441 }
442
443 #[tokio::test]
444 async fn test_rate_limited_stream_empty() {
445 let stream = stream::iter::<Vec<i32>>(vec![]);
447
448 let (mut rate_limited, _) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
450
451 assert_eq!(rate_limited.next().await, None);
453 }
454
455 #[tokio::test]
456 async fn test_rate_limited_stream_zero_rate() -> anyhow::Result<()> {
457 let stream = stream::iter(1..=3);
459
460 let (mut rate_limited, _) = stream.rate_limit_per_unit(0, Duration::from_millis(50));
462
463 assert!(
464 rate_limited
465 .next()
466 .timeout(futures_time::time::Duration::from_millis(100))
467 .await
468 .is_err(),
469 "zero rate should not yield anything"
470 );
471
472 Ok(())
473 }
474
475 #[tokio::test]
476 async fn test_rate_limited_stream_should_pause_on_zero_rate() -> anyhow::Result<()> {
477 let stream = stream::iter(1..=3);
479
480 let (mut rate_limited, controller) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
482
483 let start = Instant::now();
484 assert_eq!(Some(1), rate_limited.next().await);
485 let elapsed = start.elapsed();
486 assert!(
487 elapsed >= Duration::from_millis(100),
488 "first element too fast {elapsed:?}"
489 );
490
491 let start = Instant::now();
492 assert_eq!(Some(2), rate_limited.next().await);
493 let elapsed = start.elapsed();
494 assert!(
495 elapsed >= Duration::from_millis(100),
496 "first element too fast {elapsed:?}"
497 );
498
499 controller.set_rate_per_unit(0, Duration::from_millis(100));
500
501 assert!(
502 rate_limited
503 .next()
504 .timeout(futures_time::time::Duration::from_millis(200))
505 .await
506 .is_err(),
507 "zero rate should not yield anything"
508 );
509
510 Ok(())
511 }
512
513 #[tokio::test]
514 async fn test_rate_limited_stream_zero_rate_should_restart_when_increased() -> anyhow::Result<()> {
515 let stream = stream::iter(1..=3);
517
518 let (mut rate_limited, controller) = stream.rate_limit_per_unit(0, Duration::from_secs(1));
520
521 assert!(
522 rate_limited
523 .next()
524 .timeout(futures_time::time::Duration::from_millis(100))
525 .await
526 .is_err(),
527 "zero rate should not yield anything"
528 );
529
530 controller.set_rate_per_unit(1, Duration::from_millis(100));
531
532 let start = Instant::now();
533 let all_items = rate_limited.collect::<Vec<_>>().await;
534 let all_items_elapsed = start.elapsed();
535
536 assert_eq!(all_items, vec![1, 2, 3]);
537 assert!(
538 all_items_elapsed >= Duration::from_millis(300),
539 "all items should have been yielded in at least 300ms instead {all_items_elapsed:?}"
540 );
541
542 Ok(())
543 }
544
545 #[tokio::test]
546 async fn test_rate_changing_during_stream() {
547 let stream = stream::iter(1..=6);
549
550 let (mut rate_limited, controller) = stream.rate_limit_per_unit(5, Duration::from_secs(1));
552
553 let start = Instant::now();
555 for i in 1..=3 {
556 assert_eq!(Some(i), rate_limited.next().await);
557 }
558
559 let first_half_elapsed = start.elapsed();
561
562 controller.set_rate_per_unit(2, Duration::from_secs(1));
564
565 for i in 4..=6 {
567 assert_eq!(Some(i), rate_limited.next().await);
568 }
569
570 let total_elapsed = start.elapsed();
572 let second_half_elapsed = total_elapsed - first_half_elapsed;
573
574 assert!(
576 first_half_elapsed >= Duration::from_millis(600),
577 "First half too fast: {first_half_elapsed:?}"
578 );
579 assert!(
580 first_half_elapsed <= Duration::from_millis(700),
581 "First half too slow: {first_half_elapsed:?}"
582 );
583
584 assert!(
586 second_half_elapsed >= Duration::from_millis(1500),
587 "Second half too fast: {second_half_elapsed:?}"
588 );
589 assert!(
590 second_half_elapsed <= Duration::from_millis(1600),
591 "Second half too slow: {second_half_elapsed:?}"
592 );
593 }
594
595 #[tokio::test]
596 async fn test_very_high_rate() {
597 let stream = stream::iter(1..=100);
599
600 let (rate_limited, _) = stream.rate_limit_per_unit(1000, Duration::from_secs(1));
602
603 let start = Instant::now();
604
605 let items: Vec<i32> = rate_limited.collect().await;
607
608 let elapsed = start.elapsed();
609
610 assert_eq!(items.len(), 100);
612 assert_eq!(items.first(), Some(&1));
613 assert_eq!(items.last(), Some(&100));
614
615 assert!(
618 elapsed >= Duration::from_millis(100),
619 "Very high rate stream took too long: {elapsed:?}"
620 );
621
622 assert!(
623 elapsed < Duration::from_millis(150),
624 "Very high rate stream took too long: {elapsed:?}"
625 );
626 }
627
628 #[tokio::test]
629 async fn test_concurrent_rate_change() {
630 use futures::future::join;
631
632 let stream = stream::iter(1..=10);
634
635 let (mut rate_limited, controller) = stream.rate_limit_per_unit(2, Duration::from_secs(1));
637
638 let stream_task = async {
640 let mut count = 0;
641 let mut items = Vec::new();
642
643 while let Some(item) = rate_limited.next().await {
644 items.push(item);
645 count += 1;
646
647 if count == 3 {
649 futures_time::task::sleep(Duration::from_millis(50).into()).await;
650 }
651 }
652
653 items
654 };
655
656 let rate_change_task = async move {
658 futures_time::task::sleep(Duration::from_millis(100).into()).await;
660
661 controller.set_rate_per_unit(20, Duration::from_secs(1));
663
664 20
666 };
667
668 let (items, new_rate) = join(stream_task, rate_change_task).await;
670
671 assert_eq!(items.len(), 10, "Should have received all 10 items");
673 assert_eq!(new_rate, 20, "Rate should have been changed to 20");
674 }
675
676 #[test_log::test(tokio::test)]
677 async fn rate_limited_sink_should_respect_rate() -> anyhow::Result<()> {
678 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
679 let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
680
681 let input = (0..10).collect::<Vec<_>>();
682
683 let start = Instant::now();
684 pin_mut!(tx);
685 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
686 .timeout(futures_time::time::Duration::from_millis(500))
687 .await??;
688
689 let elapsed = start.elapsed();
690 assert!(
691 elapsed >= Duration::from_millis(200),
692 "sending took too little {elapsed:?}"
693 );
694
695 tx.close().await?;
696
697 let collected = rx.collect::<Vec<_>>().await;
698 assert_eq!(input, collected);
699
700 Ok(())
701 }
702
703 #[test_log::test(tokio::test)]
704 async fn rate_limited_sink_should_replenish_when_idle() -> anyhow::Result<()> {
705 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
706 let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
707
708 pin_mut!(tx);
709
710 let input = (0..5).collect::<Vec<_>>();
711
712 let start = Instant::now();
713 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
714 .timeout(futures_time::time::Duration::from_millis(500))
715 .await??;
716
717 let elapsed = start.elapsed();
718 assert!(
719 elapsed < Duration::from_millis(120),
720 "sending took too much {elapsed:?}"
721 );
722
723 tokio::time::sleep(Duration::from_millis(100)).await;
724
725 let start = Instant::now();
726 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
727 .timeout(futures_time::time::Duration::from_millis(500))
728 .await??;
729
730 let elapsed = start.elapsed();
731 assert!(
732 elapsed < Duration::from_millis(120),
733 "sending took too much {elapsed:?}"
734 );
735
736 tx.close().await?;
737
738 let collected = rx.collect::<Vec<_>>().await;
739 assert_eq!(input.into_iter().cycle().take(10).collect::<Vec<_>>(), collected);
740
741 Ok(())
742 }
743
744 #[test_log::test(tokio::test)]
745 async fn rate_limited_sink_should_not_send_when_zero_rate() -> anyhow::Result<()> {
746 let (tx, _) = futures::channel::mpsc::unbounded::<i32>();
747 let (tx, _) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
748
749 pin_mut!(tx);
750
751 assert!(
752 tx.send(1i32)
753 .timeout(futures_time::time::Duration::from_millis(100))
754 .await
755 .is_err()
756 );
757 Ok(())
758 }
759
760 #[test_log::test(tokio::test)]
761 async fn rate_limited_sink_should_recover_after_rate_is_increased() -> anyhow::Result<()> {
762 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
763 let (tx, ctl) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
764
765 pin_mut!(tx);
766
767 assert!(
768 tx.send(1i32)
769 .timeout(futures_time::time::Duration::from_millis(100))
770 .await
771 .is_err()
772 );
773
774 ctl.set_rate_per_unit(10, Duration::from_millis(10));
775
776 tx.send(2i32)
777 .timeout(futures_time::time::Duration::from_millis(100))
778 .await??;
779
780 pin_mut!(rx);
781 assert_eq!(
782 Some(2),
783 rx.next()
784 .timeout(futures_time::time::Duration::from_millis(100))
785 .await?
786 );
787
788 Ok(())
789 }
790}