1use std::{
2 fmt::{Debug, Display, Formatter},
3 hash::{Hash, Hasher},
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures::io::{AsyncRead, AsyncWrite};
10
11#[pin_project::pin_project]
13pub struct DuplexIO<W, R>(#[pin] pub W, #[pin] pub R);
14
15impl<R, W> From<(W, R)> for DuplexIO<W, R>
16where
17 R: AsyncRead,
18 W: AsyncWrite,
19{
20 fn from(value: (W, R)) -> Self {
21 Self(value.0, value.1)
22 }
23}
24
25impl<R, W> AsyncRead for DuplexIO<W, R>
26where
27 R: AsyncRead,
28 W: AsyncWrite,
29{
30 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
31 self.project().1.poll_read(cx, buf)
32 }
33}
34
35impl<R, W> AsyncWrite for DuplexIO<W, R>
36where
37 R: AsyncRead,
38 W: AsyncWrite,
39{
40 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
41 self.project().0.poll_write(cx, buf)
42 }
43
44 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
45 self.project().0.poll_flush(cx)
46 }
47
48 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
49 self.project().0.poll_close(cx)
50 }
51}
52
53const SOCKET_ADDRESS_MAX_LEN: usize = 52;
55
56#[derive(Copy, Clone)]
58pub(crate) struct SocketAddrStr(SocketAddr, arrayvec::ArrayString<SOCKET_ADDRESS_MAX_LEN>);
59
60impl SocketAddrStr {
61 #[allow(dead_code)]
62 pub fn as_str(&self) -> &str {
63 self.1.as_str()
64 }
65}
66
67impl AsRef<SocketAddr> for SocketAddrStr {
68 fn as_ref(&self) -> &SocketAddr {
69 &self.0
70 }
71}
72
73impl From<SocketAddr> for SocketAddrStr {
74 fn from(value: SocketAddr) -> Self {
75 let mut cached = value.to_string();
76 cached.truncate(SOCKET_ADDRESS_MAX_LEN);
77 Self(value, cached.parse().expect("cannot fail due to truncation"))
78 }
79}
80
81impl PartialEq for SocketAddrStr {
82 fn eq(&self, other: &Self) -> bool {
83 self.0 == other.0
84 }
85}
86
87impl Eq for SocketAddrStr {}
88
89impl Debug for SocketAddrStr {
90 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91 write!(f, "{}", self.1)
92 }
93}
94
95impl Display for SocketAddrStr {
96 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97 write!(f, "{}", self.1)
98 }
99}
100
101impl PartialEq<SocketAddrStr> for SocketAddr {
102 fn eq(&self, other: &SocketAddrStr) -> bool {
103 self.eq(&other.0)
104 }
105}
106
107impl Hash for SocketAddrStr {
108 fn hash<H: Hasher>(&self, state: &mut H) {
109 self.0.hash(state);
110 }
111}
112
113#[cfg(feature = "runtime-tokio")]
114pub use tokio_utils::{copy_duplex, copy_duplex_abortable};
115
116#[cfg(feature = "runtime-tokio")]
117mod tokio_utils {
118 use futures::{
119 FutureExt,
120 future::{AbortHandle, Abortable},
121 };
122 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
123
124 use super::*;
125
126 #[derive(Debug)]
127 enum TransferState {
128 Running(CopyBuffer),
129 ShuttingDown(u64),
130 Done(u64),
131 }
132
133 fn transfer_one_direction<A, B>(
134 cx: &mut Context<'_>,
135 state: &mut TransferState,
136 r: &mut A,
137 w: &mut B,
138 ) -> Poll<std::io::Result<u64>>
139 where
140 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
141 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
142 {
143 let mut r = Pin::new(r);
144 let mut w = Pin::new(w);
145 loop {
146 match state {
147 TransferState::Running(buf) => {
148 let count = std::task::ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
149 tracing::trace!(processed = count, "direction copy complete");
150 *state = TransferState::ShuttingDown(count);
151 }
152 TransferState::ShuttingDown(count) => {
153 std::task::ready!(w.as_mut().poll_shutdown(cx))?;
154 tracing::trace!(processed = *count, "direction shutdown complete");
155 *state = TransferState::Done(*count);
156 }
157 TransferState::Done(count) => return Poll::Ready(Ok(*count)),
158 }
159 }
160 }
161
162 pub async fn copy_duplex<A, B>(
169 a: &mut A,
170 b: &mut B,
171 (a_to_b_buffer_size, b_to_a_buffer_size): (usize, usize),
172 ) -> std::io::Result<(u64, u64)>
173 where
174 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
175 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
176 {
177 let (_, ar_a) = AbortHandle::new_pair();
178 let (_, ar_b) = AbortHandle::new_pair();
179
180 copy_duplex_abortable(a, b, (a_to_b_buffer_size, b_to_a_buffer_size), (ar_a, ar_b)).await
181 }
182
183 pub async fn copy_duplex_abortable<A, B>(
191 a: &mut A,
192 b: &mut B,
193 (a_to_b_buffer_size, b_to_a_buffer_size): (usize, usize),
194 (a_abort, b_abort): (futures::future::AbortRegistration, futures::future::AbortRegistration),
195 ) -> std::io::Result<(u64, u64)>
196 where
197 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
198 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
199 {
200 let mut a_to_b = TransferState::Running(CopyBuffer::new(a_to_b_buffer_size));
201 let mut b_to_a = TransferState::Running(CopyBuffer::new(b_to_a_buffer_size));
202
203 let (mut abort_a, mut abort_b) = (
205 Abortable::new(futures::future::pending::<()>(), a_abort),
206 Abortable::new(futures::future::pending::<()>(), b_abort),
207 );
208
209 std::future::poll_fn(|cx| {
210 let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
211 let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
212
213 if let (Poll::Ready(Err(_)), TransferState::Running(buf)) = (abort_a.poll_unpin(cx), &a_to_b) {
215 tracing::trace!("A-side has been aborted.");
216 a_to_b = TransferState::ShuttingDown(buf.amt);
217 cx.waker().wake_by_ref();
219 }
220
221 if let (Poll::Ready(Err(_)), TransferState::Running(buf)) = (abort_b.poll_unpin(cx), &b_to_a) {
223 tracing::trace!("B-side has been aborted.");
224 b_to_a = TransferState::ShuttingDown(buf.amt);
225 cx.waker().wake_by_ref();
227 }
228
229 if let TransferState::Done(_) = b_to_a {
231 if let TransferState::Running(buf) = &a_to_b {
232 tracing::trace!("B-side has completed, terminating A-side.");
233 a_to_b = TransferState::ShuttingDown(buf.amt);
234 a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
235 }
236 }
237
238 if let TransferState::Done(_) = a_to_b {
240 if let TransferState::Running(buf) = &b_to_a {
241 tracing::trace!("A-side has completed, terminate B-side.");
242 b_to_a = TransferState::ShuttingDown(buf.amt);
243 b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
244 }
245 }
246
247 let a_to_b_bytes_transferred = std::task::ready!(a_to_b_result);
249 let b_to_a_bytes_transferred = std::task::ready!(b_to_a_result);
250
251 tracing::trace!(
252 a_to_b = a_to_b_bytes_transferred,
253 b_to_a = b_to_a_bytes_transferred,
254 "copy completed"
255 );
256 Poll::Ready(Ok((a_to_b_bytes_transferred, b_to_a_bytes_transferred)))
257 })
258 .await
259 }
260
261 #[derive(Debug)]
262 struct CopyBuffer {
263 read_done: bool,
264 need_flush: bool,
265 pos: usize,
266 cap: usize,
267 amt: u64,
268 buf: Box<[u8]>,
269 }
270
271 impl CopyBuffer {
272 fn new(buf_size: usize) -> Self {
273 Self {
274 read_done: false,
275 need_flush: false,
276 pos: 0,
277 cap: 0,
278 amt: 0,
279 buf: vec![0; buf_size].into_boxed_slice(),
280 }
281 }
282
283 fn poll_fill_buf<R>(&mut self, cx: &mut Context<'_>, reader: Pin<&mut R>) -> Poll<std::io::Result<()>>
284 where
285 R: AsyncRead + ?Sized,
286 {
287 let me = &mut *self;
288 let mut buf = ReadBuf::new(&mut me.buf);
289 buf.set_filled(me.cap);
290
291 let res = reader.poll_read(cx, &mut buf);
292 if let Poll::Ready(Ok(())) = res {
293 let filled_len = buf.filled().len();
294 me.read_done = me.cap == filled_len;
295 me.cap = filled_len;
296 }
297 res
298 }
299
300 fn poll_write_buf<R, W>(
301 &mut self,
302 cx: &mut Context<'_>,
303 mut reader: Pin<&mut R>,
304 mut writer: Pin<&mut W>,
305 ) -> Poll<std::io::Result<usize>>
306 where
307 R: AsyncRead + ?Sized,
308 W: AsyncWrite + ?Sized,
309 {
310 let this = &mut *self;
311 match writer.as_mut().poll_write(cx, &this.buf[this.pos..this.cap]) {
312 Poll::Pending => {
313 if !this.read_done && this.cap < this.buf.len() {
316 std::task::ready!(this.poll_fill_buf(cx, reader.as_mut()))?;
317 }
318 Poll::Pending
319 }
320 res @ Poll::Ready(_) => res,
321 }
322 }
323
324 pub(super) fn poll_copy<R, W>(
325 &mut self,
326 cx: &mut Context<'_>,
327 mut reader: Pin<&mut R>,
328 mut writer: Pin<&mut W>,
329 ) -> Poll<std::io::Result<u64>>
330 where
331 R: AsyncRead + ?Sized,
332 W: AsyncWrite + ?Sized,
333 {
334 loop {
335 if self.pos == self.cap && !self.read_done {
338 self.pos = 0;
339 self.cap = 0;
340
341 match self.poll_fill_buf(cx, reader.as_mut()) {
342 Poll::Ready(Ok(())) => (),
343 Poll::Ready(Err(err)) => {
344 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)));
345 }
346 Poll::Pending => {
347 if self.need_flush {
350 std::task::ready!(writer.as_mut().poll_flush(cx))?;
351 self.need_flush = false;
352 }
353
354 return Poll::Pending;
355 }
356 }
357 }
358
359 while self.pos < self.cap {
361 let i = std::task::ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
362 if i == 0 {
363 return Poll::Ready(Err(std::io::Error::new(
364 std::io::ErrorKind::WriteZero,
365 "write zero byte",
366 )));
367 }
368 self.pos += i;
369 self.amt += i as u64;
370 self.need_flush = true;
371 }
372
373 debug_assert!(self.pos <= self.cap, "writer returned length larger than input slice");
377
378 if self.pos == self.cap && self.read_done {
381 std::task::ready!(writer.as_mut().poll_flush(cx))?;
382 return Poll::Ready(Ok(self.amt));
383 }
384 }
385 }
386 }
387}
388
389#[pin_project::pin_project]
392pub struct AsyncReadStreamer<const S: usize, R>(#[pin] pub R);
393
394impl<const S: usize, R: AsyncRead> futures::Stream for AsyncReadStreamer<S, R> {
395 type Item = std::io::Result<Box<[u8]>>;
396
397 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 let mut buffer = vec![0u8; S];
399 let mut this = self.project();
400
401 match futures::ready!(this.0.as_mut().poll_read(cx, &mut buffer)) {
402 Ok(0) => Poll::Ready(None),
403 Ok(size) => {
404 buffer.truncate(size);
405 Poll::Ready(Some(Ok(buffer.into_boxed_slice())))
406 }
407 Err(err) => Poll::Ready(Some(Err(err))),
408 }
409 }
410}
411
412#[pin_project::pin_project]
415pub struct AsyncWriteSink<const C: usize, S>(#[pin] pub S);
416
417impl<const C: usize, S> AsyncWrite for AsyncWriteSink<C, S>
418where
419 S: futures::Sink<Box<[u8]>>,
420 S::Error: Into<std::io::Error>,
421{
422 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
423 let mut this = self.project();
424
425 futures::ready!(this.0.as_mut().poll_ready(cx).map_err(Into::into))?;
426 let len = buf.len().min(C);
427
428 match this.0.as_mut().start_send(Box::from(&buf[..len])) {
429 Ok(()) => Poll::Ready(Ok(len)),
430 Err(e) => Poll::Ready(Err(e.into())),
431 }
432 }
433
434 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
435 self.project().0.poll_flush(cx).map_err(Into::into)
436 }
437
438 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
439 self.project().0.poll_close(cx).map_err(Into::into)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use futures::{SinkExt, StreamExt, TryStreamExt};
446 use tokio::io::AsyncWriteExt;
447
448 use super::*;
449
450 #[tokio::test]
451 async fn test_copy_duplex() -> anyhow::Result<()> {
452 const DATA_LEN: usize = 2000;
453
454 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
455 let mut alice_rx = [0u8; DATA_LEN];
456
457 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
458 let mut bob_rx = [0u8; DATA_LEN];
459
460 let alice = DuplexIO(futures::io::Cursor::new(alice_rx.as_mut()), alice_tx.as_ref());
461 let bob = DuplexIO(futures::io::Cursor::new(bob_rx.as_mut()), bob_tx.as_ref());
462
463 let (a_to_b, b_to_a) = copy_duplex(
464 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
465 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
466 (128, 128),
467 )
468 .await?;
469
470 assert_eq!(DATA_LEN, a_to_b as usize);
471 assert_eq!(DATA_LEN, b_to_a as usize);
472
473 assert_eq!(alice_tx, bob_rx);
474 assert_eq!(bob_tx, alice_rx);
475
476 Ok(())
477 }
478
479 #[test_log::test(tokio::test(flavor = "multi_thread"))]
480 async fn test_copy_duplex_with_abort_from_client() -> anyhow::Result<()> {
481 let (mut client_tx, mut client_rx) = tokio::io::duplex(10); let (mut server_rx, mut server_tx) = tokio::io::duplex(10); client_tx.write_all(b"hello").await?;
486 server_tx.write_all(b"data").await?;
487
488 let (handle_a, reg_a) = futures::future::AbortHandle::new_pair();
489 let (_, reg_b) = futures::future::AbortHandle::new_pair();
490
491 let result = tokio::task::spawn(async move {
492 crate::utils::copy_duplex_abortable(&mut client_rx, &mut server_rx, (2, 2), (reg_a, reg_b)).await
493 });
494
495 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
496
497 handle_a.abort();
499
500 let (a, b) = tokio::time::timeout(std::time::Duration::from_millis(100), result).await???;
501 assert_eq!(a, 5);
502 assert_eq!(b, 4);
503
504 Ok(())
505 }
506
507 #[tokio::test(flavor = "multi_thread")]
508 async fn test_copy_duplex_with_abort_from_server() -> anyhow::Result<()> {
509 let (mut client_tx, mut client_rx) = tokio::io::duplex(10); let (mut server_rx, mut server_tx) = tokio::io::duplex(10); client_tx.write_all(b"hello").await?;
514 server_tx.write_all(b"data").await?;
515
516 let (_, reg_a) = futures::future::AbortHandle::new_pair();
517 let (handle_b, reg_b) = futures::future::AbortHandle::new_pair();
518
519 let result = tokio::task::spawn(async move {
520 crate::utils::copy_duplex_abortable(&mut client_rx, &mut server_rx, (2, 2), (reg_a, reg_b)).await
521 });
522
523 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
524
525 handle_b.abort();
527
528 let (a, b) = tokio::time::timeout(std::time::Duration::from_millis(100), result).await???;
529 assert_eq!(a, 5);
530 assert_eq!(b, 4);
531
532 Ok(())
533 }
534
535 #[tokio::test]
536 async fn test_copy_duplex_small() -> anyhow::Result<()> {
537 const DATA_LEN: usize = 100;
538
539 let alice_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
540 let mut alice_rx = [0u8; DATA_LEN];
541
542 let bob_tx = hopr_crypto_random::random_bytes::<DATA_LEN>();
543 let mut bob_rx = [0u8; DATA_LEN];
544
545 let alice = DuplexIO(futures::io::Cursor::new(alice_rx.as_mut()), alice_tx.as_ref());
546 let bob = DuplexIO(futures::io::Cursor::new(bob_rx.as_mut()), bob_tx.as_ref());
547
548 let (a_to_b, b_to_a) = copy_duplex(
549 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(alice),
550 &mut tokio_util::compat::FuturesAsyncReadCompatExt::compat(bob),
551 (128, 128),
552 )
553 .await?;
554
555 assert_eq!(DATA_LEN, a_to_b as usize);
556 assert_eq!(DATA_LEN, b_to_a as usize);
557
558 assert_eq!(alice_tx, bob_rx);
559 assert_eq!(bob_tx, alice_rx);
560
561 Ok(())
562 }
563
564 #[tokio::test]
565 async fn test_client_to_server() -> anyhow::Result<()> {
566 let (mut client_tx, mut client_rx) = tokio::io::duplex(8); let (mut server_rx, mut server_tx) = tokio::io::duplex(32); client_tx.write_all(b"hello").await?;
571 client_tx.shutdown().await?;
572
573 server_tx.write_all(b"data").await?;
574 server_tx.shutdown().await?;
575
576 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, (2, 2)).await?;
577
578 let (client_to_server_count, server_to_client_count) = result;
579 assert_eq!(client_to_server_count, 5); assert_eq!(server_to_client_count, 4); Ok(())
583 }
584
585 #[tokio::test]
586 async fn test_server_to_client() -> anyhow::Result<()> {
587 let (mut client_tx, mut client_rx) = tokio::io::duplex(32); let (mut server_rx, mut server_tx) = tokio::io::duplex(8); server_tx.write_all(b"hello").await?;
592 server_tx.shutdown().await?;
593
594 client_tx.write_all(b"some longer data to transfer").await?;
595
596 let result = crate::utils::copy_duplex(&mut client_rx, &mut server_rx, (2, 2)).await?;
597
598 let (client_to_server_count, server_to_client_count) = result;
599 assert_eq!(server_to_client_count, 5); assert!(client_to_server_count <= 8); Ok(())
603 }
604
605 #[tokio::test]
606 async fn test_async_read_streamer_complete_chunk() {
607 let data = b"Hello, World!!";
608 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
609 let mut results = Vec::new();
610
611 while let Some(res) = streamer.try_next().await.unwrap() {
612 results.push(res);
613 }
614
615 assert_eq!(results, vec![Box::from(*data)]);
616 }
617
618 #[tokio::test]
619 async fn test_async_read_streamer_complete_more_chunks() {
620 let data = b"Hello, World and do it twice";
621 let mut streamer = AsyncReadStreamer::<14, _>(&data[..]);
622 let mut results = Vec::new();
623
624 while let Some(res) = streamer.try_next().await.unwrap() {
625 results.push(res);
626 }
627
628 let (data1, data2) = data.split_at(14);
629 assert_eq!(results, vec![Box::from(data1), Box::from(data2)]);
630 }
631
632 #[tokio::test]
633 async fn test_async_read_streamer_complete_more_chunks_with_incomplete() -> anyhow::Result<()> {
634 let data = b"Hello, World and do it twice, ...";
635 let streamer = AsyncReadStreamer::<14, _>(&data[..]);
636
637 let results = streamer.try_collect::<Vec<_>>().await?;
638
639 let (data1, rest) = data.split_at(14);
640 let (data2, data3) = rest.split_at(14);
641 assert_eq!(results, vec![Box::from(data1), Box::from(data2), Box::from(data3)]);
642
643 Ok(())
644 }
645
646 #[tokio::test]
647 async fn test_async_read_streamer_incomplete_chunk() -> anyhow::Result<()> {
648 let data = b"Hello, World!!";
649 let reader = &data[0..8]; let mut streamer = AsyncReadStreamer::<14, _>(reader);
651
652 assert_eq!(Some(Box::from(reader)), streamer.try_next().await?);
653
654 Ok(())
655 }
656
657 #[tokio::test]
658 async fn test_async_write_sink_should_perform_write_in_chunks() -> anyhow::Result<()> {
659 let data = b"Hello, World!!";
660 let (tx, rx) = futures::channel::mpsc::unbounded::<Box<[u8]>>();
661
662 use futures::AsyncWriteExt;
663
664 let mut writer = AsyncWriteSink::<7, _>(tx.sink_map_err(|e| std::io::Error::other(e)));
665
666 AsyncWriteExt::write_all(&mut writer, data).await?;
667 AsyncWriteExt::flush(&mut writer).await?;
668 AsyncWriteExt::close(&mut writer).await?;
669
670 let rx_data = rx.collect::<Vec<_>>().await;
671 assert_eq!(2, rx_data.len());
672 assert_eq!(rx_data[0], (&data[0..7]).into());
673 assert_eq!(rx_data[1], (&data[7..]).into());
674
675 Ok(())
676 }
677}