1use std::{
2 fmt::Debug,
3 io::ErrorKind,
4 num::NonZeroUsize,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::{FutureExt, Sink, SinkExt, ready};
12use tokio::net::UdpSocket;
13use tracing::{debug, error, instrument, trace, warn};
14
15use crate::utils::SocketAddrStr;
16
17type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21 static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::SimpleHistogram =
22 hopr_metrics::SimpleHistogram::new(
23 "hopr_udp_ingress_packet_len",
24 "UDP packet lengths on ingress",
25 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0]
26 ).unwrap();
27 static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::SimpleHistogram =
28 hopr_metrics::SimpleHistogram::new(
29 "hopr_udp_egress_packet_len",
30 "UDP packet lengths on egress",
31 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0]
32 ).unwrap();
33}
34
35#[pin_project::pin_project(PinnedDrop)]
56pub struct ConnectedUdpStream {
57 socket_handles: Vec<tokio::task::JoinHandle<()>>,
58 #[pin]
59 ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
60 #[pin]
61 egress_tx: Option<BoxIoSink<Box<[u8]>>>,
62 counterparty: Arc<OnceLock<SocketAddrStr>>,
63 bound_to: std::net::SocketAddr,
64 state: State,
65}
66
67#[derive(Copy, Clone)]
68enum State {
69 Writing,
70 Flushing(usize),
71}
72
73#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
77pub enum ForeignDataMode {
78 Discard,
80 Accept,
82 #[default]
84 Error,
85}
86
87#[derive(Copy, Clone, Debug, PartialEq, Eq)]
104pub enum UdpStreamParallelism {
105 Auto,
107 Specific(NonZeroUsize),
109}
110
111impl Default for UdpStreamParallelism {
112 fn default() -> Self {
113 Self::Specific(NonZeroUsize::MIN)
114 }
115}
116
117impl UdpStreamParallelism {
118 fn avail_parallelism() -> usize {
119 std::thread::available_parallelism()
122 .map(|n| {
123 if cfg!(target_os = "linux") {
124 n
125 } else {
126 NonZeroUsize::MIN
127 }
128 })
129 .unwrap_or_else(|e| {
130 warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
131 NonZeroUsize::MIN
132 })
133 .into()
134 }
135
136 pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
139 let cpu_half = (Self::avail_parallelism() / 2).max(1);
140
141 match (self, other) {
142 (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
143 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
144 let a = cpu_half.min(a.into());
145 (a, cpu_half * 2 - a)
146 }
147 (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
148 let b = cpu_half.min(b.into());
149 (cpu_half * 2 - b, b)
150 }
151 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
152 (cpu_half.min(a.into()), cpu_half.min(b.into()))
153 }
154 }
155 }
156
157 pub fn into_num_tasks(self) -> usize {
161 let avail_parallelism = Self::avail_parallelism();
162 match self {
163 UdpStreamParallelism::Auto => avail_parallelism,
164 UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
165 }
166 }
167}
168
169impl From<usize> for UdpStreamParallelism {
170 fn from(value: usize) -> Self {
171 NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
172 }
173}
174
175impl From<Option<usize>> for UdpStreamParallelism {
176 fn from(value: Option<usize>) -> Self {
177 value.map(UdpStreamParallelism::from).unwrap_or_default()
178 }
179}
180
181#[derive(Debug, Clone)]
185pub struct UdpStreamBuilder {
186 foreign_data_mode: ForeignDataMode,
187 buffer_size: usize,
188 queue_size: Option<usize>,
189 receiver_parallelism: UdpStreamParallelism,
190 sender_parallelism: UdpStreamParallelism,
191 counterparty: Option<std::net::SocketAddr>,
192}
193
194impl Default for UdpStreamBuilder {
195 fn default() -> Self {
196 Self {
197 buffer_size: 2048,
198 foreign_data_mode: Default::default(),
199 queue_size: None,
200 receiver_parallelism: Default::default(),
201 sender_parallelism: Default::default(),
202 counterparty: None,
203 }
204 }
205}
206
207impl UdpStreamBuilder {
208 pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
213 self.foreign_data_mode = mode;
214 self
215 }
216
217 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
224 self.buffer_size = buffer_size;
225 self
226 }
227
228 pub fn with_queue_size(mut self, queue_size: usize) -> Self {
239 self.queue_size = Some(queue_size);
240 self
241 }
242
243 pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
249 self.receiver_parallelism = parallelism.into();
250 self
251 }
252
253 pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
259 self.sender_parallelism = parallelism.into();
260 self
261 }
262
263 pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
274 self.counterparty = Some(counterparty);
275 self
276 }
277
278 pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
296 let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
297
298 let counterparty = Arc::new(
299 self.counterparty
300 .map(|s| OnceLock::from(SocketAddrStr::from(s)))
301 .unwrap_or_default(),
302 );
303 let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
304 (flume::bounded(q), flume::bounded(q))
305 } else {
306 (flume::unbounded(), flume::unbounded())
307 };
308
309 let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
310 let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
311 let mut bound_addr: Option<std::net::SocketAddr> = None;
312
313 for binding_to in bind_addr.to_socket_addrs()? {
315 debug!(
316 %binding_to,
317 num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
318 );
319
320 (0..num_socks_to_bind)
324 .map(|sock_id| {
325 let domain = match &binding_to {
326 std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
327 std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
328 };
329
330 let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
332 if num_socks_to_bind > 1 {
333 sock.set_reuse_address(true)?; sock.set_reuse_port(true)?; }
336 sock.set_nonblocking(true)?;
337 sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
338
339 let socket_bound_addr = sock
341 .local_addr()?
342 .as_socket()
343 .ok_or(std::io::Error::other("invalid socket type"))?;
344
345 match bound_addr {
346 None => bound_addr = Some(socket_bound_addr),
347 Some(addr) if addr != socket_bound_addr => {
348 return Err(std::io::Error::other(format!(
349 "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
350 )));
351 }
352 _ => {}
353 }
354
355 let sock = Arc::new(UdpSocket::from_std(sock.into())?);
356 debug!(
357 socket_id = sock_id,
358 addr = %socket_bound_addr,
359 "bound UDP socket"
360 );
361
362 Ok((sock_id, sock))
363 })
364 .filter_map(|result| match result {
365 Ok(bound) => Some(bound),
366 Err(error) => {
367 error!(
368 %binding_to,
369 %error,
370 "failed to bind udp socket"
371 );
372 None
373 }
374 })
375 .for_each(|(sock_id, sock)| {
376 if sock_id < num_tx_socks {
377 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
378 sock_id,
379 sock.clone(),
380 egress_rx.clone(),
381 counterparty.clone(),
382 )));
383 }
384 if sock_id < num_rx_socks {
385 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
386 sock_id,
387 sock.clone(),
388 ingress_tx.clone(),
389 counterparty.clone(),
390 self.foreign_data_mode,
391 self.buffer_size,
392 )));
393 }
394 });
395 }
396
397 Ok(ConnectedUdpStream {
398 ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
399 egress_tx: Some(Box::new(
400 egress_tx
401 .into_sink()
402 .sink_map_err(|e| std::io::Error::other(e.to_string())),
403 )),
404 socket_handles,
405 counterparty,
406 bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
407 state: State::Writing,
408 })
409 }
410}
411
412const QUEUE_DISPATCH_THRESHOLD: Duration = Duration::from_millis(150);
413
414impl ConnectedUdpStream {
415 fn setup_rx_queue(
417 socket_id: usize,
418 sock_rx: Arc<UdpSocket>,
419 ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
420 counterparty: Arc<OnceLock<SocketAddrStr>>,
421 foreign_data_mode: ForeignDataMode,
422 buf_size: usize,
423 ) -> futures::future::BoxFuture<'static, ()> {
424 let counterparty_rx = counterparty.clone();
425 async move {
426 let mut buffer = vec![0u8; buf_size];
427 let mut done = false;
428 loop {
429 let out_res = match sock_rx.recv_from(&mut buffer).await {
431 Ok((read, read_addr)) if read > 0 => {
432 trace!(
433 socket_id,
434 udp_bound_addr = ?sock_rx.local_addr(),
435 bytes = read,
436 from = %read_addr,
437 "received data from"
438 );
439
440 let addr = counterparty_rx.get_or_init(|| read_addr.into());
441
442 #[cfg(all(feature = "prometheus", not(test)))]
443 METRIC_UDP_INGRESS_LEN.observe(read as f64);
444
445 if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
447 let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
448 Some(Ok(out_buffer))
449 } else {
450 match foreign_data_mode {
451 ForeignDataMode::Discard => {
452 warn!(
454 socket_id,
455 udp_bound_addr = ?sock_rx.local_addr(),
456 ?read_addr,
457 expected_addr = ?addr,
458 "discarded data, which didn't come from the expected address"
459 );
460 None
461 }
462 ForeignDataMode::Error => {
463 done = true;
465 Some(Err(std::io::Error::new(
466 ErrorKind::ConnectionRefused,
467 "data from foreign client not allowed",
468 )))
469 }
470 _ => unreachable!(),
472 }
473 }
474 }
475 Ok(_) => {
476 trace!(
478 socket_id,
479 udp_bound_addr = ?sock_rx.local_addr(),
480 "read EOF on socket"
481 );
482 done = true;
483 None
484 }
485 Err(error) => {
486 debug!(
488 socket_id,
489 udp_bound_addr = ?sock_rx.local_addr(),
490 %error,
491 "forwarded error from socket"
492 );
493 done = true;
494 Some(Err(error))
495 }
496 };
497
498 if let Some(out_res) = out_res {
501 let start = std::time::Instant::now();
502 if let Err(error) = ingress_tx.send_async(out_res).await {
503 error!(
504 socket_id,
505 udp_bound_addr = ?sock_rx.local_addr(),
506 %error,
507 "failed to dispatch received data"
508 );
509 done = true;
510 }
511 let elapsed = start.elapsed();
512 if elapsed > QUEUE_DISPATCH_THRESHOLD {
513 warn!(
514 ?elapsed,
515 "udp queue dispatch took too long, consider increasing the queue size"
516 );
517 }
518 }
519
520 if done {
521 trace!(
522 socket_id,
523 udp_bound_addr = ?sock_rx.local_addr(),
524 "rx queue done"
525 );
526 break;
527 }
528 }
529 }
530 .boxed()
531 }
532
533 fn setup_tx_queue(
535 socket_id: usize,
536 sock_tx: Arc<UdpSocket>,
537 egress_rx: flume::Receiver<Box<[u8]>>,
538 counterparty: Arc<OnceLock<SocketAddrStr>>,
539 ) -> futures::future::BoxFuture<'static, ()> {
540 let counterparty_tx = counterparty.clone();
541 async move {
542 loop {
543 match egress_rx.recv_async().await {
544 Ok(data) => {
545 if let Some(target) = counterparty_tx.get() {
546 if let Err(error) = sock_tx.send_to(&data, target.as_ref()).await {
547 error!(
548 ?socket_id,
549 udp_bound_addr = ?sock_tx.local_addr(),
550 ?target,
551 %error,
552 "failed to send data"
553 );
554 }
555 trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
556
557 #[cfg(all(feature = "prometheus", not(test)))]
558 METRIC_UDP_EGRESS_LEN.observe(data.len() as f64);
559 } else {
560 error!(
561 ?socket_id,
562 udp_bound_addr = ?sock_tx.local_addr(),
563 "cannot send data, counterparty not set"
564 );
565 break;
566 }
567 }
568 Err(error) => {
569 error!(
570 ?socket_id,
571 udp_bound_addr = ?sock_tx.local_addr(),
572 %error,
573 "cannot receive more data from egress channel"
574 );
575 break;
576 }
577 }
578 trace!(
579 ?socket_id,
580 udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
581 "tx queue done"
582 );
583 }
584 }
585 .boxed()
586 }
587
588 pub fn bound_address(&self) -> &std::net::SocketAddr {
590 &self.bound_to
591 }
592
593 pub fn builder() -> UdpStreamBuilder {
595 UdpStreamBuilder::default()
596 }
597}
598
599impl tokio::io::AsyncRead for ConnectedUdpStream {
600 #[instrument(name = "ConnectedUdpStream::poll_read", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), rem = buf.remaining()) , ret)]
601 fn poll_read(
602 self: Pin<&mut Self>,
603 cx: &mut Context<'_>,
604 buf: &mut tokio::io::ReadBuf<'_>,
605 ) -> Poll<std::io::Result<()>> {
606 ready!(self.project().ingress_rx.poll_read(cx, buf))?;
607 trace!(bytes = buf.filled().len(), "read bytes");
608 Poll::Ready(Ok(()))
609 }
610}
611
612impl tokio::io::AsyncWrite for ConnectedUdpStream {
613 #[instrument(name = "ConnectedUdpStream::poll_write", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), len = buf.len()) , ret)]
614 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
615 let this = self.project();
616 if let Some(sender) = this.egress_tx.get_mut() {
617 loop {
618 match *this.state {
619 State::Writing => {
620 ready!(sender.poll_ready_unpin(cx))?;
621
622 let len = buf.iter().len();
623 sender.start_send_unpin(Box::from(buf))?;
624 *this.state = State::Flushing(len);
625 }
626 State::Flushing(len) => {
627 let res = ready!(sender.poll_flush_unpin(cx)).map(|_| len);
629 *this.state = State::Writing;
630
631 return Poll::Ready(res);
632 }
633 }
634 }
635 } else {
636 Poll::Ready(Err(std::io::Error::new(
637 ErrorKind::NotConnected,
638 "udp stream is closed",
639 )))
640 }
641 }
642
643 #[instrument(name = "ConnectedUdpStream::poll_flush", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
644 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
645 let this = self.project();
646 if let Some(sender) = this.egress_tx.as_pin_mut() {
647 sender.poll_flush(cx).map_err(std::io::Error::other)
648 } else {
649 Poll::Ready(Err(std::io::Error::new(
650 ErrorKind::NotConnected,
651 "udp stream is closed",
652 )))
653 }
654 }
655
656 #[instrument(name = "ConnectedUdpStream::poll_shutdown", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
657 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
658 let this = self.project();
659 if let Some(sender) = this.egress_tx.as_pin_mut() {
660 let ret = ready!(sender.poll_close(cx));
661
662 this.socket_handles.iter().for_each(|handle| {
663 handle.abort();
664 });
665
666 Poll::Ready(ret)
667 } else {
668 Poll::Ready(Err(std::io::Error::new(
669 ErrorKind::NotConnected,
670 "udp stream is closed",
671 )))
672 }
673 }
674}
675
676#[pin_project::pinned_drop]
677impl PinnedDrop for ConnectedUdpStream {
678 fn drop(self: Pin<&mut Self>) {
679 debug!(binding = ?self.bound_to,"dropping ConnectedUdpStream");
680 self.project().socket_handles.iter().for_each(|handle| {
681 handle.abort();
682 })
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use anyhow::Context;
689 use futures::{future::Either, pin_mut};
690 use parameterized::parameterized;
691 use tokio::{
692 io::{AsyncReadExt, AsyncWriteExt},
693 net::UdpSocket,
694 };
695
696 use super::*;
697
698 #[parameterized(parallelism = {None, Some(2), Some(0)})]
699 #[parameterized_macro(tokio::test)]
700 async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
702 const DATA_SIZE: usize = 200;
703
704 let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
705 let listen_addr = listener.local_addr()?;
706
707 tokio::task::spawn(async move {
709 loop {
710 let mut buf = [0u8; DATA_SIZE];
711 let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
712 if read > 0 {
713 assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
714 listener.send_to(&buf, addr).await.expect("send must not fail");
715 }
716 }
717 });
718
719 let mut builder = ConnectedUdpStream::builder()
720 .with_buffer_size(1024)
721 .with_queue_size(512)
722 .with_counterparty(listen_addr);
723
724 if let Some(parallelism) = parallelism {
725 builder = builder.with_receiver_parallelism(parallelism);
726 }
727
728 let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
729
730 for _ in 1..1000 {
731 let mut w_buf = [0u8; DATA_SIZE];
732 hopr_crypto_random::random_fill(&mut w_buf);
733 let written = stream.write(&w_buf).await?;
734 assert_eq!(written, DATA_SIZE);
735
736 let mut r_buf = [0u8; DATA_SIZE];
737 let read = stream.read_exact(&mut r_buf).await?;
738 assert_eq!(read, DATA_SIZE);
739
740 assert_eq!(w_buf, r_buf);
741 }
742
743 stream.shutdown().await?;
744
745 Ok(())
746 }
747
748 #[tokio::test]
749 async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
750 const BUF_SIZE: usize = 1024;
751 const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
752
753 let mut listener = ConnectedUdpStream::builder()
754 .with_buffer_size(BUF_SIZE)
755 .with_queue_size(512)
756 .build(("127.0.0.1", 0))
757 .context("bind listener")?;
758
759 let bound_addr = *listener.bound_address();
760
761 let jh = tokio::task::spawn(async move {
762 let mut buf = [0u8; BUF_SIZE / 4];
763 let mut vec = Vec::<u8>::new();
764 loop {
765 let sz = listener.read(&mut buf).await.unwrap();
766 if sz > 0 {
767 vec.extend_from_slice(&buf[..sz]);
768 if vec.len() >= EXPECTED_DATA_LEN {
769 return vec;
770 }
771 } else {
772 return vec;
773 }
774 }
775 });
776
777 let msg = [1u8; EXPECTED_DATA_LEN];
778 let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
779
780 sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
781 sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
782
783 let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
784 pin_mut!(timeout);
785 pin_mut!(jh);
786
787 match futures::future::select(jh, timeout).await {
788 Either::Left((Ok(v), _)) => {
789 assert_eq!(v.len(), EXPECTED_DATA_LEN);
790 assert_eq!(v.as_slice(), &msg);
791 Ok(())
792 }
793 _ => Err(anyhow::anyhow!("timeout or invalid data")),
794 }
795 }
796}