hopr_network_types/session/
utils.rs1use std::pin::Pin;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::{Duration, Instant};
6
7use futures::channel::mpsc::UnboundedSender;
8use futures::stream::BoxStream;
9use futures::{AsyncRead, AsyncWrite, StreamExt};
10use rand::distributions::Bernoulli;
11use rand::prelude::{thread_rng, Distribution, Rng, SeedableRng, StdRng};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub(crate) struct RetryToken {
15 pub num_retry: usize,
16 pub started_at: Instant,
17 backoff_base: f64,
18 created_at: Instant,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub(crate) enum RetryResult {
23 Wait(Duration),
24 RetryNow(RetryToken),
25 Expired,
26}
27
28impl RetryToken {
29 pub fn new(now: Instant, backoff_base: f64) -> Self {
30 Self {
31 num_retry: 0,
32 started_at: now,
33 created_at: Instant::now(),
34 backoff_base,
35 }
36 }
37
38 pub fn replenish(self, now: Instant, backoff_base: f64) -> Self {
39 Self {
40 num_retry: 0,
41 started_at: now,
42 created_at: self.created_at,
43 backoff_base,
44 }
45 }
46
47 fn retry_in(&self, base: Duration, max_duration: Duration, jitter_dev: f64) -> Option<Duration> {
48 let jitter_coeff = if jitter_dev > 0.0 {
49 rand_distr::Normal::new(1.0, jitter_dev.min(0.25))
51 .unwrap()
52 .sample(&mut thread_rng())
53 .abs()
54 } else {
55 1.0
56 };
57
58 let duration = base.mul_f64(jitter_coeff * self.backoff_base.powi(self.num_retry as i32));
60 (duration < max_duration).then_some(duration)
61 }
62
63 pub fn check(&self, now: Instant, base: Duration, max: Duration, jitter_dev: f64) -> RetryResult {
64 match self.retry_in(base, max, jitter_dev) {
65 None => RetryResult::Expired,
66 Some(retry_in) if self.started_at + retry_in >= now => RetryResult::Wait(self.started_at + retry_in - now),
67 _ => RetryResult::RetryNow(Self {
68 num_retry: self.num_retry + 1,
69 started_at: self.started_at,
70 backoff_base: self.backoff_base,
71 created_at: self.created_at,
72 }),
73 }
74 }
75
76 pub fn time_since_creation(&self) -> Duration {
77 self.created_at.elapsed()
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq)]
82pub struct FaultyNetworkConfig {
83 pub fault_prob: f64,
84 pub mixing_factor: usize,
85 pub rng_seed: [u8; 32],
86}
87
88#[derive(Clone, Debug, Default)]
89pub struct NetworkStats {
90 pub packets_sent: Arc<AtomicUsize>,
91 pub packets_received: Arc<AtomicUsize>,
92 pub bytes_sent: Arc<AtomicUsize>,
93 pub bytes_received: Arc<AtomicUsize>,
94}
95
96impl Default for FaultyNetworkConfig {
97 fn default() -> Self {
98 Self {
99 fault_prob: 0.0,
100 mixing_factor: 0,
101 rng_seed: [
102 0xd8, 0xa4, 0x71, 0xf1, 0xc2, 0x04, 0x90, 0xa3, 0x44, 0x2b, 0x96, 0xfd, 0xde, 0x9d, 0x18, 0x07, 0x42,
103 0x80, 0x96, 0xe1, 0x60, 0x1b, 0x0c, 0xef, 0x0e, 0xea, 0x7e, 0x6d, 0x44, 0xa2, 0x4c, 0x01,
104 ],
105 }
106 }
107}
108
109pub struct FaultyNetwork<'a, const C: usize> {
111 ingress: UnboundedSender<Box<[u8]>>,
112 egress: BoxStream<'a, Box<[u8]>>,
113 stats: Option<NetworkStats>,
114}
115
116impl<const C: usize> AsyncWrite for FaultyNetwork<'_, C> {
117 fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
118 if buf.len() > C {
119 return Poll::Ready(Err(std::io::Error::new(
120 std::io::ErrorKind::InvalidInput,
121 format!("data length passed to downstream must be less or equal to {C}"),
122 )));
123 }
124
125 if let Err(e) = self.ingress.unbounded_send(buf.into()) {
126 return Poll::Ready(Err(std::io::Error::new(
127 std::io::ErrorKind::BrokenPipe,
128 format!("failed to send data: {e}"),
129 )));
130 }
131
132 if let Some(stats) = &self.stats {
133 stats.bytes_sent.fetch_add(buf.len(), Ordering::Relaxed);
134 stats.packets_sent.fetch_add(1, Ordering::Relaxed);
135 }
136
137 tracing::trace!("FaultyNetwork::poll_write {} bytes", buf.len());
138 Poll::Ready(Ok(buf.len()))
139 }
140
141 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142 Poll::Ready(Ok(()))
143 }
144
145 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
146 self.ingress.close_channel();
147 Poll::Ready(Ok(()))
148 }
149}
150
151impl<const C: usize> AsyncRead for FaultyNetwork<'_, C> {
152 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
153 match self.egress.poll_next_unpin(cx) {
154 Poll::Ready(Some(item)) => {
155 let len = buf.len().min(item.len());
156 buf[..len].copy_from_slice(&item.as_ref()[..len]);
157
158 if let Some(stats) = &self.stats {
159 stats.bytes_received.fetch_add(len, Ordering::Relaxed);
160 stats.packets_received.fetch_add(1, Ordering::Relaxed);
161 }
162
163 tracing::trace!("FaultyNetwork::poll_read: {len} bytes ready");
164 Poll::Ready(Ok(len))
165 }
166 Poll::Ready(None) => Poll::Ready(Ok(0)),
167 Poll::Pending => Poll::Pending,
168 }
169 }
170}
171
172impl<const C: usize> FaultyNetwork<'_, C> {
173 #[allow(dead_code)]
174 pub fn new(cfg: FaultyNetworkConfig, stats: Option<NetworkStats>) -> Self {
175 let (ingress, egress) = futures::channel::mpsc::unbounded::<Box<[u8]>>();
176
177 let mut rng = StdRng::from_seed(cfg.rng_seed);
178 let bernoulli = Bernoulli::new(1.0 - cfg.fault_prob).unwrap();
179 let egress = egress.filter(move |_| futures::future::ready(bernoulli.sample(&mut rng)));
180
181 let egress = if cfg.mixing_factor > 0 {
182 let mut rng = StdRng::from_seed(cfg.rng_seed);
183 egress
184 .map(move |e| {
185 let wait = rng.gen_range(0..20);
186 async move {
187 hopr_async_runtime::prelude::sleep(Duration::from_micros(wait)).await;
188 e
189 }
190 })
191 .buffer_unordered(cfg.mixing_factor)
192 .boxed()
193 } else {
194 egress.boxed()
195 };
196
197 Self { ingress, egress, stats }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use futures::io::{AsyncReadExt, AsyncWriteExt};
205 use std::future::Future;
206
207 fn spawn_single_byte_read_write<C>(
208 channel: C,
209 data: Vec<u8>,
210 ) -> (impl Future<Output = Vec<u8>>, impl Future<Output = Vec<u8>>)
211 where
212 C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
213 {
214 let (mut recv, mut send) = channel.split();
215
216 let len = data.len();
217 let read = async_std::task::spawn(async move {
218 let mut out = Vec::with_capacity(len);
219 for _ in 0..len {
220 let mut bytes = [0u8; 1];
221 if recv.read(&mut bytes).await.unwrap() > 0 {
222 out.push(bytes[0]);
223 } else {
224 break;
225 }
226 }
227 out
228 });
229
230 let written = async_std::task::spawn(async move {
231 let mut out = Vec::with_capacity(len);
232 for byte in data {
233 send.write(&[byte]).await.unwrap();
234 out.push(byte);
235 }
236 send.close().await.unwrap();
237 out
238 });
239
240 (read, written)
241 }
242
243 #[async_std::test]
244 async fn faulty_network_mixing() {
245 const MIX_FACTOR: usize = 2;
246 const COUNT: usize = 20;
247
248 let net = FaultyNetwork::<466>::new(
249 FaultyNetworkConfig {
250 mixing_factor: MIX_FACTOR,
251 ..Default::default()
252 },
253 None,
254 );
255
256 let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
257 let (read, _) = futures::future::join(read, written).await;
258
259 for (pos, value) in read.into_iter().enumerate() {
260 assert!(
261 pos.abs_diff(value as usize) <= MIX_FACTOR,
262 "packet must not be off from its position by more than then mixing factor"
263 );
264 }
265 }
266
267 #[async_std::test]
268 async fn faulty_network_packet_drop() {
269 const DROP: f64 = 0.3333;
270 const COUNT: usize = 20;
271
272 let net = FaultyNetwork::<466>::new(
273 FaultyNetworkConfig {
274 fault_prob: DROP,
275 ..Default::default()
276 },
277 None,
278 );
279
280 let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
281 let (read, written) = futures::future::join(read, written).await;
282
283 let max_drop = (written.len() as f64 * (1.0 - DROP) - 2.0).floor() as usize;
284 assert!(read.len() >= max_drop, "dropped more than {max_drop}: {}", read.len());
285 }
286
287 #[async_std::test]
288 async fn faulty_network_reliable() {
289 const COUNT: usize = 20;
290
291 let net = FaultyNetwork::<466>::new(Default::default(), None);
292
293 let (read, written) = spawn_single_byte_read_write(net, (0..COUNT as u8).collect());
294 let (read, written) = futures::future::join(read, written).await;
295
296 assert_eq!(read, written);
297 }
298}