1use std::{
2 fmt::Debug,
3 io::ErrorKind,
4 num::NonZeroUsize,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::{FutureExt, Sink, SinkExt, ready};
12use tokio::net::UdpSocket;
13use tracing::{debug, error, instrument, trace, warn};
14
15use crate::utils::SocketAddrStr;
16
17type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21 static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
22 hopr_metrics::MultiHistogram::new(
23 "hopr_udp_ingress_packet_len",
24 "UDP packet lengths on ingress per counterparty",
25 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
26 &["counterparty"]
27 ).unwrap();
28 static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
29 hopr_metrics::MultiHistogram::new(
30 "hopr_udp_egress_packet_len",
31 "UDP packet lengths on egress per counterparty",
32 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
33 &["counterparty"]
34 ).unwrap();
35}
36
37#[pin_project::pin_project(PinnedDrop)]
58pub struct ConnectedUdpStream {
59 socket_handles: Vec<tokio::task::JoinHandle<()>>,
60 #[pin]
61 ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
62 #[pin]
63 egress_tx: Option<BoxIoSink<Box<[u8]>>>,
64 counterparty: Arc<OnceLock<SocketAddrStr>>,
65 bound_to: std::net::SocketAddr,
66 state: State,
67}
68
69#[derive(Copy, Clone)]
70enum State {
71 Writing,
72 Flushing(usize),
73}
74
75#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
79pub enum ForeignDataMode {
80 Discard,
82 Accept,
84 #[default]
86 Error,
87}
88
89#[derive(Copy, Clone, Debug, PartialEq, Eq)]
106pub enum UdpStreamParallelism {
107 Auto,
109 Specific(NonZeroUsize),
111}
112
113impl Default for UdpStreamParallelism {
114 fn default() -> Self {
115 Self::Specific(NonZeroUsize::MIN)
116 }
117}
118
119impl UdpStreamParallelism {
120 fn avail_parallelism() -> usize {
121 std::thread::available_parallelism()
124 .map(|n| {
125 if cfg!(target_os = "linux") {
126 n
127 } else {
128 NonZeroUsize::MIN
129 }
130 })
131 .unwrap_or_else(|e| {
132 warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
133 NonZeroUsize::MIN
134 })
135 .into()
136 }
137
138 pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
141 let cpu_half = (Self::avail_parallelism() / 2).max(1);
142
143 match (self, other) {
144 (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
145 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
146 let a = cpu_half.min(a.into());
147 (a, cpu_half * 2 - a)
148 }
149 (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
150 let b = cpu_half.min(b.into());
151 (cpu_half * 2 - b, b)
152 }
153 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
154 (cpu_half.min(a.into()), cpu_half.min(b.into()))
155 }
156 }
157 }
158
159 pub fn into_num_tasks(self) -> usize {
163 let avail_parallelism = Self::avail_parallelism();
164 match self {
165 UdpStreamParallelism::Auto => avail_parallelism,
166 UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
167 }
168 }
169}
170
171impl From<usize> for UdpStreamParallelism {
172 fn from(value: usize) -> Self {
173 NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
174 }
175}
176
177impl From<Option<usize>> for UdpStreamParallelism {
178 fn from(value: Option<usize>) -> Self {
179 value.map(UdpStreamParallelism::from).unwrap_or_default()
180 }
181}
182
183#[derive(Debug, Clone)]
187pub struct UdpStreamBuilder {
188 foreign_data_mode: ForeignDataMode,
189 buffer_size: usize,
190 queue_size: Option<usize>,
191 receiver_parallelism: UdpStreamParallelism,
192 sender_parallelism: UdpStreamParallelism,
193 counterparty: Option<std::net::SocketAddr>,
194}
195
196impl Default for UdpStreamBuilder {
197 fn default() -> Self {
198 Self {
199 buffer_size: 2048,
200 foreign_data_mode: Default::default(),
201 queue_size: None,
202 receiver_parallelism: Default::default(),
203 sender_parallelism: Default::default(),
204 counterparty: None,
205 }
206 }
207}
208
209impl UdpStreamBuilder {
210 pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
215 self.foreign_data_mode = mode;
216 self
217 }
218
219 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
226 self.buffer_size = buffer_size;
227 self
228 }
229
230 pub fn with_queue_size(mut self, queue_size: usize) -> Self {
241 self.queue_size = Some(queue_size);
242 self
243 }
244
245 pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
251 self.receiver_parallelism = parallelism.into();
252 self
253 }
254
255 pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
261 self.sender_parallelism = parallelism.into();
262 self
263 }
264
265 pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
276 self.counterparty = Some(counterparty);
277 self
278 }
279
280 pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
298 let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
299
300 let counterparty = Arc::new(
301 self.counterparty
302 .map(|s| OnceLock::from(SocketAddrStr::from(s)))
303 .unwrap_or_default(),
304 );
305 let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
306 (flume::bounded(q), flume::bounded(q))
307 } else {
308 (flume::unbounded(), flume::unbounded())
309 };
310
311 let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
312 let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
313 let mut bound_addr: Option<std::net::SocketAddr> = None;
314
315 for binding_to in bind_addr.to_socket_addrs()? {
317 debug!(
318 %binding_to,
319 num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
320 );
321
322 (0..num_socks_to_bind)
326 .map(|sock_id| {
327 let domain = match &binding_to {
328 std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
329 std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
330 };
331
332 let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
334 if num_socks_to_bind > 1 {
335 sock.set_reuse_address(true)?; sock.set_reuse_port(true)?; }
338 sock.set_nonblocking(true)?;
339 sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
340
341 let socket_bound_addr = sock
343 .local_addr()?
344 .as_socket()
345 .ok_or(std::io::Error::other("invalid socket type"))?;
346
347 match bound_addr {
348 None => bound_addr = Some(socket_bound_addr),
349 Some(addr) if addr != socket_bound_addr => {
350 return Err(std::io::Error::other(format!(
351 "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
352 )));
353 }
354 _ => {}
355 }
356
357 let sock = Arc::new(UdpSocket::from_std(sock.into())?);
358 debug!(
359 socket_id = sock_id,
360 addr = %socket_bound_addr,
361 "bound UDP socket"
362 );
363
364 Ok((sock_id, sock))
365 })
366 .filter_map(|result| match result {
367 Ok(bound) => Some(bound),
368 Err(error) => {
369 error!(
370 %binding_to,
371 %error,
372 "failed to bind udp socket"
373 );
374 None
375 }
376 })
377 .for_each(|(sock_id, sock)| {
378 if sock_id < num_tx_socks {
379 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
380 sock_id,
381 sock.clone(),
382 egress_rx.clone(),
383 counterparty.clone(),
384 )));
385 }
386 if sock_id < num_rx_socks {
387 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
388 sock_id,
389 sock.clone(),
390 ingress_tx.clone(),
391 counterparty.clone(),
392 self.foreign_data_mode,
393 self.buffer_size,
394 )));
395 }
396 });
397 }
398
399 Ok(ConnectedUdpStream {
400 ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
401 egress_tx: Some(Box::new(
402 egress_tx
403 .into_sink()
404 .sink_map_err(|e| std::io::Error::other(e.to_string())),
405 )),
406 socket_handles,
407 counterparty,
408 bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
409 state: State::Writing,
410 })
411 }
412}
413
414const QUEUE_DISPATCH_THRESHOLD: Duration = Duration::from_millis(150);
415
416impl ConnectedUdpStream {
417 fn setup_rx_queue(
419 socket_id: usize,
420 sock_rx: Arc<UdpSocket>,
421 ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
422 counterparty: Arc<OnceLock<SocketAddrStr>>,
423 foreign_data_mode: ForeignDataMode,
424 buf_size: usize,
425 ) -> futures::future::BoxFuture<'static, ()> {
426 let counterparty_rx = counterparty.clone();
427 async move {
428 let mut buffer = vec![0u8; buf_size];
429 let mut done = false;
430 loop {
431 let out_res = match sock_rx.recv_from(&mut buffer).await {
433 Ok((read, read_addr)) if read > 0 => {
434 trace!(
435 socket_id,
436 udp_bound_addr = ?sock_rx.local_addr(),
437 bytes = read,
438 from = %read_addr,
439 "received data from"
440 );
441
442 let addr = counterparty_rx.get_or_init(|| read_addr.into());
443
444 #[cfg(all(feature = "prometheus", not(test)))]
445 METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
446
447 if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
449 let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
450 Some(Ok(out_buffer))
451 } else {
452 match foreign_data_mode {
453 ForeignDataMode::Discard => {
454 warn!(
456 socket_id,
457 udp_bound_addr = ?sock_rx.local_addr(),
458 ?read_addr,
459 expected_addr = ?addr,
460 "discarded data, which didn't come from the expected address"
461 );
462 None
463 }
464 ForeignDataMode::Error => {
465 done = true;
467 Some(Err(std::io::Error::new(
468 ErrorKind::ConnectionRefused,
469 "data from foreign client not allowed",
470 )))
471 }
472 _ => unreachable!(),
474 }
475 }
476 }
477 Ok(_) => {
478 trace!(
480 socket_id,
481 udp_bound_addr = ?sock_rx.local_addr(),
482 "read EOF on socket"
483 );
484 done = true;
485 None
486 }
487 Err(error) => {
488 debug!(
490 socket_id,
491 udp_bound_addr = ?sock_rx.local_addr(),
492 %error,
493 "forwarded error from socket"
494 );
495 done = true;
496 Some(Err(error))
497 }
498 };
499
500 if let Some(out_res) = out_res {
503 let start = std::time::Instant::now();
504 if let Err(error) = ingress_tx.send_async(out_res).await {
505 error!(
506 socket_id,
507 udp_bound_addr = ?sock_rx.local_addr(),
508 %error,
509 "failed to dispatch received data"
510 );
511 done = true;
512 }
513 let elapsed = start.elapsed();
514 if elapsed > QUEUE_DISPATCH_THRESHOLD {
515 warn!(
516 ?elapsed,
517 "udp queue dispatch took too long, consider increasing the queue size"
518 );
519 }
520 }
521
522 if done {
523 trace!(
524 socket_id,
525 udp_bound_addr = ?sock_rx.local_addr(),
526 "rx queue done"
527 );
528 break;
529 }
530 }
531 }
532 .boxed()
533 }
534
535 fn setup_tx_queue(
537 socket_id: usize,
538 sock_tx: Arc<UdpSocket>,
539 egress_rx: flume::Receiver<Box<[u8]>>,
540 counterparty: Arc<OnceLock<SocketAddrStr>>,
541 ) -> futures::future::BoxFuture<'static, ()> {
542 let counterparty_tx = counterparty.clone();
543 async move {
544 loop {
545 match egress_rx.recv_async().await {
546 Ok(data) => {
547 if let Some(target) = counterparty_tx.get() {
548 if let Err(error) = sock_tx.send_to(&data, target.as_ref()).await {
549 error!(
550 ?socket_id,
551 udp_bound_addr = ?sock_tx.local_addr(),
552 ?target,
553 %error,
554 "failed to send data"
555 );
556 }
557 trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
558
559 #[cfg(all(feature = "prometheus", not(test)))]
560 METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
561 } else {
562 error!(
563 ?socket_id,
564 udp_bound_addr = ?sock_tx.local_addr(),
565 "cannot send data, counterparty not set"
566 );
567 break;
568 }
569 }
570 Err(error) => {
571 error!(
572 ?socket_id,
573 udp_bound_addr = ?sock_tx.local_addr(),
574 %error,
575 "cannot receive more data from egress channel"
576 );
577 break;
578 }
579 }
580 trace!(
581 ?socket_id,
582 udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
583 "tx queue done"
584 );
585 }
586 }
587 .boxed()
588 }
589
590 pub fn bound_address(&self) -> &std::net::SocketAddr {
592 &self.bound_to
593 }
594
595 pub fn builder() -> UdpStreamBuilder {
597 UdpStreamBuilder::default()
598 }
599}
600
601impl tokio::io::AsyncRead for ConnectedUdpStream {
602 #[instrument(name = "ConnectedUdpStream::poll_read", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), rem = buf.remaining()) , ret)]
603 fn poll_read(
604 self: Pin<&mut Self>,
605 cx: &mut Context<'_>,
606 buf: &mut tokio::io::ReadBuf<'_>,
607 ) -> Poll<std::io::Result<()>> {
608 ready!(self.project().ingress_rx.poll_read(cx, buf))?;
609 trace!(bytes = buf.filled().len(), "read bytes");
610 Poll::Ready(Ok(()))
611 }
612}
613
614impl tokio::io::AsyncWrite for ConnectedUdpStream {
615 #[instrument(name = "ConnectedUdpStream::poll_write", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), len = buf.len()) , ret)]
616 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
617 let this = self.project();
618 if let Some(sender) = this.egress_tx.get_mut() {
619 loop {
620 match *this.state {
621 State::Writing => {
622 ready!(sender.poll_ready_unpin(cx))?;
623
624 let len = buf.iter().len();
625 sender.start_send_unpin(Box::from(buf))?;
626 *this.state = State::Flushing(len);
627 }
628 State::Flushing(len) => {
629 let res = ready!(sender.poll_flush_unpin(cx)).map(|_| len);
631 *this.state = State::Writing;
632
633 return Poll::Ready(res);
634 }
635 }
636 }
637 } else {
638 Poll::Ready(Err(std::io::Error::new(
639 ErrorKind::NotConnected,
640 "udp stream is closed",
641 )))
642 }
643 }
644
645 #[instrument(name = "ConnectedUdpStream::poll_flush", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
646 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
647 let this = self.project();
648 if let Some(sender) = this.egress_tx.as_pin_mut() {
649 sender.poll_flush(cx).map_err(std::io::Error::other)
650 } else {
651 Poll::Ready(Err(std::io::Error::new(
652 ErrorKind::NotConnected,
653 "udp stream is closed",
654 )))
655 }
656 }
657
658 #[instrument(name = "ConnectedUdpStream::poll_shutdown", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
659 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
660 let this = self.project();
661 if let Some(sender) = this.egress_tx.as_pin_mut() {
662 let ret = ready!(sender.poll_close(cx));
663
664 this.socket_handles.iter().for_each(|handle| {
665 handle.abort();
666 });
667
668 Poll::Ready(ret)
669 } else {
670 Poll::Ready(Err(std::io::Error::new(
671 ErrorKind::NotConnected,
672 "udp stream is closed",
673 )))
674 }
675 }
676}
677
678#[pin_project::pinned_drop]
679impl PinnedDrop for ConnectedUdpStream {
680 fn drop(self: Pin<&mut Self>) {
681 debug!(binding = ?self.bound_to,"dropping ConnectedUdpStream");
682 self.project().socket_handles.iter().for_each(|handle| {
683 handle.abort();
684 })
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use anyhow::Context;
691 use futures::{future::Either, pin_mut};
692 use parameterized::parameterized;
693 use tokio::{
694 io::{AsyncReadExt, AsyncWriteExt},
695 net::UdpSocket,
696 };
697
698 use super::*;
699
700 #[parameterized(parallelism = {None, Some(2), Some(0)})]
701 #[parameterized_macro(tokio::test)]
702 async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
704 const DATA_SIZE: usize = 200;
705
706 let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
707 let listen_addr = listener.local_addr()?;
708
709 tokio::task::spawn(async move {
711 loop {
712 let mut buf = [0u8; DATA_SIZE];
713 let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
714 if read > 0 {
715 assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
716 listener.send_to(&buf, addr).await.expect("send must not fail");
717 }
718 }
719 });
720
721 let mut builder = ConnectedUdpStream::builder()
722 .with_buffer_size(1024)
723 .with_queue_size(512)
724 .with_counterparty(listen_addr);
725
726 if let Some(parallelism) = parallelism {
727 builder = builder.with_receiver_parallelism(parallelism);
728 }
729
730 let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
731
732 for _ in 1..1000 {
733 let mut w_buf = [0u8; DATA_SIZE];
734 hopr_crypto_random::random_fill(&mut w_buf);
735 let written = stream.write(&w_buf).await?;
736 assert_eq!(written, DATA_SIZE);
737
738 let mut r_buf = [0u8; DATA_SIZE];
739 let read = stream.read_exact(&mut r_buf).await?;
740 assert_eq!(read, DATA_SIZE);
741
742 assert_eq!(w_buf, r_buf);
743 }
744
745 stream.shutdown().await?;
746
747 Ok(())
748 }
749
750 #[tokio::test]
751 async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
752 const BUF_SIZE: usize = 1024;
753 const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
754
755 let mut listener = ConnectedUdpStream::builder()
756 .with_buffer_size(BUF_SIZE)
757 .with_queue_size(512)
758 .build(("127.0.0.1", 0))
759 .context("bind listener")?;
760
761 let bound_addr = *listener.bound_address();
762
763 let jh = tokio::task::spawn(async move {
764 let mut buf = [0u8; BUF_SIZE / 4];
765 let mut vec = Vec::<u8>::new();
766 loop {
767 let sz = listener.read(&mut buf).await.unwrap();
768 if sz > 0 {
769 vec.extend_from_slice(&buf[..sz]);
770 if vec.len() >= EXPECTED_DATA_LEN {
771 return vec;
772 }
773 } else {
774 return vec;
775 }
776 }
777 });
778
779 let msg = [1u8; EXPECTED_DATA_LEN];
780 let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
781
782 sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
783 sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
784
785 let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
786 pin_mut!(timeout);
787 pin_mut!(jh);
788
789 match futures::future::select(jh, timeout).await {
790 Either::Left((Ok(v), _)) => {
791 assert_eq!(v.len(), EXPECTED_DATA_LEN);
792 assert_eq!(v.as_slice(), &msg);
793 Ok(())
794 }
795 _ => Err(anyhow::anyhow!("timeout or invalid data")),
796 }
797 }
798}