hopr_network_types/
utils.rs

1use futures::io::{AsyncRead, AsyncWrite};
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8/// Joins [futures::AsyncRead] and [futures::AsyncWrite] into a single object.
9pub struct DuplexIO<R, W>(pub R, pub W);
10
11impl<R, W> From<(R, W)> for DuplexIO<R, W>
12where
13    R: AsyncRead,
14    W: AsyncWrite,
15{
16    fn from(value: (R, W)) -> Self {
17        Self(value.0, value.1)
18    }
19}
20
21impl<R, W> AsyncRead for DuplexIO<R, W>
22where
23    R: AsyncRead + Unpin,
24    W: AsyncWrite + Unpin,
25{
26    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
27        let this = self.get_mut();
28        Pin::new(&mut this.0).poll_read(cx, buf)
29    }
30}
31
32impl<R, W> AsyncWrite for DuplexIO<R, W>
33where
34    R: AsyncRead + Unpin,
35    W: AsyncWrite + Unpin,
36{
37    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
38        let this = self.get_mut();
39        Pin::new(&mut this.1).poll_write(cx, buf)
40    }
41
42    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
43        let this = self.get_mut();
44        Pin::new(&mut this.1).poll_flush(cx)
45    }
46
47    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
48        let this = self.get_mut();
49        Pin::new(&mut this.1).poll_close(cx)
50    }
51}
52
53// IPv6 + ':' + 65535 = 45 + 1 + 5
54const SOCKET_ADDRESS_MAX_LEN: usize = 52;
55
56/// Caches the string representation of a SocketAddr for fast conversion to `&str`
57#[derive(Copy, Clone)]
58pub(crate) struct SocketAddrStr(SocketAddr, arrayvec::ArrayString<SOCKET_ADDRESS_MAX_LEN>);
59
60impl SocketAddrStr {
61    #[allow(dead_code)]
62    pub fn as_str(&self) -> &str {
63        self.1.as_str()
64    }
65}
66
67impl AsRef<SocketAddr> for SocketAddrStr {
68    fn as_ref(&self) -> &SocketAddr {
69        &self.0
70    }
71}
72
73impl From<SocketAddr> for SocketAddrStr {
74    fn from(value: SocketAddr) -> Self {
75        let mut cached = value.to_string();
76        cached.truncate(SOCKET_ADDRESS_MAX_LEN);
77        Self(value, cached.parse().expect("cannot fail due to truncation"))
78    }
79}
80
81impl PartialEq for SocketAddrStr {
82    fn eq(&self, other: &Self) -> bool {
83        self.0 == other.0
84    }
85}
86
87impl Eq for SocketAddrStr {}
88
89impl Debug for SocketAddrStr {
90    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91        write!(f, "{}", self.1)
92    }
93}
94
95impl Display for SocketAddrStr {
96    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97        write!(f, "{}", self.1)
98    }
99}
100
101impl PartialEq<SocketAddrStr> for SocketAddr {
102    fn eq(&self, other: &SocketAddrStr) -> bool {
103        self.eq(&other.0)
104    }
105}
106
107impl Hash for SocketAddrStr {
108    fn hash<H: Hasher>(&self, state: &mut H) {
109        self.0.hash(state);
110    }
111}
112
113#[cfg(feature = "runtime-tokio")]
114pub use tokio_utils::copy_duplex;
115
116#[cfg(feature = "runtime-tokio")]
117mod tokio_utils {
118    use super::*;
119    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
120
121    #[derive(Debug)]
122    enum TransferState {
123        Running(CopyBuffer),
124        ShuttingDown(u64),
125        Done(u64),
126    }
127
128    fn transfer_one_direction<A, B>(
129        cx: &mut Context<'_>,
130        state: &mut TransferState,
131        r: &mut A,
132        w: &mut B,
133    ) -> Poll<std::io::Result<u64>>
134    where
135        A: AsyncRead + AsyncWrite + Unpin + ?Sized,
136        B: AsyncRead + AsyncWrite + Unpin + ?Sized,
137    {
138        let mut r = Pin::new(r);
139        let mut w = Pin::new(w);
140        loop {
141            match state {
142                TransferState::Running(buf) => {
143                    let count = std::task::ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
144                    *state = TransferState::ShuttingDown(count);
145                }
146                TransferState::ShuttingDown(count) => {
147                    std::task::ready!(w.as_mut().poll_shutdown(cx))?;
148                    *state = TransferState::Done(*count);
149                }
150                TransferState::Done(count) => return Poll::Ready(Ok(*count)),
151            }
152        }
153    }
154
155    /// This is a proper re-implementation of Tokio's [`copy_bidirectional_with_sizes`](tokio::io::copy_bidirectional_with_sizes),
156    /// which does not leave the stream in half-open-state when one side closes read or write side.
157    /// Instead, if either side encounters and empty read (EOF indication), the write-side is closed as well
158    /// and vice versa.
159    pub async fn copy_duplex<A, B>(
160        a: &mut A,
161        b: &mut B,
162        a_to_b_buffer_size: usize,
163        b_to_a_buffer_size: usize,
164    ) -> std::io::Result<(u64, u64)>
165    where
166        A: AsyncRead + AsyncWrite + Unpin + ?Sized,
167        B: AsyncRead + AsyncWrite + Unpin + ?Sized,
168    {
169        let mut a_to_b = TransferState::Running(CopyBuffer::new(a_to_b_buffer_size));
170        let mut b_to_a = TransferState::Running(CopyBuffer::new(b_to_a_buffer_size));
171
172        std::future::poll_fn(|cx| {
173            let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
174            let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
175
176            if let TransferState::Done(_) = b_to_a {
177                if let TransferState::Running(buf) = &a_to_b {
178                    tracing::trace!("B-side has completed, terminating A-side.");
179                    a_to_b = TransferState::ShuttingDown(buf.amt);
180                    a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
181                }
182            }
183
184            if let TransferState::Done(_) = a_to_b {
185                if let TransferState::Running(buf) = &b_to_a {
186                    tracing::trace!("A-side has completed, terminate B-side.");
187                    b_to_a = TransferState::ShuttingDown(buf.amt);
188                    b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
189                }
190            }
191
192            // Not a problem if ready! returns early
193            let a_to_b_bytes_transferred = std::task::ready!(a_to_b_result);
194            let b_to_a_bytes_transferred = std::task::ready!(b_to_a_result);
195
196            Poll::Ready(Ok((a_to_b_bytes_transferred, b_to_a_bytes_transferred)))
197        })
198        .await
199    }
200
201    #[derive(Debug)]
202    struct CopyBuffer {
203        read_done: bool,
204        need_flush: bool,
205        pos: usize,
206        cap: usize,
207        amt: u64,
208        buf: Box<[u8]>,
209    }
210
211    impl CopyBuffer {
212        fn new(buf_size: usize) -> Self {
213            Self {
214                read_done: false,
215                need_flush: false,
216                pos: 0,
217                cap: 0,
218                amt: 0,
219                buf: vec![0; buf_size].into_boxed_slice(),
220            }
221        }
222
223        fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<std::io::Result<()>>
224        where
225            R: AsyncRead + ?Sized,
226        {
227            let me = &mut *self;
228            let mut buf = ReadBuf::new(&mut me.buf);
229            buf.set_filled(me.cap);
230
231            let res = reader.poll_read(cx, &mut buf);
232            if let Poll::Ready(Ok(())) = res {
233                let filled_len = buf.filled().len();
234                me.read_done = me.cap == filled_len;
235                me.cap = filled_len;
236            }
237            res
238        }
239
240        fn poll_write_buf<R, W>(
241            &mut self,
242            cx: &mut Context<'_>,
243            mut reader: Pin<&mut R>,
244            mut writer: Pin<&mut W>,
245        ) -> Poll<std::io::Result<usize>>
246        where
247            R: AsyncRead + ?Sized,
248            W: AsyncWrite + ?Sized,
249        {
250            let this = &mut *self;
251            match writer.as_mut().poll_write(cx, &this.buf[this.pos..this.cap]) {
252                Poll::Pending => {
253                    // Top up the buffer towards full if we can read a bit more
254                    // data - this should improve the chances of a large write
255                    if !this.read_done && this.cap < this.buf.len() {
256                        std::task::ready!(this.poll_fill_buf(cx, reader.as_mut()))?;
257                    }
258                    Poll::Pending
259                }
260                res @ Poll::Ready(_) => res,
261            }
262        }
263
264        pub(super) fn poll_copy<R, W>(
265            &mut self,
266            cx: &mut Context<'_>,
267            mut reader: Pin<&mut R>,
268            mut writer: Pin<&mut W>,
269        ) -> Poll<std::io::Result<u64>>
270        where
271            R: AsyncRead + ?Sized,
272            W: AsyncWrite + ?Sized,
273        {
274            loop {
275                // If our buffer is empty, then we need to read some data to
276                // continue.
277                if self.pos == self.cap && !self.read_done {
278                    self.pos = 0;
279                    self.cap = 0;
280
281                    match self.poll_fill_buf(cx, reader.as_mut()) {
282                        Poll::Ready(Ok(())) => (),
283                        Poll::Ready(Err(err)) => {
284                            return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)))
285                        }
286                        Poll::Pending => {
287                            // Try flushing when the reader has no progress to avoid deadlock
288                            // when the reader depends on a buffered writer.
289                            if self.need_flush {
290                                std::task::ready!(writer.as_mut().poll_flush(cx))?;
291                                self.need_flush = false;
292                            }
293
294                            return Poll::Pending;
295                        }
296                    }
297                }
298
299                // If our buffer has some data, let's write it out
300                while self.pos < self.cap {
301                    let i = std::task::ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
302                    if i == 0 {
303                        return Poll::Ready(Err(std::io::Error::new(
304                            std::io::ErrorKind::WriteZero,
305                            "write zero byte",
306                        )));
307                    }
308                    self.pos += i;
309                    self.amt += i as u64;
310                    self.need_flush = true;
311                }
312
313                // If pos larger than cap, this loop will never stop.
314                // In particular, a user's wrong poll_write implementation returning
315                // incorrect written length may lead to thread blocking.
316                debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");
317
318                // If we've written all the data, and we've seen EOF, flush out the
319                // data and finish the transfer.
320                if self.pos == self.cap && self.read_done {
321                    std::task::ready!(writer.as_mut().poll_flush(cx))?;
322                    return Poll::Ready(Ok(self.amt));
323                }
324            }
325        }
326    }
327}
328
329/// Converts a [`AsyncRead`] into [`Stream`] by reading at most `S` bytes
330/// in each call to [`Stream::poll_next`].
331pub struct AsyncReadStreamer<const S: usize, R>(pub R);
332
333impl<const S: usize, R: AsyncRead + Unpin> futures::Stream for AsyncReadStreamer<S, R> {
334    type Item = std::io::Result<Box<[u8]>>;
335
336    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
337        let mut buffer = vec![0u8; S];
338
339        match futures::ready!(Pin::new(&mut self.0).poll_read(cx, &mut buffer)) {
340            Ok(0) => Poll::Ready(None),
341            Ok(size) => {
342                buffer.truncate(size);
343                Poll::Ready(Some(Ok(buffer.into_boxed_slice())))
344            }
345            Err(err) => Poll::Ready(Some(Err(err))),
346        }
347    }
348}
349
350#[cfg(all(feature = "runtime-tokio", test))]
351mod tests {
352    use super::*;
353    use crate::utils::DuplexIO;
354    use futures::TryStreamExt;
355    use tokio::io::AsyncWriteExt;
356
357    #[tokio::test]
358    async fn test_copy_duplex() -> anyhow::Result<()> {
359        const DATA_LEN: usize = 2000;
360
361        let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
362        let mut alice_rx = [0u8; DATA_LEN];
363
364        let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
365        let mut bob_rx = [0u8; DATA_LEN];
366
367        let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
368        let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
369
370        let (a_to_b, b_to_a) = copy_duplex(
371            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
372            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
373            128,
374            128,
375        )
376        .await?;
377
378        assert_eq!(DATA_LEN, a_to_b as usize);
379        assert_eq!(DATA_LEN, b_to_a as usize);
380
381        assert_eq!(alice_tx, bob_rx);
382        assert_eq!(bob_tx, alice_rx);
383
384        Ok(())
385    }
386
387    #[tokio::test]
388    async fn test_copy_duplex_small() -> anyhow::Result<()> {
389        const DATA_LEN: usize = 100;
390
391        let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
392        let mut alice_rx = [0u8; DATA_LEN];
393
394        let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
395        let mut bob_rx = [0u8; DATA_LEN];
396
397        let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
398        let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
399
400        let (a_to_b, b_to_a) = copy_duplex(
401            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
402            &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
403            128,
404            128,
405        )
406        .await?;
407
408        assert_eq!(DATA_LEN, a_to_b as usize);
409        assert_eq!(DATA_LEN, b_to_a as usize);
410
411        assert_eq!(alice_tx, bob_rx);
412        assert_eq!(bob_tx, alice_rx);
413
414        Ok(())
415    }
416
417    #[tokio::test]
418    async fn test_client_to_server() -> anyhow::Result<()> {
419        let (mut client_tx, mut client_rx) = tokio::io::duplex(8); // Create a mock duplex stream
420        let (mut server_rx, mut server_tx) = tokio::io::duplex(32); // Create a mock duplex stream
421
422        // Simulate 'a' finishing while there's still data for 'b'
423        client_tx.write_all(b"hello").await?;
424        client_tx.shutdown().await?;
425
426        server_tx.write_all(b"data").await?;
427        server_tx.shutdown().await?;
428
429        let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
430
431        let (client_to_server_count, server_to_client_count) = result;
432        assert_eq!(client_to_server_count, 5); // 'hello' was transferred
433        assert_eq!(server_to_client_count, 4); // response only partially transferred or not at all
434
435        Ok(())
436    }
437
438    #[tokio::test]
439    async fn test_server_to_client() -> anyhow::Result<()> {
440        let (mut client_tx, mut client_rx) = tokio::io::duplex(32); // Create a mock duplex stream
441        let (mut server_rx, mut server_tx) = tokio::io::duplex(8); // Create a mock duplex stream
442
443        // Simulate 'a' finishing while there's still data for 'b'
444        server_tx.write_all(b"hello").await?;
445        server_tx.shutdown().await?;
446
447        client_tx.write_all(b"some longer data to transfer").await?;
448
449        let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
450
451        let (client_to_server_count, server_to_client_count) = result;
452        assert_eq!(server_to_client_count, 5); // 'hello' was transferred
453        assert!(client_to_server_count <= 8); // response only partially transferred or not at all
454
455        Ok(())
456    }
457
458    #[async_std::test]
459    async fn test_async_read_streamer_complete_chunk() {
460        let data = b"Hello, World!!";
461        let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
462        let mut results = Vec::new();
463
464        while let Some(res) = streamer.try_next().await.unwrap() {
465            results.push(res);
466        }
467
468        assert_eq!(results, vec![Box::from(*data)]);
469    }
470
471    #[async_std::test]
472    async fn test_async_read_streamer_complete_more_chunks() {
473        let data = b"Hello, World and do it twice";
474        let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
475        let mut results = Vec::new();
476
477        while let Some(res) = streamer.try_next().await.unwrap() {
478            results.push(res);
479        }
480
481        let (data1, data2) = data.split_at(14);
482        assert_eq!(results, vec![Box::from(data1), Box::from(data2)]);
483    }
484
485    #[async_std::test]
486    async fn test_async_read_streamer_complete_more_chunks_with_incomplete() -> anyhow::Result<()> {
487        let data = b"Hello, World and do it twice, ...";
488        let streamer = AsyncReadStreamer::<14, _>(&data[..]);
489
490        let results = streamer.try_collect::<Vec<_>>().await?;
491
492        let (data1, rest) = data.split_at(14);
493        let (data2, data3) = rest.split_at(14);
494        assert_eq!(results, vec![Box::from(data1), Box::from(data2), Box::from(data3)]);
495
496        Ok(())
497    }
498
499    #[async_std::test]
500    async fn test_async_read_streamer_incomplete_chunk() -> anyhow::Result<()> {
501        let data = b"Hello, World!!";
502        let reader = &data[0..8]; // An incomplete chunk
503        let mut streamer = AsyncReadStreamer::<14, _>(reader);
504
505        assert_eq!(Some(Box::from(reader)), streamer.try_next().await?);
506
507        Ok(())
508    }
509}