1use std::{
2 fmt::Debug,
3 io::ErrorKind,
4 num::NonZeroUsize,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8};
9
10use futures::{FutureExt, Sink, SinkExt, pin_mut, ready};
11use tokio::net::UdpSocket;
12use tracing::{debug, error, trace, warn};
13
14use crate::utils::SocketAddrStr;
15
16type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
17
18#[cfg(all(feature = "prometheus", not(test)))]
19lazy_static::lazy_static! {
20 static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
21 hopr_metrics::MultiHistogram::new(
22 "hopr_udp_ingress_packet_len",
23 "UDP packet lengths on ingress per counterparty",
24 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
25 &["counterparty"]
26 ).unwrap();
27 static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
28 hopr_metrics::MultiHistogram::new(
29 "hopr_udp_egress_packet_len",
30 "UDP packet lengths on egress per counterparty",
31 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
32 &["counterparty"]
33 ).unwrap();
34}
35
36pub struct ConnectedUdpStream {
57 socket_handles: Vec<tokio::task::JoinHandle<()>>,
58 ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
59 egress_tx: Option<BoxIoSink<Box<[u8]>>>,
60 counterparty: Arc<OnceLock<SocketAddrStr>>,
61 bound_to: std::net::SocketAddr,
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
68pub enum ForeignDataMode {
69 Discard,
71 Accept,
73 #[default]
75 Error,
76}
77
78#[derive(Copy, Clone, Debug, PartialEq, Eq)]
95pub enum UdpStreamParallelism {
96 Auto,
98 Specific(NonZeroUsize),
100}
101
102impl Default for UdpStreamParallelism {
103 fn default() -> Self {
104 Self::Specific(NonZeroUsize::MIN)
105 }
106}
107
108impl UdpStreamParallelism {
109 fn avail_parallelism() -> usize {
110 std::thread::available_parallelism()
113 .map(|n| {
114 if cfg!(target_os = "linux") {
115 n
116 } else {
117 NonZeroUsize::MIN
118 }
119 })
120 .unwrap_or_else(|e| {
121 warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
122 NonZeroUsize::MIN
123 })
124 .into()
125 }
126
127 pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
130 let cpu_half = (Self::avail_parallelism() / 2).max(1);
131
132 match (self, other) {
133 (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
134 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
135 let a = cpu_half.min(a.into());
136 (a, cpu_half * 2 - a)
137 }
138 (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
139 let b = cpu_half.min(b.into());
140 (cpu_half * 2 - b, b)
141 }
142 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
143 (cpu_half.min(a.into()), cpu_half.min(b.into()))
144 }
145 }
146 }
147
148 pub fn into_num_tasks(self) -> usize {
152 let avail_parallelism = Self::avail_parallelism();
153 match self {
154 UdpStreamParallelism::Auto => avail_parallelism,
155 UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
156 }
157 }
158}
159
160impl From<usize> for UdpStreamParallelism {
161 fn from(value: usize) -> Self {
162 NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
163 }
164}
165
166impl From<Option<usize>> for UdpStreamParallelism {
167 fn from(value: Option<usize>) -> Self {
168 value.map(UdpStreamParallelism::from).unwrap_or_default()
169 }
170}
171
172#[derive(Debug, Clone)]
176pub struct UdpStreamBuilder {
177 foreign_data_mode: ForeignDataMode,
178 buffer_size: usize,
179 queue_size: Option<usize>,
180 receiver_parallelism: UdpStreamParallelism,
181 sender_parallelism: UdpStreamParallelism,
182 counterparty: Option<std::net::SocketAddr>,
183}
184
185impl Default for UdpStreamBuilder {
186 fn default() -> Self {
187 Self {
188 buffer_size: 2048,
189 foreign_data_mode: Default::default(),
190 queue_size: None,
191 receiver_parallelism: Default::default(),
192 sender_parallelism: Default::default(),
193 counterparty: None,
194 }
195 }
196}
197
198impl UdpStreamBuilder {
199 pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
204 self.foreign_data_mode = mode;
205 self
206 }
207
208 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
215 self.buffer_size = buffer_size;
216 self
217 }
218
219 pub fn with_queue_size(mut self, queue_size: usize) -> Self {
230 self.queue_size = Some(queue_size);
231 self
232 }
233
234 pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
240 self.receiver_parallelism = parallelism.into();
241 self
242 }
243
244 pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
250 self.sender_parallelism = parallelism.into();
251 self
252 }
253
254 pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
265 self.counterparty = Some(counterparty);
266 self
267 }
268
269 pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
287 let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
288
289 let counterparty = Arc::new(
290 self.counterparty
291 .map(|s| OnceLock::from(SocketAddrStr::from(s)))
292 .unwrap_or_default(),
293 );
294 let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
295 (flume::bounded(q), flume::bounded(q))
296 } else {
297 (flume::unbounded(), flume::unbounded())
298 };
299
300 let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
301 let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
302 let mut bound_addr: Option<std::net::SocketAddr> = None;
303
304 for binding_to in bind_addr.to_socket_addrs()? {
306 debug!(
307 %binding_to,
308 num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
309 );
310
311 (0..num_socks_to_bind)
315 .map(|sock_id| {
316 let domain = match &binding_to {
317 std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
318 std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
319 };
320
321 let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
323 if num_socks_to_bind > 1 {
324 sock.set_reuse_address(true)?; sock.set_reuse_port(true)?; }
327 sock.set_nonblocking(true)?;
328 sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
329
330 let socket_bound_addr = sock
332 .local_addr()?
333 .as_socket()
334 .ok_or(std::io::Error::other("invalid socket type"))?;
335
336 match bound_addr {
337 None => bound_addr = Some(socket_bound_addr),
338 Some(addr) if addr != socket_bound_addr => {
339 return Err(std::io::Error::other(format!(
340 "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
341 )));
342 }
343 _ => {}
344 }
345
346 let sock = Arc::new(UdpSocket::from_std(sock.into())?);
347 debug!(
348 socket_id = sock_id,
349 addr = %socket_bound_addr,
350 "bound UDP socket"
351 );
352
353 Ok((sock_id, sock))
354 })
355 .filter_map(|result| match result {
356 Ok(bound) => Some(bound),
357 Err(e) => {
358 error!(
359 %binding_to,
360 "failed to bind udp socket: {e}"
361 );
362 None
363 }
364 })
365 .for_each(|(sock_id, sock)| {
366 if sock_id < num_tx_socks {
367 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
368 sock_id,
369 sock.clone(),
370 egress_rx.clone(),
371 counterparty.clone(),
372 )));
373 }
374 if sock_id < num_rx_socks {
375 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
376 sock_id,
377 sock.clone(),
378 ingress_tx.clone(),
379 counterparty.clone(),
380 self.foreign_data_mode,
381 self.buffer_size,
382 )));
383 }
384 });
385 }
386
387 Ok(ConnectedUdpStream {
388 ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
389 egress_tx: Some(Box::new(
390 egress_tx
391 .into_sink()
392 .sink_map_err(|e| std::io::Error::other(e.to_string())),
393 )),
394 socket_handles,
395 counterparty,
396 bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
397 })
398 }
399}
400
401impl ConnectedUdpStream {
402 fn setup_rx_queue(
404 socket_id: usize,
405 sock_rx: Arc<UdpSocket>,
406 ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
407 counterparty: Arc<OnceLock<SocketAddrStr>>,
408 foreign_data_mode: ForeignDataMode,
409 buf_size: usize,
410 ) -> futures::future::BoxFuture<'static, ()> {
411 let counterparty_rx = counterparty.clone();
412 async move {
413 let mut buffer = vec![0u8; buf_size];
414 let mut done = false;
415 loop {
416 let out_res = match sock_rx.recv_from(&mut buffer).await {
418 Ok((read, read_addr)) if read > 0 => {
419 trace!(
420 socket_id,
421 udp_bound_addr = ?sock_rx.local_addr(),
422 bytes = read,
423 from = %read_addr,
424 "received data from"
425 );
426
427 let addr = counterparty_rx.get_or_init(|| read_addr.into());
428
429 #[cfg(all(feature = "prometheus", not(test)))]
430 METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
431
432 if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
434 let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
435 Some(Ok(out_buffer))
436 } else {
437 match foreign_data_mode {
438 ForeignDataMode::Discard => {
439 warn!(
441 socket_id,
442 udp_bound_addr = ?sock_rx.local_addr(),
443 ?read_addr,
444 expected_addr = ?addr,
445 "discarded data, which didn't come from the expected address"
446 );
447 None
448 }
449 ForeignDataMode::Error => {
450 done = true;
452 Some(Err(std::io::Error::new(
453 ErrorKind::ConnectionRefused,
454 "data from foreign client not allowed",
455 )))
456 }
457 _ => unreachable!(),
459 }
460 }
461 }
462 Ok(_) => {
463 trace!(
465 socket_id,
466 udp_bound_addr = ?sock_rx.local_addr(),
467 "read EOF on socket"
468 );
469 done = true;
470 None
471 }
472 Err(e) => {
473 debug!(
475 socket_id,
476 udp_bound_addr = ?sock_rx.local_addr(),
477 error = %e,
478 "forwarded error from socket"
479 );
480 done = true;
481 Some(Err(e))
482 }
483 };
484
485 if let Some(out_res) = out_res {
488 if let Err(err) = ingress_tx.send_async(out_res).await {
489 error!(
490 socket_id,
491 udp_bound_addr = ?sock_rx.local_addr(),
492 error = %err,
493 "failed to dispatch received data"
494 );
495 done = true;
496 }
497 }
498
499 if done {
500 trace!(
501 socket_id,
502 udp_bound_addr = ?sock_rx.local_addr(),
503 "rx queue done"
504 );
505 break;
506 }
507 }
508 }
509 .boxed()
510 }
511
512 fn setup_tx_queue(
514 socket_id: usize,
515 sock_tx: Arc<UdpSocket>,
516 egress_rx: flume::Receiver<Box<[u8]>>,
517 counterparty: Arc<OnceLock<SocketAddrStr>>,
518 ) -> futures::future::BoxFuture<'static, ()> {
519 let counterparty_tx = counterparty.clone();
520 async move {
521 loop {
522 match egress_rx.recv_async().await {
523 Ok(data) => {
524 if let Some(target) = counterparty_tx.get() {
525 if let Err(e) = sock_tx.send_to(&data, target.as_ref()).await {
526 error!(
527 ?socket_id,
528 udp_bound_addr = ?sock_tx.local_addr(),
529 ?target,
530 error = %e,
531 "failed to send data"
532 );
533 }
534 trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
535
536 #[cfg(all(feature = "prometheus", not(test)))]
537 METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
538 } else {
539 error!(
540 ?socket_id,
541 udp_bound_addr = ?sock_tx.local_addr(),
542 "cannot send data, counterparty not set"
543 );
544 break;
545 }
546 }
547 Err(e) => {
548 error!(
549 ?socket_id,
550 udp_bound_addr = ?sock_tx.local_addr(),
551 error = %e,
552 "cannot receive more data from egress channel"
553 );
554 break;
555 }
556 }
557 trace!(
558 ?socket_id,
559 udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
560 "tx queue done"
561 );
562 }
563 }
564 .boxed()
565 }
566
567 pub fn bound_address(&self) -> &std::net::SocketAddr {
569 &self.bound_to
570 }
571
572 pub fn builder() -> UdpStreamBuilder {
574 UdpStreamBuilder::default()
575 }
576}
577
578impl Drop for ConnectedUdpStream {
579 fn drop(&mut self) {
580 self.socket_handles.iter().for_each(|handle| {
581 handle.abort();
582 })
583 }
584}
585
586impl tokio::io::AsyncRead for ConnectedUdpStream {
587 fn poll_read(
588 mut self: Pin<&mut Self>,
589 cx: &mut Context<'_>,
590 buf: &mut tokio::io::ReadBuf<'_>,
591 ) -> Poll<std::io::Result<()>> {
592 trace!(
593 remaining_bytes = buf.remaining(),
594 counterparty = ?self.counterparty.get(),
595 "polling read of from udp stream",
596 );
597 match Pin::new(&mut self.ingress_rx).poll_read(cx, buf) {
598 Poll::Ready(Ok(())) => {
599 let read = buf.filled().len();
600 trace!(bytes = read, "read bytes");
601 Poll::Ready(Ok(()))
602 }
603 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
604 Poll::Pending => Poll::Pending,
605 }
606 }
607}
608
609impl tokio::io::AsyncWrite for ConnectedUdpStream {
610 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
611 trace!(
612 bytes = buf.len(),
613 counterparty = ?self.counterparty.get(),
614 "polling write to udp stream",
615 );
616 if let Some(sender) = &mut self.egress_tx {
617 if let Err(e) = ready!(sender.poll_ready_unpin(cx)) {
618 return Poll::Ready(Err(e));
619 }
620
621 let len = buf.iter().len();
622 if let Err(e) = sender.start_send_unpin(Box::from(buf)) {
623 return Poll::Ready(Err(e));
624 }
625
626 pin_mut!(sender);
628 sender.poll_flush(cx).map_ok(|_| len)
629 } else {
630 Poll::Ready(Err(std::io::Error::new(
631 ErrorKind::NotConnected,
632 "udp stream is closed",
633 )))
634 }
635 }
636
637 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
638 trace!(
639 counterparty = ?self.counterparty.get(),
640 "polling flush to udp stream"
641 );
642 if let Some(sender) = &mut self.egress_tx {
643 pin_mut!(sender);
644 sender
645 .poll_flush(cx)
646 .map_err(|err| std::io::Error::other(err.to_string()))
647 } else {
648 Poll::Ready(Err(std::io::Error::new(
649 ErrorKind::NotConnected,
650 "udp stream is closed",
651 )))
652 }
653 }
654
655 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
656 trace!(
657 counterparty = ?self.counterparty.get(),
658 "polling close on udp stream"
659 );
660 let mut taken_sender = self.egress_tx.take();
662 if let Some(sender) = &mut taken_sender {
663 pin_mut!(sender);
664 sender
665 .poll_close(cx)
666 .map_err(|err| std::io::Error::other(err.to_string()))
667 } else {
668 Poll::Ready(Err(std::io::Error::new(
669 ErrorKind::NotConnected,
670 "udp stream is closed",
671 )))
672 }
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use anyhow::Context;
679 use futures::future::Either;
680 use parameterized::parameterized;
681 use tokio::{
682 io::{AsyncReadExt, AsyncWriteExt},
683 net::UdpSocket,
684 };
685
686 use super::*;
687
688 #[parameterized(parallelism = {None, Some(2), Some(0)})]
689 #[parameterized_macro(tokio::test)]
690 async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
692 const DATA_SIZE: usize = 200;
693
694 let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
695 let listen_addr = listener.local_addr()?;
696
697 tokio::task::spawn(async move {
699 loop {
700 let mut buf = [0u8; DATA_SIZE];
701 let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
702 if read > 0 {
703 assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
704 listener.send_to(&buf, addr).await.expect("send must not fail");
705 }
706 }
707 });
708
709 let mut builder = ConnectedUdpStream::builder()
710 .with_buffer_size(1024)
711 .with_queue_size(512)
712 .with_counterparty(listen_addr);
713
714 if let Some(parallelism) = parallelism {
715 builder = builder.with_receiver_parallelism(parallelism);
716 }
717
718 let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
719
720 for _ in 1..1000 {
721 let mut w_buf = [0u8; DATA_SIZE];
722 hopr_crypto_random::random_fill(&mut w_buf);
723 let written = stream.write(&w_buf).await?;
724 assert_eq!(written, DATA_SIZE);
725
726 let mut r_buf = [0u8; DATA_SIZE];
727 let read = stream.read_exact(&mut r_buf).await?;
728 assert_eq!(read, DATA_SIZE);
729
730 assert_eq!(w_buf, r_buf);
731 }
732
733 stream.shutdown().await?;
734
735 Ok(())
736 }
737
738 #[tokio::test]
739 async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
740 const BUF_SIZE: usize = 1024;
741 const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
742
743 let mut listener = ConnectedUdpStream::builder()
744 .with_buffer_size(BUF_SIZE)
745 .with_queue_size(512)
746 .build(("127.0.0.1", 0))
747 .context("bind listener")?;
748
749 let bound_addr = *listener.bound_address();
750
751 let jh = tokio::task::spawn(async move {
752 let mut buf = [0u8; BUF_SIZE / 4];
753 let mut vec = Vec::<u8>::new();
754 loop {
755 let sz = listener.read(&mut buf).await.unwrap();
756 if sz > 0 {
757 vec.extend_from_slice(&buf[..sz]);
758 if vec.len() >= EXPECTED_DATA_LEN {
759 return vec;
760 }
761 } else {
762 return vec;
763 }
764 }
765 });
766
767 let msg = [1u8; EXPECTED_DATA_LEN];
768 let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
769
770 sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
771 sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
772
773 let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
774 pin_mut!(timeout);
775 pin_mut!(jh);
776
777 match futures::future::select(jh, timeout).await {
778 Either::Left((Ok(v), _)) => {
779 assert_eq!(v.len(), EXPECTED_DATA_LEN);
780 assert_eq!(v.as_slice(), &msg);
781 Ok(())
782 }
783 _ => Err(anyhow::anyhow!("timeout or invalid data")),
784 }
785 }
786}