1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc,
6 atomic::{AtomicU64, Ordering},
7 },
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11
12use futures_time::task::Sleep;
13use pin_project::pin_project;
14
15#[derive(Clone, Debug)]
17pub struct RateController(Arc<AtomicU64>);
18
19impl Default for RateController {
20 fn default() -> Self {
21 Self(Arc::new(AtomicU64::new(0)))
22 }
23}
24
25fn rate_from_delay(delay_micros: u64) -> f64 {
26 if delay_micros > 0 {
27 1.0 / Duration::from_micros(delay_micros).as_secs_f64()
28 } else {
29 0.0
30 }
31}
32
33#[allow(unused)]
34impl RateController {
35 const MIN_DELAY: Duration = Duration::from_micros(100);
36
37 pub fn new(elements_per_unit: usize, unit: Duration) -> Self {
38 let rc = RateController::default();
39 rc.set_rate_per_unit(elements_per_unit, unit);
40 rc
41 }
42
43 pub fn set_rate_per_unit(&self, elements_per_unit: usize, unit: Duration) {
45 assert!(unit > Duration::ZERO, "unit must be greater than zero");
46
47 if elements_per_unit > 0 {
48 let rate_per_sec = (elements_per_unit as f64 / unit.as_secs_f64());
50
51 let new_rate = Duration::from_secs_f64(1.0 / rate_per_sec)
53 .max(Self::MIN_DELAY)
54 .as_micros()
55 .min(u64::MAX as u128) as u64; self.0.store(new_rate, Ordering::Relaxed);
58 } else {
59 self.0.store(0, Ordering::Relaxed);
60 }
61 }
62
63 pub fn has_zero_rate(&self) -> bool {
65 self.0.load(Ordering::Relaxed) == 0
66 }
67
68 pub fn get_delay_per_element(&self) -> Duration {
70 Duration::from_micros(self.0.load(Ordering::Relaxed))
71 }
72
73 pub fn get_rate_per_sec(&self) -> f64 {
75 rate_from_delay(self.0.load(Ordering::Relaxed))
76 }
77}
78
79enum StreamState {
80 Read,
81 NoRate,
82 Wait,
83}
84
85#[must_use = "streams do nothing unless polled"]
89#[pin_project]
90pub struct RateLimitedStream<S: futures::Stream> {
91 #[pin]
92 inner: S,
93 item: Option<S::Item>,
94 #[pin]
95 delay: Option<Sleep>,
96 state: StreamState,
97 delay_time: Arc<AtomicU64>,
98}
99
100impl<S: futures::Stream> RateLimitedStream<S> {
101 pub fn new_with_controller(stream: S, controller: &RateController) -> Self {
103 Self {
104 inner: stream,
105 item: None,
106 delay: None,
107 state: if controller.0.load(Ordering::Relaxed) > 0 {
108 StreamState::Read
109 } else {
110 StreamState::NoRate
111 },
112 delay_time: controller.0.clone(),
113 }
114 }
115
116 pub fn new_with_rate_per_unit(stream: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
118 let rc = RateController::new(initial_rate_per_unit, unit);
119 (Self::new_with_controller(stream, &rc), rc)
120 }
121}
122
123impl<S, T> futures::Stream for RateLimitedStream<S>
124where
125 S: futures::Stream<Item = T> + Unpin,
126{
127 type Item = T;
128
129 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 let mut this = self.project();
131
132 loop {
133 match this.state {
134 StreamState::Read => {
135 let yield_start = Instant::now();
136 if let Some(item) = futures::ready!(this.inner.as_mut().poll_next(cx)) {
137 *this.item = Some(item);
138 let delay_time = this.delay_time.load(Ordering::Relaxed);
139 if delay_time > 0 {
140 let wait = Duration::from_micros(delay_time)
141 .saturating_sub(yield_start.elapsed())
142 .max(RateController::MIN_DELAY);
143 *this.delay = Some(futures_time::task::sleep(wait.into()));
144 *this.state = StreamState::Wait;
145 } else {
146 *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
147 *this.state = StreamState::NoRate;
148 }
149 } else {
150 return Poll::Ready(None);
151 }
152 }
153 StreamState::NoRate => {
154 if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
155 let _ = futures::ready!(delay.as_mut().poll(cx));
156 }
157 let delay_time = this.delay_time.load(Ordering::Relaxed);
158 if delay_time > 0 {
159 *this.delay = Some(futures_time::task::sleep(Duration::from_micros(delay_time).into()));
160 if this.item.is_some() {
161 *this.state = StreamState::Wait;
162 } else {
163 *this.state = StreamState::Read;
164 }
165 } else {
166 *this.delay = Some(futures_time::task::sleep(Duration::from_millis(100).into()));
167 *this.state = StreamState::NoRate;
168 }
169 }
170 StreamState::Wait => {
171 if let Some(mut delay) = this.delay.as_mut().as_pin_mut() {
172 let _ = futures::ready!(delay.as_mut().poll(cx));
173 *this.state = StreamState::Read;
174 return Poll::Ready(this.item.take());
175 }
176 }
177 }
178 }
179 }
180}
181
182pub trait RateLimitStreamExt: futures::Stream + Sized {
184 fn rate_limit_per_unit(
192 self,
193 elements_per_unit: usize,
194 unit: Duration,
195 ) -> (RateLimitedStream<Self>, RateController) {
196 RateLimitedStream::new_with_rate_per_unit(self, elements_per_unit, unit)
197 }
198
199 fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedStream<Self> {
208 RateLimitedStream::new_with_controller(self, controller)
209 }
210}
211
212impl<S: futures::Stream + Sized> RateLimitStreamExt for S {}
213
214#[must_use = "sinks do nothing unless polled"]
222#[pin_project]
223pub struct RateLimitedSink<S> {
224 #[pin]
225 inner: S,
226 delay_micros: Arc<AtomicU64>,
227 tokens: u64,
228 last_check: Option<Instant>,
229 #[pin]
230 sleep: Option<Sleep>,
231 state: SinkState,
232}
233
234enum SinkState {
235 Ready,
236 Waiting,
237}
238
239impl<S> RateLimitedSink<S> {
240 pub fn new_with_controller(inner: S, controller: &RateController) -> Self {
242 Self {
243 inner,
244 delay_micros: controller.0.clone(),
245 tokens: 0,
246 last_check: None,
247 sleep: None,
248 state: SinkState::Ready,
249 }
250 }
251
252 pub fn new_with_rate_per_unit(inner: S, initial_rate_per_unit: usize, unit: Duration) -> (Self, RateController) {
254 let rc = RateController::new(initial_rate_per_unit, unit);
255 (Self::new_with_controller(inner, &rc), rc)
256 }
257}
258
259impl<S, Item> futures::Sink<Item> for RateLimitedSink<S>
260where
261 S: futures::Sink<Item> + Unpin,
262{
263 type Error = S::Error;
264
265 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
266 let mut this = self.project();
267
268 loop {
269 let current_delay = this.delay_micros.load(Ordering::Relaxed);
270 let current_rate_limit = rate_from_delay(current_delay);
271
272 if let Some(last_check) = this.last_check.as_mut() {
273 *this.tokens += (current_rate_limit * last_check.elapsed().as_secs_f64()).round() as u64;
274 *last_check = Instant::now();
275 } else {
276 *this.last_check = Some(Instant::now());
278 }
279
280 match this.state {
281 SinkState::Ready => {
282 if *this.tokens > 0 {
283 futures::ready!(this.inner.as_mut().poll_ready(cx))?;
284
285 tracing::trace!(tokens = *this.tokens, "tokens left");
286 return Poll::Ready(Ok(()));
287 } else {
288 tracing::trace!("no tokens left");
289 *this.state = SinkState::Waiting;
290 }
291 }
292 SinkState::Waiting => {
293 if let Some(sleep) = this.sleep.as_mut().as_pin_mut() {
294 futures::ready!(sleep.poll(cx));
295 this.sleep.set(None);
296
297 tracing::trace!("waiting done");
298 *this.state = SinkState::Ready;
299 } else if current_delay > 0 {
300 *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_micros(
302 current_delay,
303 )));
304 } else {
305 *this.sleep = Some(futures_time::task::sleep(futures_time::time::Duration::from_millis(50)));
307 }
308 }
309 }
310 }
311 }
312
313 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), S::Error> {
314 let this = self.project();
315 if *this.tokens > 0 {
316 *this.tokens -= 1;
317 tracing::trace!("token consumed");
318 this.inner.start_send(item)
319 } else {
320 panic!("start_send called without poll_ready");
321 }
322 }
323
324 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
325 self.project().inner.poll_flush(cx)
326 }
327
328 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
329 self.project().inner.poll_close(cx)
330 }
331}
332
333impl<S: Clone> Clone for RateLimitedSink<S> {
334 fn clone(&self) -> Self {
335 Self {
336 inner: self.inner.clone(),
337 delay_micros: self.delay_micros.clone(),
338 tokens: 0,
339 last_check: None,
340 sleep: None,
341 state: SinkState::Ready,
342 }
343 }
344}
345
346pub trait RateLimitSinkExt<T>: futures::Sink<T> + Sized {
348 fn rate_limit_per_unit(self, elements_per_unit: usize, unit: Duration) -> (RateLimitedSink<Self>, RateController) {
356 RateLimitedSink::new_with_rate_per_unit(self, elements_per_unit, unit)
357 }
358
359 fn rate_limit_with_controller(self, controller: &RateController) -> RateLimitedSink<Self> {
368 RateLimitedSink::new_with_controller(self, controller)
369 }
370}
371
372impl<T, S: futures::Sink<T> + Sized> RateLimitSinkExt<T> for S {}
373
374#[cfg(test)]
375mod tests {
376 use std::time::{Duration, Instant};
377
378 use futures::{
379 SinkExt, pin_mut,
380 stream::{self, StreamExt},
381 };
382 use futures_time::future::FutureExt;
383
384 use super::*;
385
386 #[test]
387 fn test_rate_controller_set_rate_per_unit() {
388 let rc = RateController(Arc::new(AtomicU64::new(0)));
389 rc.set_rate_per_unit(2500, 2 * Duration::from_secs(1));
390 assert_eq!(rc.get_rate_per_sec(), 1250.0);
391 }
392
393 #[tokio::test]
394 async fn test_rate_limited_stream_respects_rate() {
395 let stream = stream::iter(1..=5);
397
398 let (rate_limited, controller) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
400
401 assert_eq!(controller.get_rate_per_sec(), 10.0);
402
403 let start = Instant::now();
404
405 let items: Vec<i32> = rate_limited.collect().await;
407
408 let elapsed = start.elapsed();
409
410 assert_eq!(items, vec![1, 2, 3, 4, 5]);
412
413 assert!(
417 elapsed >= Duration::from_millis(300),
418 "Stream completed too quickly: {elapsed:?}"
419 );
420
421 assert!(
423 elapsed <= Duration::from_millis(600),
424 "Stream completed too slowly: {elapsed:?}"
425 );
426 }
427
428 #[tokio::test]
429 async fn test_rate_limited_first_item_should_be_delayed() {
430 let stream = stream::iter(1..=1);
432
433 let (rate_limited, _) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
435
436 pin_mut!(rate_limited);
437
438 let start = Instant::now();
439 assert_eq!(Some(1), rate_limited.next().await);
440 assert!(start.elapsed() >= Duration::from_millis(100));
441 }
442
443 #[tokio::test]
444 async fn test_rate_limited_stream_empty() {
445 let stream = stream::iter::<Vec<i32>>(vec![]);
447
448 let (mut rate_limited, _) = stream.rate_limit_per_unit(10, Duration::from_secs(1));
450
451 assert_eq!(rate_limited.next().await, None);
453 }
454
455 #[tokio::test]
456 async fn test_rate_limited_stream_zero_rate() -> anyhow::Result<()> {
457 let stream = stream::iter(1..=3);
459
460 let (mut rate_limited, _) = stream.rate_limit_per_unit(0, Duration::from_millis(50));
462
463 assert!(
464 rate_limited
465 .next()
466 .timeout(futures_time::time::Duration::from_millis(100))
467 .await
468 .is_err(),
469 "zero rate should not yield anything"
470 );
471
472 Ok(())
473 }
474
475 #[tokio::test]
476 async fn test_rate_limited_stream_should_pause_on_zero_rate() -> anyhow::Result<()> {
477 let stream = stream::iter(1..=3);
479
480 let (mut rate_limited, controller) = stream.rate_limit_per_unit(1, Duration::from_millis(100));
482
483 let start = Instant::now();
484 assert_eq!(Some(1), rate_limited.next().await);
485 let elapsed = start.elapsed();
486 assert!(
487 elapsed >= Duration::from_millis(100),
488 "first element too fast {elapsed:?}"
489 );
490
491 let start = Instant::now();
492 assert_eq!(Some(2), rate_limited.next().await);
493 let elapsed = start.elapsed();
494 assert!(
495 elapsed >= Duration::from_millis(100),
496 "first element too fast {elapsed:?}"
497 );
498
499 controller.set_rate_per_unit(0, Duration::from_millis(100));
500
501 assert!(
502 rate_limited
503 .next()
504 .timeout(futures_time::time::Duration::from_millis(200))
505 .await
506 .is_err(),
507 "zero rate should not yield anything"
508 );
509
510 Ok(())
511 }
512
513 #[tokio::test]
514 async fn test_rate_limited_stream_zero_rate_should_restart_when_increased() -> anyhow::Result<()> {
515 let stream = stream::iter(1..=3);
517
518 let (mut rate_limited, controller) = stream.rate_limit_per_unit(0, Duration::from_secs(1));
520
521 assert!(
522 rate_limited
523 .next()
524 .timeout(futures_time::time::Duration::from_millis(100))
525 .await
526 .is_err(),
527 "zero rate should not yield anything"
528 );
529
530 controller.set_rate_per_unit(1, Duration::from_millis(100));
531
532 let start = Instant::now();
533 let all_items = rate_limited.collect::<Vec<_>>().await;
534 let all_items_elapsed = start.elapsed();
535
536 assert_eq!(all_items, vec![1, 2, 3]);
537 assert!(
538 all_items_elapsed >= Duration::from_millis(300),
539 "all items should have been yielded in at least 300ms instead {all_items_elapsed:?}"
540 );
541
542 Ok(())
543 }
544
545 #[tokio::test]
546 async fn test_rate_changing_during_stream() {
547 let stream = stream::iter(1..=6);
549
550 let (mut rate_limited, controller) = stream.rate_limit_per_unit(5, Duration::from_secs(1));
552
553 let start = Instant::now();
555 for i in 1..=3 {
556 assert_eq!(Some(i), rate_limited.next().await);
557 }
558
559 let first_half_elapsed = start.elapsed();
561
562 controller.set_rate_per_unit(2, Duration::from_secs(1));
564
565 for i in 4..=6 {
567 assert_eq!(Some(i), rate_limited.next().await);
568 }
569
570 let total_elapsed = start.elapsed();
572 let second_half_elapsed = total_elapsed - first_half_elapsed;
573
574 assert!(
576 first_half_elapsed >= Duration::from_millis(600),
577 "First half too fast: {first_half_elapsed:?}"
578 );
579 assert!(
580 first_half_elapsed <= Duration::from_millis(700),
581 "First half too slow: {first_half_elapsed:?}"
582 );
583
584 assert!(
586 second_half_elapsed >= Duration::from_millis(1500),
587 "Second half too fast: {second_half_elapsed:?}"
588 );
589 assert!(
590 second_half_elapsed <= Duration::from_millis(1600),
591 "Second half too slow: {second_half_elapsed:?}"
592 );
593 }
594
595 #[tokio::test]
596 async fn test_very_high_rate() {
597 let stream = stream::iter(1..=100);
599
600 let (rate_limited, _) = stream.rate_limit_per_unit(1000, Duration::from_secs(1));
602
603 let start = Instant::now();
604
605 let items: Vec<i32> = rate_limited.collect().await;
607
608 let elapsed = start.elapsed();
609
610 assert_eq!(items.len(), 100);
612 assert_eq!(items.first(), Some(&1));
613 assert_eq!(items.last(), Some(&100));
614
615 assert!(
619 elapsed >= Duration::from_millis(90),
620 "Very high rate stream finished too quickly: {elapsed:?}"
621 );
622
623 assert!(
624 elapsed < Duration::from_millis(500),
625 "Very high rate stream took too long: {elapsed:?}"
626 );
627 }
628
629 #[tokio::test]
630 async fn test_concurrent_rate_change() {
631 use futures::future::join;
632
633 let stream = stream::iter(1..=10);
635
636 let (mut rate_limited, controller) = stream.rate_limit_per_unit(2, Duration::from_secs(1));
638
639 let stream_task = async {
641 let mut count = 0;
642 let mut items = Vec::new();
643
644 while let Some(item) = rate_limited.next().await {
645 items.push(item);
646 count += 1;
647
648 if count == 3 {
650 futures_time::task::sleep(Duration::from_millis(50).into()).await;
651 }
652 }
653
654 items
655 };
656
657 let rate_change_task = async move {
659 futures_time::task::sleep(Duration::from_millis(100).into()).await;
661
662 controller.set_rate_per_unit(20, Duration::from_secs(1));
664
665 20
667 };
668
669 let (items, new_rate) = join(stream_task, rate_change_task).await;
671
672 assert_eq!(items.len(), 10, "Should have received all 10 items");
674 assert_eq!(new_rate, 20, "Rate should have been changed to 20");
675 }
676
677 #[test_log::test(tokio::test)]
678 async fn rate_limited_sink_should_respect_rate() -> anyhow::Result<()> {
679 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
680 let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
681
682 let input = (0..10).collect::<Vec<_>>();
683
684 let start = Instant::now();
685 pin_mut!(tx);
686 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
687 .timeout(futures_time::time::Duration::from_millis(500))
688 .await??;
689
690 let elapsed = start.elapsed();
691 assert!(
692 elapsed >= Duration::from_millis(200),
693 "sending took too little {elapsed:?}"
694 );
695
696 tx.close().await?;
697
698 let collected = rx.collect::<Vec<_>>().await;
699 assert_eq!(input, collected);
700
701 Ok(())
702 }
703
704 #[test_log::test(tokio::test)]
705 async fn rate_limited_sink_should_replenish_when_idle() -> anyhow::Result<()> {
706 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
707 let (tx, _) = tx.rate_limit_per_unit(5, Duration::from_millis(100));
708
709 pin_mut!(tx);
710
711 let input = (0..5).collect::<Vec<_>>();
712
713 let start = Instant::now();
714 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
715 .timeout(futures_time::time::Duration::from_millis(500))
716 .await??;
717
718 let elapsed = start.elapsed();
719 assert!(
720 elapsed < Duration::from_millis(120),
721 "sending took too much {elapsed:?}"
722 );
723
724 tokio::time::sleep(Duration::from_millis(100)).await;
725
726 let start = Instant::now();
727 tx.send_all(&mut futures::stream::iter(input.clone()).map(Ok))
728 .timeout(futures_time::time::Duration::from_millis(500))
729 .await??;
730
731 let elapsed = start.elapsed();
732 assert!(
733 elapsed < Duration::from_millis(120),
734 "sending took too much {elapsed:?}"
735 );
736
737 tx.close().await?;
738
739 let collected = rx.collect::<Vec<_>>().await;
740 assert_eq!(input.into_iter().cycle().take(10).collect::<Vec<_>>(), collected);
741
742 Ok(())
743 }
744
745 #[test_log::test(tokio::test)]
746 async fn rate_limited_sink_should_not_send_when_zero_rate() -> anyhow::Result<()> {
747 let (tx, _) = futures::channel::mpsc::unbounded::<i32>();
748 let (tx, _) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
749
750 pin_mut!(tx);
751
752 assert!(
753 tx.send(1i32)
754 .timeout(futures_time::time::Duration::from_millis(100))
755 .await
756 .is_err()
757 );
758 Ok(())
759 }
760
761 #[test_log::test(tokio::test)]
762 async fn rate_limited_sink_should_recover_after_rate_is_increased() -> anyhow::Result<()> {
763 let (tx, rx) = futures::channel::mpsc::unbounded::<i32>();
764 let (tx, ctl) = tx.rate_limit_per_unit(0, Duration::from_millis(100));
765
766 pin_mut!(tx);
767
768 assert!(
769 tx.send(1i32)
770 .timeout(futures_time::time::Duration::from_millis(100))
771 .await
772 .is_err()
773 );
774
775 ctl.set_rate_per_unit(10, Duration::from_millis(10));
776
777 tx.send(2i32)
778 .timeout(futures_time::time::Duration::from_millis(100))
779 .await??;
780
781 pin_mut!(rx);
782 assert_eq!(
783 Some(2),
784 rx.next()
785 .timeout(futures_time::time::Duration::from_millis(100))
786 .await?
787 );
788
789 Ok(())
790 }
791}