hopr_network_types/session/
utils.rs

1use std::pin::Pin;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use futures::channel::mpsc::UnboundedSender;
8use futures::stream::BoxStream;
9use futures::{AsyncRead, AsyncWrite, StreamExt};
10use rand::distributions::Bernoulli;
11use rand::prelude::{thread_rng, Distribution, Rng, SeedableRng, StdRng};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub(crate) struct RetryToken {
15    pub num_retry: usize,
16    pub started_at: Instant,
17    backoff_base: f64,
18    created_at: Instant,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub(crate) enum RetryResult {
23    Wait(Duration),
24    RetryNow(RetryToken),
25    Expired,
26}
27
28impl RetryToken {
29    pub fn new(now: Instant, backoff_base: f64) -> Self {
30        Self {
31            num_retry: 0,
32            started_at: now,
33            created_at: Instant::now(),
34            backoff_base,
35        }
36    }
37
38    pub fn replenish(self, now: Instant, backoff_base: f64) -> Self {
39        Self {
40            num_retry: 0,
41            started_at: now,
42            created_at: self.created_at,
43            backoff_base,
44        }
45    }
46
47    fn retry_in(&self, base: Duration, max_duration: Duration, jitter_dev: f64) -> Option<Duration> {
48        let jitter_coeff = if jitter_dev > 0.0 {
49            // Should not use jitter with sigma > 0.25
50            rand_distr::Normal::new(1.0, jitter_dev.min(0.25))
51                .unwrap()
52                .sample(&mut thread_rng())
53                .abs()
54        } else {
55            1.0
56        };
57
58        // jitter * base * backoff_base ^ num_retry
59        let duration = base.mul_f64(jitter_coeff * self.backoff_base.powi(self.num_retry as i32));
60        (duration < max_duration).then_some(duration)
61    }
62
63    pub fn check(&self, now: Instant, base: Duration, max: Duration, jitter_dev: f64) -> RetryResult {
64        match self.retry_in(base, max, jitter_dev) {
65            None => RetryResult::Expired,
66            Some(retry_in) if self.started_at + retry_in >= now => RetryResult::Wait(self.started_at + retry_in - now),
67            _ => RetryResult::RetryNow(Self {
68                num_retry: self.num_retry + 1,
69                started_at: self.started_at,
70                backoff_base: self.backoff_base,
71                created_at: self.created_at,
72            }),
73        }
74    }
75
76    pub fn time_since_creation(&self) -> Duration {
77        self.created_at.elapsed()
78    }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub struct FaultyNetworkConfig {
83    pub fault_prob: f64,
84    pub mixing_factor: usize,
85    pub rng_seed: [u8; 32],
86}
87
88#[derive(Clone, Debug, Default)]
89pub struct NetworkStats {
90    pub packets_sent: Arc<AtomicUsize>,
91    pub packets_received: Arc<AtomicUsize>,
92    pub bytes_sent: Arc<AtomicUsize>,
93    pub bytes_received: Arc<AtomicUsize>,
94}
95
96impl Default for FaultyNetworkConfig {
97    fn default() -> Self {
98        Self {
99            fault_prob: 0.0,
100            mixing_factor: 0,
101            rng_seed: [
102                0xd8, 0xa4, 0x71, 0xf1, 0xc2, 0x04, 0x90, 0xa3, 0x44, 0x2b, 0x96, 0xfd, 0xde, 0x9d, 0x18, 0x07, 0x42,
103                0x80, 0x96, 0xe1, 0x60, 0x1b, 0x0c, 0xef, 0x0e, 0xea, 0x7e, 0x6d, 0x44, 0xa2, 0x4c, 0x01,
104            ],
105        }
106    }
107}
108
109/// Network simulator used for testing.
110pub struct FaultyNetwork<'a, const C: usize> {
111    ingress: UnboundedSender<Box<[u8]>>,
112    egress: BoxStream<'a, Box<[u8]>>,
113    stats: Option<NetworkStats>,
114}
115
116impl<const C: usize> AsyncWrite for FaultyNetwork<'_, C> {
117    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
118        if buf.len() > C {
119            return Poll::Ready(Err(std::io::Error::new(
120                std::io::ErrorKind::InvalidInput,
121                format!("data length passed to downstream must be less or equal to {C}"),
122            )));
123        }
124
125        if let Err(e) = self.ingress.unbounded_send(buf.into()) {
126            return Poll::Ready(Err(std::io::Error::new(
127                std::io::ErrorKind::BrokenPipe,
128                format!("failed to send data: {e}"),
129            )));
130        }
131
132        if let Some(stats) = &self.stats {
133            stats.bytes_sent.fetch_add(buf.len(), Ordering::Relaxed);
134            stats.packets_sent.fetch_add(1, Ordering::Relaxed);
135        }
136
137        tracing::trace!("FaultyNetwork::poll_write {} bytes", buf.len());
138        Poll::Ready(Ok(buf.len()))
139    }
140
141    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142        Poll::Ready(Ok(()))
143    }
144
145    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
146        self.ingress.close_channel();
147        Poll::Ready(Ok(()))
148    }
149}
150
151impl<const C: usize> AsyncRead for FaultyNetwork<'_, C> {
152    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
153        match self.egress.poll_next_unpin(cx) {
154            Poll::Ready(Some(item)) => {
155                let len = buf.len().min(item.len());
156                buf[..len].copy_from_slice(&item.as_ref()[..len]);
157
158                if let Some(stats) = &self.stats {
159                    stats.bytes_received.fetch_add(len, Ordering::Relaxed);
160                    stats.packets_received.fetch_add(1, Ordering::Relaxed);
161                }
162
163                tracing::trace!("FaultyNetwork::poll_read: {len} bytes ready");
164                Poll::Ready(Ok(len))
165            }
166            Poll::Ready(None) => Poll::Ready(Ok(0)),
167            Poll::Pending => Poll::Pending,
168        }
169    }
170}
171
172impl<const C: usize> FaultyNetwork<'_, C> {
173    #[allow(dead_code)]
174    pub fn new(cfg: FaultyNetworkConfig, stats: Option<NetworkStats>) -> Self {
175        let (ingress, egress) = futures::channel::mpsc::unbounded::<Box<[u8]>>();
176
177        let mut rng = StdRng::from_seed(cfg.rng_seed);
178        let bernoulli = Bernoulli::new(1.0 - cfg.fault_prob).unwrap();
179        let egress = egress.filter(move |_| futures::future::ready(bernoulli.sample(&mut rng)));
180
181        let egress = if cfg.mixing_factor > 0 {
182            let mut rng = StdRng::from_seed(cfg.rng_seed);
183            egress
184                .map(move |e| {
185                    let wait = rng.gen_range(0..20);
186                    async move {
187                        hopr_async_runtime::prelude::sleep(Duration::from_micros(wait)).await;
188                        e
189                    }
190                })
191                .buffer_unordered(cfg.mixing_factor)
192                .boxed()
193        } else {
194            egress.boxed()
195        };
196
197        Self { ingress, egress, stats }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use futures::io::{AsyncReadExt, AsyncWriteExt};
205    use std::future::Future;
206
207    fn spawn_single_byte_read_write<C>(
208        channel: C,
209        data: Vec<u8>,
210    ) -> (impl Future<Output = Vec<u8>>, impl Future<Output = Vec<u8>>)
211    where
212        C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
213    {
214        let (mut recv, mut send) = channel.split();
215
216        let len = data.len();
217        let read = async_std::task::spawn(async move {
218            let mut out = Vec::with_capacity(len);
219            for _ in 0..len {
220                let mut bytes = [0u8; 1];
221                if recv.read(&mut bytes).await.unwrap() > 0 {
222                    out.push(bytes[0]);
223                } else {
224                    break;
225                }
226            }
227            out
228        });
229
230        let written = async_std::task::spawn(async move {
231            let mut out = Vec::with_capacity(len);
232            for byte in data {
233                send.write(&[byte]).await.unwrap();
234                out.push(byte);
235            }
236            send.close().await.unwrap();
237            out
238        });
239
240        (read, written)
241    }
242
243    #[async_std::test]
244    async fn faulty_network_mixing() {
245        const MIX_FACTOR: usize = 2;
246        const COUNT: usize = 20;
247
248        let net = FaultyNetwork::<466>::new(
249            FaultyNetworkConfig {
250                mixing_factor: MIX_FACTOR,
251                ..Default::default()
252            },
253            None,
254        );
255
256        let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
257        let (read, _) = futures::future::join(read, written).await;
258
259        for (pos, value) in read.into_iter().enumerate() {
260            assert!(
261                pos.abs_diff(value as usize) <= MIX_FACTOR,
262                "packet must not be off from its position by more than then mixing factor"
263            );
264        }
265    }
266
267    #[async_std::test]
268    async fn faulty_network_packet_drop() {
269        const DROP: f64 = 0.3333;
270        const COUNT: usize = 20;
271
272        let net = FaultyNetwork::<466>::new(
273            FaultyNetworkConfig {
274                fault_prob: DROP,
275                ..Default::default()
276            },
277            None,
278        );
279
280        let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
281        let (read, written) = futures::future::join(read, written).await;
282
283        let max_drop = (written.len() as f64 * (1.0 - DROP) - 2.0).floor() as usize;
284        assert!(read.len() >= max_drop, "dropped more than {max_drop}: {}", read.len());
285    }
286
287    #[async_std::test]
288    async fn faulty_network_reliable() {
289        const COUNT: usize = 20;
290
291        let net = FaultyNetwork::<466>::new(Default::default(), None);
292
293        let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
294        let (read, written) = futures::future::join(read, written).await;
295
296        assert_eq!(read, written);
297    }
298}