1use crate::utils::SocketAddrStr;
2use futures::{pin_mut, ready, FutureExt, Sink, SinkExt};
3use std::fmt::Debug;
4use std::io::ErrorKind;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::sync::{Arc, OnceLock};
8use std::task::{Context, Poll};
9use tokio::net::UdpSocket;
10use tracing::{debug, error, trace, warn};
11
12type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
13
14#[cfg(all(feature = "prometheus", not(test)))]
15lazy_static::lazy_static! {
16 static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
17 hopr_metrics::MultiHistogram::new(
18 "hopr_udp_ingress_packet_len",
19 "UDP packet lengths on ingress per counterparty",
20 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
21 &["counterparty"]
22 ).unwrap();
23 static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
24 hopr_metrics::MultiHistogram::new(
25 "hopr_udp_egress_packet_len",
26 "UDP packet lengths on egress per counterparty",
27 vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
28 &["counterparty"]
29 ).unwrap();
30}
31
32pub struct ConnectedUdpStream {
53 socket_handles: Vec<tokio::task::JoinHandle<()>>,
54 ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
55 egress_tx: Option<BoxIoSink<Box<[u8]>>>,
56 counterparty: Arc<OnceLock<SocketAddrStr>>,
57 bound_to: std::net::SocketAddr,
58}
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
64pub enum ForeignDataMode {
65 Discard,
67 Accept,
69 #[default]
71 Error,
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
89pub enum UdpStreamParallelism {
90 Auto,
92 Specific(NonZeroUsize),
94}
95
96impl Default for UdpStreamParallelism {
97 fn default() -> Self {
98 Self::Specific(NonZeroUsize::MIN)
99 }
100}
101
102impl UdpStreamParallelism {
103 fn avail_parallelism() -> usize {
104 std::thread::available_parallelism()
107 .map(|n| {
108 if cfg!(target_os = "linux") {
109 n
110 } else {
111 NonZeroUsize::MIN
112 }
113 })
114 .unwrap_or_else(|e| {
115 warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
116 NonZeroUsize::MIN
117 })
118 .into()
119 }
120
121 pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
124 let cpu_half = (Self::avail_parallelism() / 2).max(1);
125
126 match (self, other) {
127 (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
128 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
129 let a = cpu_half.min(a.into());
130 (a, cpu_half * 2 - a)
131 }
132 (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
133 let b = cpu_half.min(b.into());
134 (cpu_half * 2 - b, b)
135 }
136 (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
137 (cpu_half.min(a.into()), cpu_half.min(b.into()))
138 }
139 }
140 }
141
142 pub fn into_num_tasks(self) -> usize {
146 let avail_parallelism = Self::avail_parallelism();
147 match self {
148 UdpStreamParallelism::Auto => avail_parallelism,
149 UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
150 }
151 }
152}
153
154impl From<usize> for UdpStreamParallelism {
155 fn from(value: usize) -> Self {
156 NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
157 }
158}
159
160impl From<Option<usize>> for UdpStreamParallelism {
161 fn from(value: Option<usize>) -> Self {
162 value.map(UdpStreamParallelism::from).unwrap_or_default()
163 }
164}
165
166#[derive(Debug, Clone)]
170pub struct UdpStreamBuilder {
171 foreign_data_mode: ForeignDataMode,
172 buffer_size: usize,
173 queue_size: Option<usize>,
174 receiver_parallelism: UdpStreamParallelism,
175 sender_parallelism: UdpStreamParallelism,
176 counterparty: Option<std::net::SocketAddr>,
177}
178
179impl Default for UdpStreamBuilder {
180 fn default() -> Self {
181 Self {
182 buffer_size: 2048,
183 foreign_data_mode: Default::default(),
184 queue_size: None,
185 receiver_parallelism: Default::default(),
186 sender_parallelism: Default::default(),
187 counterparty: None,
188 }
189 }
190}
191
192impl UdpStreamBuilder {
193 pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
198 self.foreign_data_mode = mode;
199 self
200 }
201
202 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
209 self.buffer_size = buffer_size;
210 self
211 }
212
213 pub fn with_queue_size(mut self, queue_size: usize) -> Self {
224 self.queue_size = Some(queue_size);
225 self
226 }
227
228 pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
234 self.receiver_parallelism = parallelism.into();
235 self
236 }
237
238 pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
244 self.sender_parallelism = parallelism.into();
245 self
246 }
247
248 pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
259 self.counterparty = Some(counterparty);
260 self
261 }
262
263 pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
280 let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
281
282 let counterparty = Arc::new(
283 self.counterparty
284 .map(|s| OnceLock::from(SocketAddrStr::from(s)))
285 .unwrap_or_default(),
286 );
287 let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
288 (flume::bounded(q), flume::bounded(q))
289 } else {
290 (flume::unbounded(), flume::unbounded())
291 };
292
293 let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
294 let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
295 let mut bound_addr: Option<std::net::SocketAddr> = None;
296
297 for binding_to in bind_addr.to_socket_addrs()? {
299 debug!(
300 %binding_to,
301 num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
302 );
303
304 (0..num_socks_to_bind)
308 .map(|sock_id| {
309 let domain = match &binding_to {
310 std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
311 std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
312 };
313
314 let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
316 if num_socks_to_bind > 1 {
317 sock.set_reuse_address(true)?; sock.set_reuse_port(true)?; }
320 sock.set_nonblocking(true)?;
321 sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
322
323 let socket_bound_addr = sock
325 .local_addr()?
326 .as_socket()
327 .ok_or(std::io::Error::other("invalid socket type"))?;
328
329 match bound_addr {
330 None => bound_addr = Some(socket_bound_addr),
331 Some(addr) if addr != socket_bound_addr => {
332 return Err(std::io::Error::other(format!(
333 "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
334 )))
335 }
336 _ => {}
337 }
338
339 let sock = Arc::new(UdpSocket::from_std(sock.into())?);
340 debug!(
341 socket_id = sock_id,
342 addr = %socket_bound_addr,
343 "bound UDP socket"
344 );
345
346 Ok((sock_id, sock))
347 })
348 .filter_map(|result| match result {
349 Ok(bound) => Some(bound),
350 Err(e) => {
351 error!(
352 %binding_to,
353 "failed to bind udp socket: {e}"
354 );
355 None
356 }
357 })
358 .for_each(|(sock_id, sock)| {
359 if sock_id < num_tx_socks {
360 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
361 sock_id,
362 sock.clone(),
363 egress_rx.clone(),
364 counterparty.clone(),
365 )));
366 }
367 if sock_id < num_rx_socks {
368 socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
369 sock_id,
370 sock.clone(),
371 ingress_tx.clone(),
372 counterparty.clone(),
373 self.foreign_data_mode,
374 self.buffer_size,
375 )));
376 }
377 });
378 }
379
380 Ok(ConnectedUdpStream {
381 ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
382 egress_tx: Some(Box::new(
383 egress_tx
384 .into_sink()
385 .sink_map_err(|e| std::io::Error::other(e.to_string())),
386 )),
387 socket_handles,
388 counterparty,
389 bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
390 })
391 }
392}
393
394impl ConnectedUdpStream {
395 fn setup_rx_queue(
397 socket_id: usize,
398 sock_rx: Arc<UdpSocket>,
399 ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
400 counterparty: Arc<OnceLock<SocketAddrStr>>,
401 foreign_data_mode: ForeignDataMode,
402 buf_size: usize,
403 ) -> futures::future::BoxFuture<'static, ()> {
404 let counterparty_rx = counterparty.clone();
405 async move {
406 let mut buffer = vec![0u8; buf_size];
407 let mut done = false;
408 loop {
409 let out_res = match sock_rx.recv_from(&mut buffer).await {
411 Ok((read, read_addr)) if read > 0 => {
412 trace!(
413 socket_id,
414 udp_bound_addr = ?sock_rx.local_addr(),
415 bytes = read,
416 from = %read_addr,
417 "received data from"
418 );
419
420 let addr = counterparty_rx.get_or_init(|| read_addr.into());
421
422 #[cfg(all(feature = "prometheus", not(test)))]
423 METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
424
425 if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
427 let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
428 Some(Ok(out_buffer))
429 } else {
430 match foreign_data_mode {
431 ForeignDataMode::Discard => {
432 warn!(
434 socket_id,
435 udp_bound_addr = ?sock_rx.local_addr(),
436 ?read_addr,
437 expected_addr = ?addr,
438 "discarded data, which didn't come from the expected address"
439 );
440 None
441 }
442 ForeignDataMode::Error => {
443 done = true;
445 Some(Err(std::io::Error::new(
446 ErrorKind::ConnectionRefused,
447 "data from foreign client not allowed",
448 )))
449 }
450 _ => unreachable!(),
452 }
453 }
454 }
455 Ok(_) => {
456 trace!(
458 socket_id,
459 udp_bound_addr = ?sock_rx.local_addr(),
460 "read EOF on socket"
461 );
462 done = true;
463 None
464 }
465 Err(e) => {
466 debug!(
468 socket_id,
469 udp_bound_addr = ?sock_rx.local_addr(),
470 error = %e,
471 "forwarded error from socket"
472 );
473 done = true;
474 Some(Err(e))
475 }
476 };
477
478 if let Some(out_res) = out_res {
481 if let Err(err) = ingress_tx.send_async(out_res).await {
482 error!(
483 socket_id,
484 udp_bound_addr = ?sock_rx.local_addr(),
485 error = %err,
486 "failed to dispatch received data"
487 );
488 done = true;
489 }
490 }
491
492 if done {
493 trace!(
494 socket_id,
495 udp_bound_addr = ?sock_rx.local_addr(),
496 "rx queue done"
497 );
498 break;
499 }
500 }
501 }
502 .boxed()
503 }
504
505 fn setup_tx_queue(
507 socket_id: usize,
508 sock_tx: Arc<UdpSocket>,
509 egress_rx: flume::Receiver<Box<[u8]>>,
510 counterparty: Arc<OnceLock<SocketAddrStr>>,
511 ) -> futures::future::BoxFuture<'static, ()> {
512 let counterparty_tx = counterparty.clone();
513 async move {
514 loop {
515 match egress_rx.recv_async().await {
516 Ok(data) => {
517 if let Some(target) = counterparty_tx.get() {
518 if let Err(e) = sock_tx.send_to(&data, target.as_ref()).await {
519 error!(
520 ?socket_id,
521 udp_bound_addr = ?sock_tx.local_addr(),
522 ?target,
523 error = %e,
524 "failed to send data"
525 );
526 }
527 trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
528
529 #[cfg(all(feature = "prometheus", not(test)))]
530 METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
531 } else {
532 error!(
533 ?socket_id,
534 udp_bound_addr = ?sock_tx.local_addr(),
535 "cannot send data, counterparty not set"
536 );
537 break;
538 }
539 }
540 Err(e) => {
541 error!(
542 ?socket_id,
543 udp_bound_addr = ?sock_tx.local_addr(),
544 error = %e,
545 "cannot receive more data from egress channel"
546 );
547 break;
548 }
549 }
550 trace!(
551 ?socket_id,
552 udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
553 "tx queue done"
554 );
555 }
556 }
557 .boxed()
558 }
559
560 pub fn bound_address(&self) -> &std::net::SocketAddr {
562 &self.bound_to
563 }
564
565 pub fn builder() -> UdpStreamBuilder {
567 UdpStreamBuilder::default()
568 }
569}
570
571impl Drop for ConnectedUdpStream {
572 fn drop(&mut self) {
573 self.socket_handles.iter().for_each(|handle| {
574 handle.abort();
575 })
576 }
577}
578
579impl tokio::io::AsyncRead for ConnectedUdpStream {
580 fn poll_read(
581 mut self: Pin<&mut Self>,
582 cx: &mut Context<'_>,
583 buf: &mut tokio::io::ReadBuf<'_>,
584 ) -> Poll<std::io::Result<()>> {
585 trace!(
586 remaining_bytes = buf.remaining(),
587 counterparty = ?self.counterparty.get(),
588 "polling read of from udp stream",
589 );
590 match Pin::new(&mut self.ingress_rx).poll_read(cx, buf) {
591 Poll::Ready(Ok(())) => {
592 let read = buf.filled().len();
593 trace!(bytes = read, "read bytes");
594 Poll::Ready(Ok(()))
595 }
596 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
597 Poll::Pending => Poll::Pending,
598 }
599 }
600}
601
602impl tokio::io::AsyncWrite for ConnectedUdpStream {
603 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
604 trace!(
605 bytes = buf.len(),
606 counterparty = ?self.counterparty.get(),
607 "polling write to udp stream",
608 );
609 if let Some(sender) = &mut self.egress_tx {
610 if let Err(e) = ready!(sender.poll_ready_unpin(cx)) {
611 return Poll::Ready(Err(e));
612 }
613
614 let len = buf.iter().len();
615 if let Err(e) = sender.start_send_unpin(Box::from(buf)) {
616 return Poll::Ready(Err(e));
617 }
618
619 pin_mut!(sender);
621 sender.poll_flush(cx).map_ok(|_| len)
622 } else {
623 Poll::Ready(Err(std::io::Error::new(
624 ErrorKind::NotConnected,
625 "udp stream is closed",
626 )))
627 }
628 }
629
630 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
631 trace!(
632 counterparty = ?self.counterparty.get(),
633 "polling flush to udp stream"
634 );
635 if let Some(sender) = &mut self.egress_tx {
636 pin_mut!(sender);
637 sender
638 .poll_flush(cx)
639 .map_err(|err| std::io::Error::other(err.to_string()))
640 } else {
641 Poll::Ready(Err(std::io::Error::new(
642 ErrorKind::NotConnected,
643 "udp stream is closed",
644 )))
645 }
646 }
647
648 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
649 trace!(
650 counterparty = ?self.counterparty.get(),
651 "polling close on udp stream"
652 );
653 let mut taken_sender = self.egress_tx.take();
655 if let Some(sender) = &mut taken_sender {
656 pin_mut!(sender);
657 sender
658 .poll_close(cx)
659 .map_err(|err| std::io::Error::other(err.to_string()))
660 } else {
661 Poll::Ready(Err(std::io::Error::new(
662 ErrorKind::NotConnected,
663 "udp stream is closed",
664 )))
665 }
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use anyhow::Context;
673 use futures::future::Either;
674 use parameterized::parameterized;
675 use tokio::io::{AsyncReadExt, AsyncWriteExt};
676 use tokio::net::UdpSocket;
677
678 #[parameterized(parallelism = {None, Some(2), Some(0)})]
679 #[parameterized_macro(tokio::test)]
680 async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
682 const DATA_SIZE: usize = 200;
683
684 let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
685 let listen_addr = listener.local_addr()?;
686
687 tokio::task::spawn(async move {
689 loop {
690 let mut buf = [0u8; DATA_SIZE];
691 let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
692 if read > 0 {
693 assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
694 listener.send_to(&buf, addr).await.expect("send must not fail");
695 }
696 }
697 });
698
699 let mut builder = ConnectedUdpStream::builder()
700 .with_buffer_size(1024)
701 .with_queue_size(512)
702 .with_counterparty(listen_addr);
703
704 if let Some(parallelism) = parallelism {
705 builder = builder.with_receiver_parallelism(parallelism);
706 }
707
708 let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
709
710 for _ in 1..1000 {
711 let mut w_buf = [0u8; DATA_SIZE];
712 hopr_crypto_random::random_fill(&mut w_buf);
713 let written = stream.write(&w_buf).await?;
714 assert_eq!(written, DATA_SIZE);
715
716 let mut r_buf = [0u8; DATA_SIZE];
717 let read = stream.read_exact(&mut r_buf).await?;
718 assert_eq!(read, DATA_SIZE);
719
720 assert_eq!(w_buf, r_buf);
721 }
722
723 stream.shutdown().await?;
724
725 Ok(())
726 }
727
728 #[tokio::test]
729 async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
730 const BUF_SIZE: usize = 1024;
731 const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
732
733 let mut listener = ConnectedUdpStream::builder()
734 .with_buffer_size(BUF_SIZE)
735 .with_queue_size(512)
736 .build(("127.0.0.1", 0))
737 .context("bind listener")?;
738
739 let bound_addr = *listener.bound_address();
740
741 let jh = tokio::task::spawn(async move {
742 let mut buf = [0u8; BUF_SIZE / 4];
743 let mut vec = Vec::<u8>::new();
744 loop {
745 let sz = listener.read(&mut buf).await.unwrap();
746 if sz > 0 {
747 vec.extend_from_slice(&buf[..sz]);
748 if vec.len() >= EXPECTED_DATA_LEN {
749 return vec;
750 }
751 } else {
752 return vec;
753 }
754 }
755 });
756
757 let msg = [1u8; EXPECTED_DATA_LEN];
758 let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
759
760 sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
761 sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
762
763 let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
764 pin_mut!(timeout);
765 pin_mut!(jh);
766
767 match futures::future::select(jh, timeout).await {
768 Either::Left((Ok(v), _)) => {
769 assert_eq!(v.len(), EXPECTED_DATA_LEN);
770 assert_eq!(v.as_slice(), &msg);
771 Ok(())
772 }
773 _ => Err(anyhow::anyhow!("timeout or invalid data")),
774 }
775 }
776}