1use futures::io::{AsyncRead, AsyncWrite};
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8pub struct DuplexIO<R, W>(pub R, pub W);
10
11impl<R, W> From<(R, W)> for DuplexIO<R, W>
12where
13 R: AsyncRead,
14 W: AsyncWrite,
15{
16 fn from(value: (R, W)) -> Self {
17 Self(value.0, value.1)
18 }
19}
20
21impl<R, W> AsyncRead for DuplexIO<R, W>
22where
23 R: AsyncRead + Unpin,
24 W: AsyncWrite + Unpin,
25{
26 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
27 let this = self.get_mut();
28 Pin::new(&mut this.0).poll_read(cx, buf)
29 }
30}
31
32impl<R, W> AsyncWrite for DuplexIO<R, W>
33where
34 R: AsyncRead + Unpin,
35 W: AsyncWrite + Unpin,
36{
37 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
38 let this = self.get_mut();
39 Pin::new(&mut this.1).poll_write(cx, buf)
40 }
41
42 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
43 let this = self.get_mut();
44 Pin::new(&mut this.1).poll_flush(cx)
45 }
46
47 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
48 let this = self.get_mut();
49 Pin::new(&mut this.1).poll_close(cx)
50 }
51}
52
53const SOCKET_ADDRESS_MAX_LEN: usize = 52;
55
56#[derive(Copy, Clone)]
58pub(crate) struct SocketAddrStr(SocketAddr, arrayvec::ArrayString<SOCKET_ADDRESS_MAX_LEN>);
59
60impl SocketAddrStr {
61 #[allow(dead_code)]
62 pub fn as_str(&self) -> &str {
63 self.1.as_str()
64 }
65}
66
67impl AsRef<SocketAddr> for SocketAddrStr {
68 fn as_ref(&self) -> &SocketAddr {
69 &self.0
70 }
71}
72
73impl From<SocketAddr> for SocketAddrStr {
74 fn from(value: SocketAddr) -> Self {
75 let mut cached = value.to_string();
76 cached.truncate(SOCKET_ADDRESS_MAX_LEN);
77 Self(value, cached.parse().expect("cannot fail due to truncation"))
78 }
79}
80
81impl PartialEq for SocketAddrStr {
82 fn eq(&self, other: &Self) -> bool {
83 self.0 == other.0
84 }
85}
86
87impl Eq for SocketAddrStr {}
88
89impl Debug for SocketAddrStr {
90 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91 write!(f, "{}", self.1)
92 }
93}
94
95impl Display for SocketAddrStr {
96 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.1)
98 }
99}
100
101impl PartialEq<SocketAddrStr> for SocketAddr {
102 fn eq(&self, other: &SocketAddrStr) -> bool {
103 self.eq(&other.0)
104 }
105}
106
107impl Hash for SocketAddrStr {
108 fn hash<H: Hasher>(&self, state: &mut H) {
109 self.0.hash(state);
110 }
111}
112
113#[cfg(feature = "runtime-tokio")]
114pub use tokio_utils::copy_duplex;
115
116#[cfg(feature = "runtime-tokio")]
117mod tokio_utils {
118 use super::*;
119 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
120
121 #[derive(Debug)]
122 enum TransferState {
123 Running(CopyBuffer),
124 ShuttingDown(u64),
125 Done(u64),
126 }
127
128 fn transfer_one_direction<A, B>(
129 cx: &mut Context<'_>,
130 state: &mut TransferState,
131 r: &mut A,
132 w: &mut B,
133 ) -> Poll<std::io::Result<u64>>
134 where
135 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
136 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
137 {
138 let mut r = Pin::new(r);
139 let mut w = Pin::new(w);
140 loop {
141 match state {
142 TransferState::Running(buf) => {
143 let count = std::task::ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
144 *state = TransferState::ShuttingDown(count);
145 }
146 TransferState::ShuttingDown(count) => {
147 std::task::ready!(w.as_mut().poll_shutdown(cx))?;
148 *state = TransferState::Done(*count);
149 }
150 TransferState::Done(count) => return Poll::Ready(Ok(*count)),
151 }
152 }
153 }
154
155 pub async fn copy_duplex<A, B>(
160 a: &mut A,
161 b: &mut B,
162 a_to_b_buffer_size: usize,
163 b_to_a_buffer_size: usize,
164 ) -> std::io::Result<(u64, u64)>
165 where
166 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
167 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
168 {
169 let mut a_to_b = TransferState::Running(CopyBuffer::new(a_to_b_buffer_size));
170 let mut b_to_a = TransferState::Running(CopyBuffer::new(b_to_a_buffer_size));
171
172 std::future::poll_fn(|cx| {
173 let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
174 let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
175
176 if let TransferState::Done(_) = b_to_a {
177 if let TransferState::Running(buf) = &a_to_b {
178 tracing::trace!("B-side has completed, terminating A-side.");
179 a_to_b = TransferState::ShuttingDown(buf.amt);
180 a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
181 }
182 }
183
184 if let TransferState::Done(_) = a_to_b {
185 if let TransferState::Running(buf) = &b_to_a {
186 tracing::trace!("A-side has completed, terminate B-side.");
187 b_to_a = TransferState::ShuttingDown(buf.amt);
188 b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
189 }
190 }
191
192 let a_to_b_bytes_transferred = std::task::ready!(a_to_b_result);
194 let b_to_a_bytes_transferred = std::task::ready!(b_to_a_result);
195
196 Poll::Ready(Ok((a_to_b_bytes_transferred, b_to_a_bytes_transferred)))
197 })
198 .await
199 }
200
201 #[derive(Debug)]
202 struct CopyBuffer {
203 read_done: bool,
204 need_flush: bool,
205 pos: usize,
206 cap: usize,
207 amt: u64,
208 buf: Box<[u8]>,
209 }
210
211 impl CopyBuffer {
212 fn new(buf_size: usize) -> Self {
213 Self {
214 read_done: false,
215 need_flush: false,
216 pos: 0,
217 cap: 0,
218 amt: 0,
219 buf: vec![0; buf_size].into_boxed_slice(),
220 }
221 }
222
223 fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<std::io::Result<()>>
224 where
225 R: AsyncRead + ?Sized,
226 {
227 let me = &mut *self;
228 let mut buf = ReadBuf::new(&mut me.buf);
229 buf.set_filled(me.cap);
230
231 let res = reader.poll_read(cx, &mut buf);
232 if let Poll::Ready(Ok(())) = res {
233 let filled_len = buf.filled().len();
234 me.read_done = me.cap == filled_len;
235 me.cap = filled_len;
236 }
237 res
238 }
239
240 fn poll_write_buf<R, W>(
241 &mut self,
242 cx: &mut Context<'_>,
243 mut reader: Pin<&mut R>,
244 mut writer: Pin<&mut W>,
245 ) -> Poll<std::io::Result<usize>>
246 where
247 R: AsyncRead + ?Sized,
248 W: AsyncWrite + ?Sized,
249 {
250 let this = &mut *self;
251 match writer.as_mut().poll_write(cx, &this.buf[this.pos..this.cap]) {
252 Poll::Pending => {
253 if !this.read_done && this.cap < this.buf.len() {
256 std::task::ready!(this.poll_fill_buf(cx, reader.as_mut()))?;
257 }
258 Poll::Pending
259 }
260 res @ Poll::Ready(_) => res,
261 }
262 }
263
264 pub(super) fn poll_copy<R, W>(
265 &mut self,
266 cx: &mut Context<'_>,
267 mut reader: Pin<&mut R>,
268 mut writer: Pin<&mut W>,
269 ) -> Poll<std::io::Result<u64>>
270 where
271 R: AsyncRead + ?Sized,
272 W: AsyncWrite + ?Sized,
273 {
274 loop {
275 if self.pos == self.cap && !self.read_done {
278 self.pos = 0;
279 self.cap = 0;
280
281 match self.poll_fill_buf(cx, reader.as_mut()) {
282 Poll::Ready(Ok(())) => (),
283 Poll::Ready(Err(err)) => {
284 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)))
285 }
286 Poll::Pending => {
287 if self.need_flush {
290 std::task::ready!(writer.as_mut().poll_flush(cx))?;
291 self.need_flush = false;
292 }
293
294 return Poll::Pending;
295 }
296 }
297 }
298
299 while self.pos < self.cap {
301 let i = std::task::ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
302 if i == 0 {
303 return Poll::Ready(Err(std::io::Error::new(
304 std::io::ErrorKind::WriteZero,
305 "write zero byte",
306 )));
307 }
308 self.pos += i;
309 self.amt += i as u64;
310 self.need_flush = true;
311 }
312
313 debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");
317
318 if self.pos == self.cap && self.read_done {
321 std::task::ready!(writer.as_mut().poll_flush(cx))?;
322 return Poll::Ready(Ok(self.amt));
323 }
324 }
325 }
326 }
327}
328
329pub struct AsyncReadStreamer<const S: usize, R>(pub R);
332
333impl<const S: usize, R: AsyncRead + Unpin> futures::Stream for AsyncReadStreamer<S, R> {
334 type Item = std::io::Result<Box<[u8]>>;
335
336 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
337 let mut buffer = vec![0u8; S];
338
339 match futures::ready!(Pin::new(&mut self.0).poll_read(cx, &mut buffer)) {
340 Ok(0) => Poll::Ready(None),
341 Ok(size) => {
342 buffer.truncate(size);
343 Poll::Ready(Some(Ok(buffer.into_boxed_slice())))
344 }
345 Err(err) => Poll::Ready(Some(Err(err))),
346 }
347 }
348}
349
350#[cfg(all(feature = "runtime-tokio", test))]
351mod tests {
352 use super::*;
353 use crate::utils::DuplexIO;
354 use futures::TryStreamExt;
355 use tokio::io::AsyncWriteExt;
356
357 #[tokio::test]
358 async fn test_copy_duplex() -> anyhow::Result<()> {
359 const DATA_LEN: usize = 2000;
360
361 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
362 let mut alice_rx = [0u8; DATA_LEN];
363
364 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
365 let mut bob_rx = [0u8; DATA_LEN];
366
367 let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
368 let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
369
370 let (a_to_b, b_to_a) = copy_duplex(
371 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
372 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
373 128,
374 128,
375 )
376 .await?;
377
378 assert_eq!(DATA_LEN, a_to_b as usize);
379 assert_eq!(DATA_LEN, b_to_a as usize);
380
381 assert_eq!(alice_tx, bob_rx);
382 assert_eq!(bob_tx, alice_rx);
383
384 Ok(())
385 }
386
387 #[tokio::test]
388 async fn test_copy_duplex_small() -> anyhow::Result<()> {
389 const DATA_LEN: usize = 100;
390
391 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
392 let mut alice_rx = [0u8; DATA_LEN];
393
394 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
395 let mut bob_rx = [0u8; DATA_LEN];
396
397 let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
398 let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
399
400 let (a_to_b, b_to_a) = copy_duplex(
401 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
402 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
403 128,
404 128,
405 )
406 .await?;
407
408 assert_eq!(DATA_LEN, a_to_b as usize);
409 assert_eq!(DATA_LEN, b_to_a as usize);
410
411 assert_eq!(alice_tx, bob_rx);
412 assert_eq!(bob_tx, alice_rx);
413
414 Ok(())
415 }
416
417 #[tokio::test]
418 async fn test_client_to_server() -> anyhow::Result<()> {
419 let (mut client_tx, mut client_rx) = tokio::io::duplex(8); let (mut server_rx, mut server_tx) = tokio::io::duplex(32); client_tx.write_all(b"hello").await?;
424 client_tx.shutdown().await?;
425
426 server_tx.write_all(b"data").await?;
427 server_tx.shutdown().await?;
428
429 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
430
431 let (client_to_server_count, server_to_client_count) = result;
432 assert_eq!(client_to_server_count, 5); assert_eq!(server_to_client_count, 4); Ok(())
436 }
437
438 #[tokio::test]
439 async fn test_server_to_client() -> anyhow::Result<()> {
440 let (mut client_tx, mut client_rx) = tokio::io::duplex(32); let (mut server_rx, mut server_tx) = tokio::io::duplex(8); server_tx.write_all(b"hello").await?;
445 server_tx.shutdown().await?;
446
447 client_tx.write_all(b"some longer data to transfer").await?;
448
449 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
450
451 let (client_to_server_count, server_to_client_count) = result;
452 assert_eq!(server_to_client_count, 5); assert!(client_to_server_count <= 8); Ok(())
456 }
457
458 #[async_std::test]
459 async fn test_async_read_streamer_complete_chunk() {
460 let data = b"Hello, World!!";
461 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
462 let mut results = Vec::new();
463
464 while let Some(res) = streamer.try_next().await.unwrap() {
465 results.push(res);
466 }
467
468 assert_eq!(results, vec![Box::from(*data)]);
469 }
470
471 #[async_std::test]
472 async fn test_async_read_streamer_complete_more_chunks() {
473 let data = b"Hello, World and do it twice";
474 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
475 let mut results = Vec::new();
476
477 while let Some(res) = streamer.try_next().await.unwrap() {
478 results.push(res);
479 }
480
481 let (data1, data2) = data.split_at(14);
482 assert_eq!(results, vec![Box::from(data1), Box::from(data2)]);
483 }
484
485 #[async_std::test]
486 async fn test_async_read_streamer_complete_more_chunks_with_incomplete() -> anyhow::Result<()> {
487 let data = b"Hello, World and do it twice, ...";
488 let streamer = AsyncReadStreamer::<14, _>(&data[..]);
489
490 let results = streamer.try_collect::<Vec<_>>().await?;
491
492 let (data1, rest) = data.split_at(14);
493 let (data2, data3) = rest.split_at(14);
494 assert_eq!(results, vec![Box::from(data1), Box::from(data2), Box::from(data3)]);
495
496 Ok(())
497 }
498
499 #[async_std::test]
500 async fn test_async_read_streamer_incomplete_chunk() -> anyhow::Result<()> {
501 let data = b"Hello, World!!";
502 let reader = &data[0..8]; let mut streamer = AsyncReadStreamer::<14, _>(reader);
504
505 assert_eq!(Some(Box::from(reader)), streamer.try_next().await?);
506
507 Ok(())
508 }
509}