hopr_network_types/
udp.rs

1use std::{
2    fmt::Debug,
3    io::ErrorKind,
4    num::NonZeroUsize,
5    pin::Pin,
6    sync::{Arc, OnceLock},
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use futures::{FutureExt, Sink, SinkExt, ready};
12use tokio::net::UdpSocket;
13use tracing::{debug, error, instrument, trace, warn};
14
15use crate::utils::SocketAddrStr;
16
17type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21    static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
22        hopr_metrics::MultiHistogram::new(
23            "hopr_udp_ingress_packet_len",
24            "UDP packet lengths on ingress per counterparty",
25            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
26            &["counterparty"]
27    ).unwrap();
28    static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
29        hopr_metrics::MultiHistogram::new(
30            "hopr_udp_egress_packet_len",
31            "UDP packet lengths on egress per counterparty",
32            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
33            &["counterparty"]
34    ).unwrap();
35}
36
37/// Mimics TCP-like stream functionality on a UDP socket by restricting it to a single
38/// counterparty and implements [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`].
39/// The instance must always be constructed using a [`UdpStreamBuilder`].
40///
41/// To set a counterparty, one of the following must happen:
42/// 1) setting it during build via [`UdpStreamBuilder::with_counterparty`]
43/// 2) receiving some data from the other side.
44///
45/// Whatever of the above happens first, sets the counterparty.
46/// Once the counterparty is set, all data sent and received will be sent or filtered by this
47/// counterparty address.
48///
49/// If data from another party is received, an error is raised, unless the object has been constructed
50/// with [`ForeignDataMode::Discard`] or [`ForeignDataMode::Accept`] setting.
51///
52/// This object is also capable of parallel processing on a UDP socket.
53/// If [parallelism](UdpStreamBuilder::with_receiver_parallelism) is set, the instance will create
54/// multiple sockets with `SO_REUSEADDR` and spin parallel tasks that will coordinate data and
55/// transmission retrieval using these sockets. This is driven by RX/TX MPMC queues, which are
56/// per-default unbounded (see [queue size](UdpStreamBuilder::with_queue_size) for details).
57#[pin_project::pin_project(PinnedDrop)]
58pub struct ConnectedUdpStream {
59    socket_handles: Vec<tokio::task::JoinHandle<()>>,
60    #[pin]
61    ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
62    #[pin]
63    egress_tx: Option<BoxIoSink<Box<[u8]>>>,
64    counterparty: Arc<OnceLock<SocketAddrStr>>,
65    bound_to: std::net::SocketAddr,
66    state: State,
67}
68
69#[derive(Copy, Clone)]
70enum State {
71    Writing,
72    Flushing(usize),
73}
74
75/// Defines what happens when data from another [`SocketAddr`](std::net::SocketAddr) arrives
76/// into the [`ConnectedUdpStream`] (other than the one that is considered a counterparty for that
77/// instance).
78#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
79pub enum ForeignDataMode {
80    /// Foreign data are simply discarded.
81    Discard,
82    /// Foreign data are accepted as if they arrived from the set counterparty.
83    Accept,
84    /// Error is raised on the `poll_read` attempt.
85    #[default]
86    Error,
87}
88
89/// Determines how many parallel readers or writer sockets should be bound in [`ConnectedUdpStream`].
90///
91/// Each UDP socket is bound with `SO_REUSEADDR` and `SO_REUSEPORT` to facilitate parallel processing
92/// of send and/or receive operations.
93///
94/// **NOTE**: This is a Linux-specific optimization, and it will have no effect on other systems.
95///
96/// - If some [`Specific`](UdpStreamParallelism::Specific) value `n` > 0 is given, the [`ConnectedUdpStream`] will bind
97///   `n` sockets.
98/// - If [`Auto`](UdpStreamParallelism::Auto) is given, the number of sockets bound by [`ConnectedUdpStream`] is
99///   determined by [`std::thread::available_parallelism`].
100///
101/// The default is `Specific(1)`.
102///
103/// Always use [`into_num_tasks`](UdpStreamParallelism::into_num_tasks) or
104/// [`split_evenly_with`](UdpStreamParallelism::split_evenly_with) to determine the correct number of sockets to spawn.
105#[derive(Copy, Clone, Debug, PartialEq, Eq)]
106pub enum UdpStreamParallelism {
107    /// Bind as many sender or receiver sockets as given by [`std::thread::available_parallelism`].
108    Auto,
109    /// Bind a specific number of sender or receiver sockets.
110    Specific(NonZeroUsize),
111}
112
113impl Default for UdpStreamParallelism {
114    fn default() -> Self {
115        Self::Specific(NonZeroUsize::MIN)
116    }
117}
118
119impl UdpStreamParallelism {
120    fn avail_parallelism() -> usize {
121        // On non-Linux system this will always default to 1, since the
122        // multiple UDP socket optimization is not possible for those platforms.
123        std::thread::available_parallelism()
124            .map(|n| {
125                if cfg!(target_os = "linux") {
126                    n
127                } else {
128                    NonZeroUsize::MIN
129                }
130            })
131            .unwrap_or_else(|e| {
132                warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
133                NonZeroUsize::MIN
134            })
135            .into()
136    }
137
138    /// Returns the number of sockets for this and the `other` instance
139    /// when they evenly split the available CPU parallelism.
140    pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
141        let cpu_half = (Self::avail_parallelism() / 2).max(1);
142
143        match (self, other) {
144            (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
145            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
146                let a = cpu_half.min(a.into());
147                (a, cpu_half * 2 - a)
148            }
149            (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
150                let b = cpu_half.min(b.into());
151                (cpu_half * 2 - b, b)
152            }
153            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
154                (cpu_half.min(a.into()), cpu_half.min(b.into()))
155            }
156        }
157    }
158
159    /// Calculates the actual number of tasks for this instance.
160    ///
161    /// The returned value is never more than the maximum available CPU parallelism.
162    pub fn into_num_tasks(self) -> usize {
163        let avail_parallelism = Self::avail_parallelism();
164        match self {
165            UdpStreamParallelism::Auto => avail_parallelism,
166            UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
167        }
168    }
169}
170
171impl From<usize> for UdpStreamParallelism {
172    fn from(value: usize) -> Self {
173        NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
174    }
175}
176
177impl From<Option<usize>> for UdpStreamParallelism {
178    fn from(value: Option<usize>) -> Self {
179        value.map(UdpStreamParallelism::from).unwrap_or_default()
180    }
181}
182
183/// Builder object for the [`ConnectedUdpStream`].
184///
185/// If you wish to use defaults, do `UdpStreamBuilder::default().build(addr)`.
186#[derive(Debug, Clone)]
187pub struct UdpStreamBuilder {
188    foreign_data_mode: ForeignDataMode,
189    buffer_size: usize,
190    queue_size: Option<usize>,
191    receiver_parallelism: UdpStreamParallelism,
192    sender_parallelism: UdpStreamParallelism,
193    counterparty: Option<std::net::SocketAddr>,
194}
195
196impl Default for UdpStreamBuilder {
197    fn default() -> Self {
198        Self {
199            buffer_size: 2048,
200            foreign_data_mode: Default::default(),
201            queue_size: None,
202            receiver_parallelism: Default::default(),
203            sender_parallelism: Default::default(),
204            counterparty: None,
205        }
206    }
207}
208
209impl UdpStreamBuilder {
210    /// Defines the behavior when data from an unexpected source arrive into the socket.
211    /// See [`ForeignDataMode`] for details.
212    ///
213    /// Default is [`ForeignDataMode::Error`].
214    pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
215        self.foreign_data_mode = mode;
216        self
217    }
218
219    /// The size of the UDP receive buffer.
220    ///
221    /// This size must be at least the size of the MTU, otherwise the unread UDP data that
222    /// does not fit this buffer will be discarded.
223    ///
224    /// Default is 2048.
225    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
226        self.buffer_size = buffer_size;
227        self
228    }
229
230    /// Size of the TX/RX queue that dispatches data of reads from/writings to
231    /// the sockets.
232    ///
233    /// This an important back-pressure mechanism when dispatching received data from
234    /// fast senders.
235    /// Reduces the maximum memory consumed by the object, which is given by:
236    /// [`buffer_size`](UdpStreamBuilder::with_buffer_size) *
237    /// [`queue_size`](UdpStreamBuilder::with_queue_size)
238    ///
239    /// Default is unbounded.
240    pub fn with_queue_size(mut self, queue_size: usize) -> Self {
241        self.queue_size = Some(queue_size);
242        self
243    }
244
245    /// Sets how many parallel receiving sockets should be bound.
246    ///
247    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
248    ///
249    /// Default is `1`.
250    pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
251        self.receiver_parallelism = parallelism.into();
252        self
253    }
254
255    /// Sets how many parallel sending sockets should be bound.
256    ///
257    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
258    ///
259    /// Default is `1`.
260    pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
261        self.sender_parallelism = parallelism.into();
262        self
263    }
264
265    /// Sets the expected counterparty for data sent/received by the UDP sockets.
266    ///
267    /// If not specified, the counterparty is determined from the first packet received.
268    /// However, no data can be sent up until this point.
269    /// Therefore, the value must be set if data are sent first rather than received.
270    /// If data is expected to be received first, the value does not need to be set.
271    ///
272    /// See [`ConnectedUdpStream`] and [`ForeignDataMode`] for details.
273    ///
274    /// Default is none.
275    pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
276        self.counterparty = Some(counterparty);
277        self
278    }
279
280    /// Builds the [`ConnectedUdpStream`] with UDP socket(s) bound to `bind_addr`.
281    ///
282    /// The number of RX sockets bound is determined by [receiver
283    /// parallelism](UdpStreamBuilder::with_receiver_parallelism), and similarly, the number of TX sockets bound is
284    /// determined by [sender parallelism](UdpStreamBuilder::with_sender_parallelism). On non-Linux platforms, only
285    /// a single receiver and sender will be bound, regardless of the above.
286    ///
287    /// The returned instance is always ready to receive data.
288    /// It is also ready to send data
289    /// if the [counterparty](UdpStreamBuilder::with_counterparty) has been set.
290    ///
291    /// If `bind_addr` yields multiple addresses, binding will be attempted with each of the addresses
292    /// until one succeeds. If none of the addresses succeed in binding the socket(s),
293    /// the `AddrNotAvailable` error is returned.
294    ///
295    /// Note that wildcard addresses (such as `0.0.0.0`) are *not* considered as multiple addresses,
296    /// and such socket(s) will bind to all available interfaces at the system level.
297    pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
298        let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
299
300        let counterparty = Arc::new(
301            self.counterparty
302                .map(|s| OnceLock::from(SocketAddrStr::from(s)))
303                .unwrap_or_default(),
304        );
305        let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
306            (flume::bounded(q), flume::bounded(q))
307        } else {
308            (flume::unbounded(), flume::unbounded())
309        };
310
311        let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
312        let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
313        let mut bound_addr: Option<std::net::SocketAddr> = None;
314
315        // Try binding on all network addresses in `bind_addr`
316        for binding_to in bind_addr.to_socket_addrs()? {
317            debug!(
318                %binding_to,
319                num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
320            );
321
322            // TODO: split bound sockets into a separate cloneable object
323
324            // Try to bind sockets on the current network interface address
325            (0..num_socks_to_bind)
326                .map(|sock_id| {
327                    let domain = match &binding_to {
328                        std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
329                        std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
330                    };
331
332                    // Bind a new non-blocking UDP socket
333                    let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
334                    if num_socks_to_bind > 1 {
335                        sock.set_reuse_address(true)?; // Needed for every next socket with non-wildcard IP
336                        sock.set_reuse_port(true)?; // Needed on Linux to evenly distribute datagrams
337                    }
338                    sock.set_nonblocking(true)?;
339                    sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
340
341                    // Determine the address we bound this socket to, so we can also bind the others
342                    let socket_bound_addr = sock
343                        .local_addr()?
344                        .as_socket()
345                        .ok_or(std::io::Error::other("invalid socket type"))?;
346
347                    match bound_addr {
348                        None => bound_addr = Some(socket_bound_addr),
349                        Some(addr) if addr != socket_bound_addr => {
350                            return Err(std::io::Error::other(format!(
351                                "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
352                            )));
353                        }
354                        _ => {}
355                    }
356
357                    let sock = Arc::new(UdpSocket::from_std(sock.into())?);
358                    debug!(
359                        socket_id = sock_id,
360                        addr = %socket_bound_addr,
361                        "bound UDP socket"
362                    );
363
364                    Ok((sock_id, sock))
365                })
366                .filter_map(|result| match result {
367                    Ok(bound) => Some(bound),
368                    Err(error) => {
369                        error!(
370                            %binding_to,
371                            %error,
372                            "failed to bind udp socket"
373                        );
374                        None
375                    }
376                })
377                .for_each(|(sock_id, sock)| {
378                    if sock_id < num_tx_socks {
379                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
380                            sock_id,
381                            sock.clone(),
382                            egress_rx.clone(),
383                            counterparty.clone(),
384                        )));
385                    }
386                    if sock_id < num_rx_socks {
387                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
388                            sock_id,
389                            sock.clone(),
390                            ingress_tx.clone(),
391                            counterparty.clone(),
392                            self.foreign_data_mode,
393                            self.buffer_size,
394                        )));
395                    }
396                });
397        }
398
399        Ok(ConnectedUdpStream {
400            ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
401            egress_tx: Some(Box::new(
402                egress_tx
403                    .into_sink()
404                    .sink_map_err(|e| std::io::Error::other(e.to_string())),
405            )),
406            socket_handles,
407            counterparty,
408            bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
409            state: State::Writing,
410        })
411    }
412}
413
414const QUEUE_DISPATCH_THRESHOLD: Duration = Duration::from_millis(150);
415
416impl ConnectedUdpStream {
417    /// Creates a receiver queue for the UDP stream.
418    fn setup_rx_queue(
419        socket_id: usize,
420        sock_rx: Arc<UdpSocket>,
421        ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
422        counterparty: Arc<OnceLock<SocketAddrStr>>,
423        foreign_data_mode: ForeignDataMode,
424        buf_size: usize,
425    ) -> futures::future::BoxFuture<'static, ()> {
426        let counterparty_rx = counterparty.clone();
427        async move {
428            let mut buffer = vec![0u8; buf_size];
429            let mut done = false;
430            loop {
431                // Read data from the socket
432                let out_res = match sock_rx.recv_from(&mut buffer).await {
433                    Ok((read, read_addr)) if read > 0 => {
434                        trace!(
435                            socket_id,
436                            udp_bound_addr = ?sock_rx.local_addr(),
437                            bytes = read,
438                            from = %read_addr,
439                            "received data from"
440                        );
441
442                        let addr = counterparty_rx.get_or_init(|| read_addr.into());
443
444                        #[cfg(all(feature = "prometheus", not(test)))]
445                        METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
446
447                        // If the data is from a counterparty, or we accept anything, pass it
448                        if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
449                            let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
450                            Some(Ok(out_buffer))
451                        } else {
452                            match foreign_data_mode {
453                                ForeignDataMode::Discard => {
454                                    // Don't even bother sending an error about discarded stuff
455                                    warn!(
456                                        socket_id,
457                                        udp_bound_addr = ?sock_rx.local_addr(),
458                                        ?read_addr,
459                                        expected_addr = ?addr,
460                                        "discarded data, which didn't come from the expected address"
461                                    );
462                                    None
463                                }
464                                ForeignDataMode::Error => {
465                                    // Terminate here, the ingress_tx gets dropped
466                                    done = true;
467                                    Some(Err(std::io::Error::new(
468                                        ErrorKind::ConnectionRefused,
469                                        "data from foreign client not allowed",
470                                    )))
471                                }
472                                // ForeignDataMode::Accept has been handled above
473                                _ => unreachable!(),
474                            }
475                        }
476                    }
477                    Ok(_) => {
478                        // Read EOF, terminate here, the ingress_tx gets dropped
479                        trace!(
480                            socket_id,
481                            udp_bound_addr = ?sock_rx.local_addr(),
482                            "read EOF on socket"
483                        );
484                        done = true;
485                        None
486                    }
487                    Err(error) => {
488                        // Forward the error
489                        debug!(
490                            socket_id,
491                            udp_bound_addr = ?sock_rx.local_addr(),
492                            %error,
493                            "forwarded error from socket"
494                        );
495                        done = true;
496                        Some(Err(error))
497                    }
498                };
499
500                // Dispatch the received data to the queue.
501                // If the underlying queue is bounded, it will wait until there is space.
502                if let Some(out_res) = out_res {
503                    let start = std::time::Instant::now();
504                    if let Err(error) = ingress_tx.send_async(out_res).await {
505                        error!(
506                            socket_id,
507                            udp_bound_addr = ?sock_rx.local_addr(),
508                            %error,
509                            "failed to dispatch received data"
510                        );
511                        done = true;
512                    }
513                    let elapsed = start.elapsed();
514                    if elapsed > QUEUE_DISPATCH_THRESHOLD {
515                        warn!(
516                            ?elapsed,
517                            "udp queue dispatch took too long, consider increasing the queue size"
518                        );
519                    }
520                }
521
522                if done {
523                    trace!(
524                        socket_id,
525                        udp_bound_addr = ?sock_rx.local_addr(),
526                        "rx queue done"
527                    );
528                    break;
529                }
530            }
531        }
532        .boxed()
533    }
534
535    /// Creates a transmission queue for the UDP stream.
536    fn setup_tx_queue(
537        socket_id: usize,
538        sock_tx: Arc<UdpSocket>,
539        egress_rx: flume::Receiver<Box<[u8]>>,
540        counterparty: Arc<OnceLock<SocketAddrStr>>,
541    ) -> futures::future::BoxFuture<'static, ()> {
542        let counterparty_tx = counterparty.clone();
543        async move {
544            loop {
545                match egress_rx.recv_async().await {
546                    Ok(data) => {
547                        if let Some(target) = counterparty_tx.get() {
548                            if let Err(error) = sock_tx.send_to(&data, target.as_ref()).await {
549                                error!(
550                                    ?socket_id,
551                                    udp_bound_addr = ?sock_tx.local_addr(),
552                                    ?target,
553                                    %error,
554                                    "failed to send data"
555                                );
556                            }
557                            trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
558
559                            #[cfg(all(feature = "prometheus", not(test)))]
560                            METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
561                        } else {
562                            error!(
563                                ?socket_id,
564                                udp_bound_addr = ?sock_tx.local_addr(),
565                                "cannot send data, counterparty not set"
566                            );
567                            break;
568                        }
569                    }
570                    Err(error) => {
571                        error!(
572                            ?socket_id,
573                            udp_bound_addr = ?sock_tx.local_addr(),
574                            %error,
575                            "cannot receive more data from egress channel"
576                        );
577                        break;
578                    }
579                }
580                trace!(
581                    ?socket_id,
582                    udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
583                    "tx queue done"
584                );
585            }
586        }
587        .boxed()
588    }
589
590    /// Local address that all UDP sockets in this instance are bound to.
591    pub fn bound_address(&self) -> &std::net::SocketAddr {
592        &self.bound_to
593    }
594
595    /// Creates a new [builder](UdpStreamBuilder).
596    pub fn builder() -> UdpStreamBuilder {
597        UdpStreamBuilder::default()
598    }
599}
600
601impl tokio::io::AsyncRead for ConnectedUdpStream {
602    #[instrument(name = "ConnectedUdpStream::poll_read", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), rem = buf.remaining()) , ret)]
603    fn poll_read(
604        self: Pin<&mut Self>,
605        cx: &mut Context<'_>,
606        buf: &mut tokio::io::ReadBuf<'_>,
607    ) -> Poll<std::io::Result<()>> {
608        ready!(self.project().ingress_rx.poll_read(cx, buf))?;
609        trace!(bytes = buf.filled().len(), "read bytes");
610        Poll::Ready(Ok(()))
611    }
612}
613
614impl tokio::io::AsyncWrite for ConnectedUdpStream {
615    #[instrument(name = "ConnectedUdpStream::poll_write", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), len = buf.len()) , ret)]
616    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
617        let this = self.project();
618        if let Some(sender) = this.egress_tx.get_mut() {
619            loop {
620                match *this.state {
621                    State::Writing => {
622                        ready!(sender.poll_ready_unpin(cx))?;
623
624                        let len = buf.iter().len();
625                        sender.start_send_unpin(Box::from(buf))?;
626                        *this.state = State::Flushing(len);
627                    }
628                    State::Flushing(len) => {
629                        // Explicitly flush after each data sent
630                        let res = ready!(sender.poll_flush_unpin(cx)).map(|_| len);
631                        *this.state = State::Writing;
632
633                        return Poll::Ready(res);
634                    }
635                }
636            }
637        } else {
638            Poll::Ready(Err(std::io::Error::new(
639                ErrorKind::NotConnected,
640                "udp stream is closed",
641            )))
642        }
643    }
644
645    #[instrument(name = "ConnectedUdpStream::poll_flush", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
646    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
647        let this = self.project();
648        if let Some(sender) = this.egress_tx.as_pin_mut() {
649            sender.poll_flush(cx).map_err(std::io::Error::other)
650        } else {
651            Poll::Ready(Err(std::io::Error::new(
652                ErrorKind::NotConnected,
653                "udp stream is closed",
654            )))
655        }
656    }
657
658    #[instrument(name = "ConnectedUdpStream::poll_shutdown", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
659    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
660        let this = self.project();
661        if let Some(sender) = this.egress_tx.as_pin_mut() {
662            let ret = ready!(sender.poll_close(cx));
663
664            this.socket_handles.iter().for_each(|handle| {
665                handle.abort();
666            });
667
668            Poll::Ready(ret)
669        } else {
670            Poll::Ready(Err(std::io::Error::new(
671                ErrorKind::NotConnected,
672                "udp stream is closed",
673            )))
674        }
675    }
676}
677
678#[pin_project::pinned_drop]
679impl PinnedDrop for ConnectedUdpStream {
680    fn drop(self: Pin<&mut Self>) {
681        debug!(binding = ?self.bound_to,"dropping ConnectedUdpStream");
682        self.project().socket_handles.iter().for_each(|handle| {
683            handle.abort();
684        })
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use anyhow::Context;
691    use futures::{future::Either, pin_mut};
692    use parameterized::parameterized;
693    use tokio::{
694        io::{AsyncReadExt, AsyncWriteExt},
695        net::UdpSocket,
696    };
697
698    use super::*;
699
700    #[parameterized(parallelism = {None, Some(2), Some(0)})]
701    #[parameterized_macro(tokio::test)]
702    //#[parameterized_macro(test_log::test(tokio::test))]
703    async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
704        const DATA_SIZE: usize = 200;
705
706        let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
707        let listen_addr = listener.local_addr()?;
708
709        // Simple echo UDP server
710        tokio::task::spawn(async move {
711            loop {
712                let mut buf = [0u8; DATA_SIZE];
713                let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
714                if read > 0 {
715                    assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
716                    listener.send_to(&buf, addr).await.expect("send must not fail");
717                }
718            }
719        });
720
721        let mut builder = ConnectedUdpStream::builder()
722            .with_buffer_size(1024)
723            .with_queue_size(512)
724            .with_counterparty(listen_addr);
725
726        if let Some(parallelism) = parallelism {
727            builder = builder.with_receiver_parallelism(parallelism);
728        }
729
730        let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
731
732        for _ in 1..1000 {
733            let mut w_buf = [0u8; DATA_SIZE];
734            hopr_crypto_random::random_fill(&mut w_buf);
735            let written = stream.write(&w_buf).await?;
736            assert_eq!(written, DATA_SIZE);
737
738            let mut r_buf = [0u8; DATA_SIZE];
739            let read = stream.read_exact(&mut r_buf).await?;
740            assert_eq!(read, DATA_SIZE);
741
742            assert_eq!(w_buf, r_buf);
743        }
744
745        stream.shutdown().await?;
746
747        Ok(())
748    }
749
750    #[tokio::test]
751    async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
752        const BUF_SIZE: usize = 1024;
753        const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
754
755        let mut listener = ConnectedUdpStream::builder()
756            .with_buffer_size(BUF_SIZE)
757            .with_queue_size(512)
758            .build(("127.0.0.1", 0))
759            .context("bind listener")?;
760
761        let bound_addr = *listener.bound_address();
762
763        let jh = tokio::task::spawn(async move {
764            let mut buf = [0u8; BUF_SIZE / 4];
765            let mut vec = Vec::<u8>::new();
766            loop {
767                let sz = listener.read(&mut buf).await.unwrap();
768                if sz > 0 {
769                    vec.extend_from_slice(&buf[..sz]);
770                    if vec.len() >= EXPECTED_DATA_LEN {
771                        return vec;
772                    }
773                } else {
774                    return vec;
775                }
776            }
777        });
778
779        let msg = [1u8; EXPECTED_DATA_LEN];
780        let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
781
782        sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
783        sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
784
785        let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
786        pin_mut!(timeout);
787        pin_mut!(jh);
788
789        match futures::future::select(jh, timeout).await {
790            Either::Left((Ok(v), _)) => {
791                assert_eq!(v.len(), EXPECTED_DATA_LEN);
792                assert_eq!(v.as_slice(), &msg);
793                Ok(())
794            }
795            _ => Err(anyhow::anyhow!("timeout or invalid data")),
796        }
797    }
798}