1use std::{
2 fmt::{Debug, Display, Formatter},
3 hash::{Hash, Hasher},
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures::io::{AsyncRead, AsyncWrite};
10
11pub struct DuplexIO<R, W>(pub R, pub W);
13
14impl<R, W> From<(R, W)> for DuplexIO<R, W>
15where
16 R: AsyncRead,
17 W: AsyncWrite,
18{
19 fn from(value: (R, W)) -> Self {
20 Self(value.0, value.1)
21 }
22}
23
24impl<R, W> AsyncRead for DuplexIO<R, W>
25where
26 R: AsyncRead + Unpin,
27 W: AsyncWrite + Unpin,
28{
29 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
30 let this = self.get_mut();
31 Pin::new(&mut this.0).poll_read(cx, buf)
32 }
33}
34
35impl<R, W> AsyncWrite for DuplexIO<R, W>
36where
37 R: AsyncRead + Unpin,
38 W: AsyncWrite + Unpin,
39{
40 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
41 let this = self.get_mut();
42 Pin::new(&mut this.1).poll_write(cx, buf)
43 }
44
45 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
46 let this = self.get_mut();
47 Pin::new(&mut this.1).poll_flush(cx)
48 }
49
50 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
51 let this = self.get_mut();
52 Pin::new(&mut this.1).poll_close(cx)
53 }
54}
55
56const SOCKET_ADDRESS_MAX_LEN: usize = 52;
58
59#[derive(Copy, Clone)]
61pub(crate) struct SocketAddrStr(SocketAddr, arrayvec::ArrayString<SOCKET_ADDRESS_MAX_LEN>);
62
63impl SocketAddrStr {
64 #[allow(dead_code)]
65 pub fn as_str(&self) -> &str {
66 self.1.as_str()
67 }
68}
69
70impl AsRef<SocketAddr> for SocketAddrStr {
71 fn as_ref(&self) -> &SocketAddr {
72 &self.0
73 }
74}
75
76impl From<SocketAddr> for SocketAddrStr {
77 fn from(value: SocketAddr) -> Self {
78 let mut cached = value.to_string();
79 cached.truncate(SOCKET_ADDRESS_MAX_LEN);
80 Self(value, cached.parse().expect("cannot fail due to truncation"))
81 }
82}
83
84impl PartialEq for SocketAddrStr {
85 fn eq(&self, other: &Self) -> bool {
86 self.0 == other.0
87 }
88}
89
90impl Eq for SocketAddrStr {}
91
92impl Debug for SocketAddrStr {
93 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
94 write!(f, "{}", self.1)
95 }
96}
97
98impl Display for SocketAddrStr {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 write!(f, "{}", self.1)
101 }
102}
103
104impl PartialEq<SocketAddrStr> for SocketAddr {
105 fn eq(&self, other: &SocketAddrStr) -> bool {
106 self.eq(&other.0)
107 }
108}
109
110impl Hash for SocketAddrStr {
111 fn hash<H: Hasher>(&self, state: &mut H) {
112 self.0.hash(state);
113 }
114}
115
116#[cfg(feature = "runtime-tokio")]
117pub use tokio_utils::copy_duplex;
118
119#[cfg(feature = "runtime-tokio")]
120mod tokio_utils {
121 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
122
123 use super::*;
124
125 #[derive(Debug)]
126 enum TransferState {
127 Running(CopyBuffer),
128 ShuttingDown(u64),
129 Done(u64),
130 }
131
132 fn transfer_one_direction<A, B>(
133 cx: &mut Context<'_>,
134 state: &mut TransferState,
135 r: &mut A,
136 w: &mut B,
137 ) -> Poll<std::io::Result<u64>>
138 where
139 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
140 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
141 {
142 let mut r = Pin::new(r);
143 let mut w = Pin::new(w);
144 loop {
145 match state {
146 TransferState::Running(buf) => {
147 let count = std::task::ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
148 *state = TransferState::ShuttingDown(count);
149 }
150 TransferState::ShuttingDown(count) => {
151 std::task::ready!(w.as_mut().poll_shutdown(cx))?;
152 *state = TransferState::Done(*count);
153 }
154 TransferState::Done(count) => return Poll::Ready(Ok(*count)),
155 }
156 }
157 }
158
159 pub async fn copy_duplex<A, B>(
164 a: &mut A,
165 b: &mut B,
166 a_to_b_buffer_size: usize,
167 b_to_a_buffer_size: usize,
168 ) -> std::io::Result<(u64, u64)>
169 where
170 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
171 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
172 {
173 let mut a_to_b = TransferState::Running(CopyBuffer::new(a_to_b_buffer_size));
174 let mut b_to_a = TransferState::Running(CopyBuffer::new(b_to_a_buffer_size));
175
176 std::future::poll_fn(|cx| {
177 let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
178 let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
179
180 if let TransferState::Done(_) = b_to_a {
181 if let TransferState::Running(buf) = &a_to_b {
182 tracing::trace!("B-side has completed, terminating A-side.");
183 a_to_b = TransferState::ShuttingDown(buf.amt);
184 a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
185 }
186 }
187
188 if let TransferState::Done(_) = a_to_b {
189 if let TransferState::Running(buf) = &b_to_a {
190 tracing::trace!("A-side has completed, terminate B-side.");
191 b_to_a = TransferState::ShuttingDown(buf.amt);
192 b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
193 }
194 }
195
196 let a_to_b_bytes_transferred = std::task::ready!(a_to_b_result);
198 let b_to_a_bytes_transferred = std::task::ready!(b_to_a_result);
199
200 Poll::Ready(Ok((a_to_b_bytes_transferred, b_to_a_bytes_transferred)))
201 })
202 .await
203 }
204
205 #[derive(Debug)]
206 struct CopyBuffer {
207 read_done: bool,
208 need_flush: bool,
209 pos: usize,
210 cap: usize,
211 amt: u64,
212 buf: Box<[u8]>,
213 }
214
215 impl CopyBuffer {
216 fn new(buf_size: usize) -> Self {
217 Self {
218 read_done: false,
219 need_flush: false,
220 pos: 0,
221 cap: 0,
222 amt: 0,
223 buf: vec![0; buf_size].into_boxed_slice(),
224 }
225 }
226
227 fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<std::io::Result<()>>
228 where
229 R: AsyncRead + ?Sized,
230 {
231 let me = &mut *self;
232 let mut buf = ReadBuf::new(&mut me.buf);
233 buf.set_filled(me.cap);
234
235 let res = reader.poll_read(cx, &mut buf);
236 if let Poll::Ready(Ok(())) = res {
237 let filled_len = buf.filled().len();
238 me.read_done = me.cap == filled_len;
239 me.cap = filled_len;
240 }
241 res
242 }
243
244 fn poll_write_buf<R, W>(
245 &mut self,
246 cx: &mut Context<'_>,
247 mut reader: Pin<&mut R>,
248 mut writer: Pin<&mut W>,
249 ) -> Poll<std::io::Result<usize>>
250 where
251 R: AsyncRead + ?Sized,
252 W: AsyncWrite + ?Sized,
253 {
254 let this = &mut *self;
255 match writer.as_mut().poll_write(cx, &this.buf[this.pos..this.cap]) {
256 Poll::Pending => {
257 if !this.read_done && this.cap < this.buf.len() {
260 std::task::ready!(this.poll_fill_buf(cx, reader.as_mut()))?;
261 }
262 Poll::Pending
263 }
264 res @ Poll::Ready(_) => res,
265 }
266 }
267
268 pub(super) fn poll_copy<R, W>(
269 &mut self,
270 cx: &mut Context<'_>,
271 mut reader: Pin<&mut R>,
272 mut writer: Pin<&mut W>,
273 ) -> Poll<std::io::Result<u64>>
274 where
275 R: AsyncRead + ?Sized,
276 W: AsyncWrite + ?Sized,
277 {
278 loop {
279 if self.pos == self.cap && !self.read_done {
282 self.pos = 0;
283 self.cap = 0;
284
285 match self.poll_fill_buf(cx, reader.as_mut()) {
286 Poll::Ready(Ok(())) => (),
287 Poll::Ready(Err(err)) => {
288 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)));
289 }
290 Poll::Pending => {
291 if self.need_flush {
294 std::task::ready!(writer.as_mut().poll_flush(cx))?;
295 self.need_flush = false;
296 }
297
298 return Poll::Pending;
299 }
300 }
301 }
302
303 while self.pos < self.cap {
305 let i = std::task::ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
306 if i == 0 {
307 return Poll::Ready(Err(std::io::Error::new(
308 std::io::ErrorKind::WriteZero,
309 "write zero byte",
310 )));
311 }
312 self.pos += i;
313 self.amt += i as u64;
314 self.need_flush = true;
315 }
316
317 debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");
321
322 if self.pos == self.cap && self.read_done {
325 std::task::ready!(writer.as_mut().poll_flush(cx))?;
326 return Poll::Ready(Ok(self.amt));
327 }
328 }
329 }
330 }
331}
332
333pub struct AsyncReadStreamer<const S: usize, R>(pub R);
336
337impl<const S: usize, R: AsyncRead + Unpin> futures::Stream for AsyncReadStreamer<S, R> {
338 type Item = std::io::Result<Box<[u8]>>;
339
340 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341 let mut buffer = vec![0u8; S];
342
343 match futures::ready!(Pin::new(&mut self.0).poll_read(cx, &mut buffer)) {
344 Ok(0) => Poll::Ready(None),
345 Ok(size) => {
346 buffer.truncate(size);
347 Poll::Ready(Some(Ok(buffer.into_boxed_slice())))
348 }
349 Err(err) => Poll::Ready(Some(Err(err))),
350 }
351 }
352}
353
354#[cfg(all(feature = "runtime-tokio", test))]
355mod tests {
356 use futures::TryStreamExt;
357 use tokio::io::AsyncWriteExt;
358
359 use super::*;
360 use crate::utils::DuplexIO;
361
362 #[tokio::test]
363 async fn test_copy_duplex() -> anyhow::Result<()> {
364 const DATA_LEN: usize = 2000;
365
366 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
367 let mut alice_rx = [0u8; DATA_LEN];
368
369 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
370 let mut bob_rx = [0u8; DATA_LEN];
371
372 let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
373 let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
374
375 let (a_to_b, b_to_a) = copy_duplex(
376 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
377 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
378 128,
379 128,
380 )
381 .await?;
382
383 assert_eq!(DATA_LEN, a_to_b as usize);
384 assert_eq!(DATA_LEN, b_to_a as usize);
385
386 assert_eq!(alice_tx, bob_rx);
387 assert_eq!(bob_tx, alice_rx);
388
389 Ok(())
390 }
391
392 #[tokio::test]
393 async fn test_copy_duplex_small() -> anyhow::Result<()> {
394 const DATA_LEN: usize = 100;
395
396 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
397 let mut alice_rx = [0u8; DATA_LEN];
398
399 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
400 let mut bob_rx = [0u8; DATA_LEN];
401
402 let alice = DuplexIO(alice_tx.as_ref(), futures::io::Cursor::new(alice_rx.as_mut()));
403 let bob = DuplexIO(bob_tx.as_ref(), futures::io::Cursor::new(bob_rx.as_mut()));
404
405 let (a_to_b, b_to_a) = copy_duplex(
406 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
407 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
408 128,
409 128,
410 )
411 .await?;
412
413 assert_eq!(DATA_LEN, a_to_b as usize);
414 assert_eq!(DATA_LEN, b_to_a as usize);
415
416 assert_eq!(alice_tx, bob_rx);
417 assert_eq!(bob_tx, alice_rx);
418
419 Ok(())
420 }
421
422 #[tokio::test]
423 async fn test_client_to_server() -> anyhow::Result<()> {
424 let (mut client_tx, mut client_rx) = tokio::io::duplex(8); let (mut server_rx, mut server_tx) = tokio::io::duplex(32); client_tx.write_all(b"hello").await?;
429 client_tx.shutdown().await?;
430
431 server_tx.write_all(b"data").await?;
432 server_tx.shutdown().await?;
433
434 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
435
436 let (client_to_server_count, server_to_client_count) = result;
437 assert_eq!(client_to_server_count, 5); assert_eq!(server_to_client_count, 4); Ok(())
441 }
442
443 #[tokio::test]
444 async fn test_server_to_client() -> anyhow::Result<()> {
445 let (mut client_tx, mut client_rx) = tokio::io::duplex(32); let (mut server_rx, mut server_tx) = tokio::io::duplex(8); server_tx.write_all(b"hello").await?;
450 server_tx.shutdown().await?;
451
452 client_tx.write_all(b"some longer data to transfer").await?;
453
454 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, 2, 2).await?;
455
456 let (client_to_server_count, server_to_client_count) = result;
457 assert_eq!(server_to_client_count, 5); assert!(client_to_server_count <= 8); Ok(())
461 }
462
463 #[tokio::test]
464 async fn test_async_read_streamer_complete_chunk() {
465 let data = b"Hello, World!!";
466 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
467 let mut results = Vec::new();
468
469 while let Some(res) = streamer.try_next().await.unwrap() {
470 results.push(res);
471 }
472
473 assert_eq!(results, vec![Box::from(*data)]);
474 }
475
476 #[tokio::test]
477 async fn test_async_read_streamer_complete_more_chunks() {
478 let data = b"Hello, World and do it twice";
479 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
480 let mut results = Vec::new();
481
482 while let Some(res) = streamer.try_next().await.unwrap() {
483 results.push(res);
484 }
485
486 let (data1, data2) = data.split_at(14);
487 assert_eq!(results, vec![Box::from(data1), Box::from(data2)]);
488 }
489
490 #[tokio::test]
491 async fn test_async_read_streamer_complete_more_chunks_with_incomplete() -> anyhow::Result<()> {
492 let data = b"Hello, World and do it twice, ...";
493 let streamer = AsyncReadStreamer::<14, _>(&data[..]);
494
495 let results = streamer.try_collect::<Vec<_>>().await?;
496
497 let (data1, rest) = data.split_at(14);
498 let (data2, data3) = rest.split_at(14);
499 assert_eq!(results, vec![Box::from(data1), Box::from(data2), Box::from(data3)]);
500
501 Ok(())
502 }
503
504 #[tokio::test]
505 async fn test_async_read_streamer_incomplete_chunk() -> anyhow::Result<()> {
506 let data = b"Hello, World!!";
507 let reader = &data[0..8]; let mut streamer = AsyncReadStreamer::<14, _>(reader);
509
510 assert_eq!(Some(Box::from(reader)), streamer.try_next().await?);
511
512 Ok(())
513 }
514}