hopr_network_types/
utils.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    hash::{Hash, Hasher},
4    net::SocketAddr,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures::io::{AsyncRead, AsyncWrite};
10
11/// Joins [futures::AsyncRead] and [futures::AsyncWrite] into a single object.
12pub struct DuplexIO<R, W>(pub R, pub W);
13
14impl<R, W> From<(R, W)> for DuplexIO<R, W>
15where
16    R: AsyncRead,
17    W: AsyncWrite,
18{
19    fn from(value: (R, W)) -> Self {
20        Self(value.0, value.1)
21    }
22}
23
24impl<R, W> AsyncRead for DuplexIO<R, W>
25where
26    R: AsyncRead + Unpin,
27    W: AsyncWrite + Unpin,
28{
29    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
30        let this = self.get_mut();
31        Pin::new(&mut this.0).poll_read(cx, buf)
32    }
33}
34
35impl<R, W> AsyncWrite for DuplexIO<R, W>
36where
37    R: AsyncRead + Unpin,
38    W: AsyncWrite + Unpin,
39{
40    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
41        let this = self.get_mut();
42        Pin::new(&mut this.1).poll_write(cx, buf)
43    }
44
45    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
46        let this = self.get_mut();
47        Pin::new(&mut this.1).poll_flush(cx)
48    }
49
50    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
51        let this = self.get_mut();
52        Pin::new(&mut this.1).poll_close(cx)
53    }
54}
55
56// IPv6 + ':' + 65535 = 45 + 1 + 5
57const SOCKET_ADDRESS_MAX_LEN: usize = 52;
58
59/// Caches the string representation of a SocketAddr for fast conversion to `&str`
60#[derive(Copy, Clone)]
61pub(crate) struct SocketAddrStr(SocketAddr, arrayvec::ArrayString<SOCKET_ADDRESS_MAX_LEN>);
62
63impl SocketAddrStr {
64    #[allow(dead_code)]
65    pub fn as_str(&self) -> &str {
66        self.1.as_str()
67    }
68}
69
70impl AsRef<SocketAddr> for SocketAddrStr {
71    fn as_ref(&self) -> &SocketAddr {
72        &self.0
73    }
74}
75
76impl From<SocketAddr> for SocketAddrStr {
77    fn from(value: SocketAddr) -> Self {
78        let mut cached = value.to_string();
79        cached.truncate(SOCKET_ADDRESS_MAX_LEN);
80        Self(value, cached.parse().expect("cannot fail due to truncation"))
81    }
82}
83
84impl PartialEq for SocketAddrStr {
85    fn eq(&self, other: &Self) -> bool {
86        self.0 == other.0
87    }
88}
89
90impl Eq for SocketAddrStr {}
91
92impl Debug for SocketAddrStr {
93    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94        write!(f, "{}", self.1)
95    }
96}
97
98impl Display for SocketAddrStr {
99    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100        write!(f, "{}", self.1)
101    }
102}
103
104impl PartialEq<SocketAddrStr> for SocketAddr {
105    fn eq(&self, other: &SocketAddrStr) -> bool {
106        self.eq(&other.0)
107    }
108}
109
110impl Hash for SocketAddrStr {
111    fn hash<H: Hasher>(&self, state: &mut H) {
112        self.0.hash(state);
113    }
114}
115
116#[cfg(feature = "runtime-tokio")]
117pub use tokio_utils::copy_duplex;
118
119#[cfg(feature = "runtime-tokio")]
120mod tokio_utils {
121    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
122
123    use super::*;
124
125    #[derive(Debug)]
126    enum TransferState {
127        Running(CopyBuffer),
128        ShuttingDown(u64),
129        Done(u64),
130    }
131
132    fn transfer_one_direction<A, B>(
133        cx: &mut Context<'_>,
134        state: &mut TransferState,
135        r: &mut A,
136        w: &mut B,
137    ) -> Poll<std::io::Result<u64>>
138    where
139        A: AsyncRead + AsyncWrite + Unpin + ?Sized,
140        B: AsyncRead + AsyncWrite + Unpin + ?Sized,
141    {
142        let mut r = Pin::new(r);
143        let mut w = Pin::new(w);
144        loop {
145            match state {
146                TransferState::Running(buf) => {
147                    let count = std::task::ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
148                    *state = TransferState::ShuttingDown(count);
149                }
150                TransferState::ShuttingDown(count) => {
151                    std::task::ready!(w.as_mut().poll_shutdown(cx))?;
152                    *state = TransferState::Done(*count);
153                }
154                TransferState::Done(count) => return Poll::Ready(Ok(*count)),
155            }
156        }
157    }
158
159    /// This is a proper re-implementation of Tokio's
160    /// [`copy_bidirectional_with_sizes`](tokio::io::copy_bidirectional_with_sizes), which does not leave the stream
161    /// in half-open-state when one side closes read or write side. Instead, if either side encounters and empty
162    /// read (EOF indication), the write-side is closed as well and vice versa.
163    pub async fn copy_duplex<A, B>(
164        a: &mut A,
165        b: &mut B,
166        a_to_b_buffer_size: usize,
167        b_to_a_buffer_size: usize,
168    ) -> std::io::Result<(u64, u64)>
169    where
170        A: AsyncRead + AsyncWrite + Unpin + ?Sized,
171        B: AsyncRead + AsyncWrite + Unpin + ?Sized,
172    {
173        let mut a_to_b = TransferState::Running(CopyBuffer::new(a_to_b_buffer_size));
174        let mut b_to_a = TransferState::Running(CopyBuffer::new(b_to_a_buffer_size));
175
176        std::future::poll_fn(|cx| {
177            let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
178            let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
179
180            if let TransferState::Done(_) = b_to_a {
181                if let TransferState::Running(buf) = &a_to_b {
182                    tracing::trace!("B-side has completed, terminating A-side.");
183                    a_to_b = TransferState::ShuttingDown(buf.amt);
184                    a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
185                }
186            }
187
188            if let TransferState::Done(_) = a_to_b {
189                if let TransferState::Running(buf) = &b_to_a {
190                    tracing::trace!("A-side has completed, terminate B-side.");
191                    b_to_a = TransferState::ShuttingDown(buf.amt);
192                    b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
193                }
194            }
195
196            // Not a problem if ready! returns early
197            let a_to_b_bytes_transferred = std::task::ready!(a_to_b_result);
198            let b_to_a_bytes_transferred = std::task::ready!(b_to_a_result);
199
200            Poll::Ready(Ok((a_to_b_bytes_transferred, b_to_a_bytes_transferred)))
201        })
202        .await
203    }
204
205    #[derive(Debug)]
206    struct CopyBuffer {
207        read_done: bool,
208        need_flush: bool,
209        pos: usize,
210        cap: usize,
211        amt: u64,
212        buf: Box<[u8]>,
213    }
214
215    impl CopyBuffer {
216        fn new(buf_size: usize) -> Self {
217            Self {
218                read_done: false,
219                need_flush: false,
220                pos: 0,
221                cap: 0,
222                amt: 0,
223                buf: vec![0; buf_size].into_boxed_slice(),
224            }
225        }
226
227        fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<std::io::Result<()>>
228        where
229            R: AsyncRead + ?Sized,
230        {
231            let me = &mut *self;
232            let mut buf = ReadBuf::new(&mut me.buf);
233            buf.set_filled(me.cap);
234
235            let res = reader.poll_read(cx, &mut buf);
236            if let Poll::Ready(Ok(())) = res {
237                let filled_len = buf.filled().len();
238                me.read_done = me.cap == filled_len;
239                me.cap = filled_len;
240            }
241            res
242        }
243
244        fn poll_write_buf<R, W>(
245            &mut self,
246            cx: &mut Context<'_>,
247            mut reader: Pin<&mut R>,
248            mut writer: Pin<&mut W>,
249        ) -> Poll<std::io::Result<usize>>
250        where
251            R: AsyncRead + ?Sized,
252            W: AsyncWrite + ?Sized,
253        {
254            let this = &mut *self;
255            match writer.as_mut().poll_write(cx, &this.buf[this.pos..this.cap]) {
256                Poll::Pending => {
257                    // Top up the buffer towards full if we can read a bit more
258                    // data - this should improve the chances of a large write
259                    if !this.read_done && this.cap < this.buf.len() {
260                        std::task::ready!(this.poll_fill_buf(cx, reader.as_mut()))?;
261                    }
262                    Poll::Pending
263                }
264                res @ Poll::Ready(_) => res,
265            }
266        }
267
268        pub(super) fn poll_copy<R, W>(
269            &mut self,
270            cx: &mut Context<'_>,
271            mut reader: Pin<&mut R>,
272            mut writer: Pin<&mut W>,
273        ) -> Poll<std::io::Result<u64>>
274        where
275            R: AsyncRead + ?Sized,
276            W: AsyncWrite + ?Sized,
277        {
278            loop {
279                // If our buffer is empty, then we need to read some data to
280                // continue.
281                if self.pos == self.cap && !self.read_done {
282                    self.pos = 0;
283                    self.cap = 0;
284
285                    match self.poll_fill_buf(cx, reader.as_mut()) {
286                        Poll::Ready(Ok(())) => (),
287                        Poll::Ready(Err(err)) => {
288                            return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)));
289                        }
290                        Poll::Pending => {
291                            // Try flushing when the reader has no progress to avoid deadlock
292                            // when the reader depends on a buffered writer.
293                            if self.need_flush {
294                                std::task::ready!(writer.as_mut().poll_flush(cx))?;
295                                self.need_flush = false;
296                            }
297
298                            return Poll::Pending;
299                        }
300                    }
301                }
302
303                // If our buffer has some data, let's write it out
304                while self.pos < self.cap {
305                    let i = std::task::ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
306                    if i == 0 {
307                        return Poll::Ready(Err(std::io::Error::new(
308                            std::io::ErrorKind::WriteZero,
309                            "write zero byte",
310                        )));
311                    }
312                    self.pos += i;
313                    self.amt += i as u64;
314                    self.need_flush = true;
315                }
316
317                // If pos larger than cap, this loop will never stop.
318                // In particular, a user's wrong poll_write implementation returning
319                // incorrect written length may lead to thread blocking.
320                debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");
321
322                // If we've written all the data, and we've seen EOF, flush out the
323                // data and finish the transfer.
324                if self.pos == self.cap && self.read_done {
325                    std::task::ready!(writer.as_mut().poll_flush(cx))?;
326                    return Poll::Ready(Ok(self.amt));
327                }
328            }
329        }
330    }
331}
332
333/// Converts a [`AsyncRead`] into [`Stream`] by reading at most `S` bytes
334/// in each call to [`Stream::poll_next`].
335pub struct AsyncReadStreamer<const S: usize, R>(pub R);
336
337impl<const S: usize, R: AsyncRead + Unpin> futures::Stream for AsyncReadStreamer<S, R> {
338    type Item = std::io::Result<Box<[u8]>>;
339
340    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341        let mut buffer = vec![0u8; S];
342
343        match futures::ready!(Pin::new(&mut self.0).poll_read(cx, &mut buffer)) {
344            Ok(0) => Poll::Ready(None),
345            Ok(size) => {
346                buffer.truncate(size);
347                Poll::Ready(Some(Ok(buffer.into_boxed_slice())))
348            }
349            Err(err) => Poll::Ready(Some(Err(err))),
350        }
351    }
352}
353
354#[cfg(all(feature = "runtime-tokio", test))]
355mod tests {
356    use futures::TryStreamExt;
357    use tokio::io::AsyncWriteExt;
358
359    use super::*;
360    use crate::utils::DuplexIO;
361
362    #[tokio::test]
363    async fn test_copy_duplex() -> anyhow::Result<()> {
364        const DATA_LEN: usize = 2000;
365
366        let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
367        let mut alice_rx = [0u8; DATA_LEN];
368
369        let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
370        let mut bob_rx = [0u8; DATA_LEN];
371
372        let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
373        let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
374
375        let (a_to_b, b_to_a) = copy_duplex(
376            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
377            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
378            128,
379            128,
380        )
381        .await?;
382
383        assert_eq!(DATA_LEN, a_to_b as usize);
384        assert_eq!(DATA_LEN, b_to_a as usize);
385
386        assert_eq!(alice_tx, bob_rx);
387        assert_eq!(bob_tx, alice_rx);
388
389        Ok(())
390    }
391
392    #[tokio::test]
393    async fn test_copy_duplex_small() -> anyhow::Result<()> {
394        const DATA_LEN: usize = 100;
395
396        let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
397        let mut alice_rx = [0u8; DATA_LEN];
398
399        let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
400        let mut bob_rx = [0u8; DATA_LEN];
401
402        let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
403        let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
404
405        let (a_to_b, b_to_a) = copy_duplex(
406            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
407            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
408            128,
409            128,
410        )
411        .await?;
412
413        assert_eq!(DATA_LEN, a_to_b as usize);
414        assert_eq!(DATA_LEN, b_to_a as usize);
415
416        assert_eq!(alice_tx, bob_rx);
417        assert_eq!(bob_tx, alice_rx);
418
419        Ok(())
420    }
421
422    #[tokio::test]
423    async fn test_client_to_server() -> anyhow::Result<()> {
424        let (mut client_tx, mut client_rx) = tokio::io::duplex(8); // Create a mock duplex stream
425        let (mut server_rx, mut server_tx) = tokio::io::duplex(32); // Create a mock duplex stream
426
427        // Simulate 'a' finishing while there's still data for 'b'
428        client_tx.write_all(b"hello").await?;
429        client_tx.shutdown().await?;
430
431        server_tx.write_all(b"data").await?;
432        server_tx.shutdown().await?;
433
434        let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
435
436        let (client_to_server_count, server_to_client_count) = result;
437        assert_eq!(client_to_server_count, 5); // 'hello' was transferred
438        assert_eq!(server_to_client_count, 4); // response only partially transferred or not at all
439
440        Ok(())
441    }
442
443    #[tokio::test]
444    async fn test_server_to_client() -> anyhow::Result<()> {
445        let (mut client_tx, mut client_rx) = tokio::io::duplex(32); // Create a mock duplex stream
446        let (mut server_rx, mut server_tx) = tokio::io::duplex(8); // Create a mock duplex stream
447
448        // Simulate 'a' finishing while there's still data for 'b'
449        server_tx.write_all(b"hello").await?;
450        server_tx.shutdown().await?;
451
452        client_tx.write_all(b"some longer data to transfer").await?;
453
454        let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
455
456        let (client_to_server_count, server_to_client_count) = result;
457        assert_eq!(server_to_client_count, 5); // 'hello' was transferred
458        assert!(client_to_server_count <= 8); // response only partially transferred or not at all
459
460        Ok(())
461    }
462
463    #[tokio::test]
464    async fn test_async_read_streamer_complete_chunk() {
465        let data = b"Hello, World!!";
466        let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
467        let mut results = Vec::new();
468
469        while let Some(res) = streamer.try_next().await.unwrap() {
470            results.push(res);
471        }
472
473        assert_eq!(results, vec![Box::from(*data)]);
474    }
475
476    #[tokio::test]
477    async fn test_async_read_streamer_complete_more_chunks() {
478        let data = b"Hello, World and do it twice";
479        let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
480        let mut results = Vec::new();
481
482        while let Some(res) = streamer.try_next().await.unwrap() {
483            results.push(res);
484        }
485
486        let (data1, data2) = data.split_at(14);
487        assert_eq!(results, vec![Box::from(data1), Box::from(data2)]);
488    }
489
490    #[tokio::test]
491    async fn test_async_read_streamer_complete_more_chunks_with_incomplete() -> anyhow::Result<()> {
492        let data = b"Hello, World and do it twice, ...";
493        let streamer = AsyncReadStreamer::<14, _>(&data[..]);
494
495        let results = streamer.try_collect::<Vec<_>>().await?;
496
497        let (data1, rest) = data.split_at(14);
498        let (data2, data3) = rest.split_at(14);
499        assert_eq!(results, vec![Box::from(data1), Box::from(data2), Box::from(data3)]);
500
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn test_async_read_streamer_incomplete_chunk() -> anyhow::Result<()> {
506        let data = b"Hello, World!!";
507        let reader = &data[0..8]; // An incomplete chunk
508        let mut streamer = AsyncReadStreamer::<14, _>(reader);
509
510        assert_eq!(Some(Box::from(reader)), streamer.try_next().await?);
511
512        Ok(())
513    }
514}