hopr_network_types/
udp.rs

1use std::{
2    fmt::Debug,
3    io::ErrorKind,
4    num::NonZeroUsize,
5    pin::Pin,
6    sync::{Arc, OnceLock},
7    task::{Context, Poll},
8};
9
10use futures::{FutureExt, Sink, SinkExt, pin_mut, ready};
11use tokio::net::UdpSocket;
12use tracing::{debug, error, trace, warn};
13
14use crate::utils::SocketAddrStr;
15
16type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
17
18#[cfg(all(feature = "prometheus", not(test)))]
19lazy_static::lazy_static! {
20    static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
21        hopr_metrics::MultiHistogram::new(
22            "hopr_udp_ingress_packet_len",
23            "UDP packet lengths on ingress per counterparty",
24            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
25            &["counterparty"]
26    ).unwrap();
27    static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
28        hopr_metrics::MultiHistogram::new(
29            "hopr_udp_egress_packet_len",
30            "UDP packet lengths on egress per counterparty",
31            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
32            &["counterparty"]
33    ).unwrap();
34}
35
36/// Mimics TCP-like stream functionality on a UDP socket by restricting it to a single
37/// counterparty and implements [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`].
38/// The instance is always constructed using a [`UdpStreamBuilder`].
39///
40/// To set a counterparty, one of the following must happen:
41/// 1) setting it during build via [`UdpStreamBuilder::with_counterparty`]
42/// 2) receiving some data from the other side.
43///
44/// Whatever of the above happens first, sets the counterparty.
45/// Once the counterparty is set, all data sent and received will be sent or filtered by this
46/// counterparty address.
47///
48/// If data from another party is received, an error is raised, unless the object has been constructed
49/// with [`ForeignDataMode::Discard`] or [`ForeignDataMode::Accept`] setting.
50///
51/// This object is also capable of parallel processing on a UDP socket.
52/// If [parallelism](UdpStreamBuilder::with_receiver_parallelism) is set, the instance will create
53/// multiple sockets with `SO_REUSEADDR` and spin parallel tasks that will coordinate data and
54/// transmission retrieval using these sockets. This is driven by RX/TX MPMC queues, which are
55/// per-default unbounded (see [queue size](UdpStreamBuilder::with_queue_size) for details).
56pub struct ConnectedUdpStream {
57    socket_handles: Vec<tokio::task::JoinHandle<()>>,
58    ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
59    egress_tx: Option<BoxIoSink<Box<[u8]>>>,
60    counterparty: Arc<OnceLock<SocketAddrStr>>,
61    bound_to: std::net::SocketAddr,
62}
63
64/// Defines what happens when data from another [`SocketAddr`](std::net::SocketAddr) arrives
65/// into the [`ConnectedUdpStream`] (other than the one that is considered a counterparty for that
66/// instance).
67#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
68pub enum ForeignDataMode {
69    /// Foreign data are simply discarded.
70    Discard,
71    /// Foreign data are accepted as if they arrived from the set counterparty.
72    Accept,
73    /// Error is raised on the `poll_read` attempt.
74    #[default]
75    Error,
76}
77
78/// Determines how many parallel readers or writer sockets should be bound in [`ConnectedUdpStream`].
79///
80/// Each UDP socket is bound with `SO_REUSEADDR` and `SO_REUSEPORT` to facilitate parallel processing
81/// of send and/or receive operations.
82///
83/// **NOTE**: This is a Linux-specific optimization, and it will have no effect on other systems.
84///
85/// - If some [`Specific`](UdpStreamParallelism::Specific) value `n` > 0 is given, the [`ConnectedUdpStream`] will bind
86///   `n` sockets.
87/// - If [`Auto`](UdpStreamParallelism::Auto) is given, the number of sockets bound by [`ConnectedUdpStream`] is
88///   determined by [`std::thread::available_parallelism`].
89///
90/// The default is `Specific(1)`.
91///
92/// Always use [`into_num_tasks`](UdpStreamParallelism::into_num_tasks) or
93/// [`split_evenly_with`](UdpStreamParallelism::split_evenly_with) to determine the correct number of sockets to spawn.
94#[derive(Copy, Clone, Debug, PartialEq, Eq)]
95pub enum UdpStreamParallelism {
96    /// Bind as many sender or receiver sockets as given by [`std::thread::available_parallelism`].
97    Auto,
98    /// Bind specific number of sender or receiver sockets.
99    Specific(NonZeroUsize),
100}
101
102impl Default for UdpStreamParallelism {
103    fn default() -> Self {
104        Self::Specific(NonZeroUsize::MIN)
105    }
106}
107
108impl UdpStreamParallelism {
109    fn avail_parallelism() -> usize {
110        // On non-Linux system this will always default to 1, since the
111        // multiple UDP socket optimization is not possible for those platforms.
112        std::thread::available_parallelism()
113            .map(|n| {
114                if cfg!(target_os = "linux") {
115                    n
116                } else {
117                    NonZeroUsize::MIN
118                }
119            })
120            .unwrap_or_else(|e| {
121                warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
122                NonZeroUsize::MIN
123            })
124            .into()
125    }
126
127    /// Returns the number of sockets for this and the `other` instance
128    /// when they evenly split the available CPU parallelism.
129    pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
130        let cpu_half = (Self::avail_parallelism() / 2).max(1);
131
132        match (self, other) {
133            (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
134            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
135                let a = cpu_half.min(a.into());
136                (a, cpu_half * 2 - a)
137            }
138            (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
139                let b = cpu_half.min(b.into());
140                (cpu_half * 2 - b, b)
141            }
142            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
143                (cpu_half.min(a.into()), cpu_half.min(b.into()))
144            }
145        }
146    }
147
148    /// Calculates the actual number of tasks for this instance.
149    ///
150    /// The returned value is never more than the maximum available CPU parallelism.
151    pub fn into_num_tasks(self) -> usize {
152        let avail_parallelism = Self::avail_parallelism();
153        match self {
154            UdpStreamParallelism::Auto => avail_parallelism,
155            UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
156        }
157    }
158}
159
160impl From<usize> for UdpStreamParallelism {
161    fn from(value: usize) -> Self {
162        NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
163    }
164}
165
166impl From<Option<usize>> for UdpStreamParallelism {
167    fn from(value: Option<usize>) -> Self {
168        value.map(UdpStreamParallelism::from).unwrap_or_default()
169    }
170}
171
172/// Builder object for the [`ConnectedUdpStream`].
173///
174/// If you wish to use defaults, do `UdpStreamBuilder::default().build(addr)`.
175#[derive(Debug, Clone)]
176pub struct UdpStreamBuilder {
177    foreign_data_mode: ForeignDataMode,
178    buffer_size: usize,
179    queue_size: Option<usize>,
180    receiver_parallelism: UdpStreamParallelism,
181    sender_parallelism: UdpStreamParallelism,
182    counterparty: Option<std::net::SocketAddr>,
183}
184
185impl Default for UdpStreamBuilder {
186    fn default() -> Self {
187        Self {
188            buffer_size: 2048,
189            foreign_data_mode: Default::default(),
190            queue_size: None,
191            receiver_parallelism: Default::default(),
192            sender_parallelism: Default::default(),
193            counterparty: None,
194        }
195    }
196}
197
198impl UdpStreamBuilder {
199    /// Defines the behavior when data from an unexpected source arrive into the socket.
200    /// See [`ForeignDataMode`] for details.
201    ///
202    /// Default is [`ForeignDataMode::Error`].
203    pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
204        self.foreign_data_mode = mode;
205        self
206    }
207
208    /// The size of the UDP receive buffer.
209    ///
210    /// This size must be at least the size of the MTU, otherwise the unread UDP data that
211    /// does not fit this buffer will be discarded.
212    ///
213    /// Default is 2048.
214    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
215        self.buffer_size = buffer_size;
216        self
217    }
218
219    /// Size of the TX/RX queue that dispatches data of reads from/writings to
220    /// the sockets.
221    ///
222    /// This an important back-pressure mechanism when dispatching received data from
223    /// fast senders.
224    /// Reduces the maximum memory consumed by the object, which is given by:
225    /// [`buffer_size`](UdpStreamBuilder::with_buffer_size) *
226    /// [`queue_size`](UdpStreamBuilder::with_queue_size)
227    ///
228    /// Default is unbounded.
229    pub fn with_queue_size(mut self, queue_size: usize) -> Self {
230        self.queue_size = Some(queue_size);
231        self
232    }
233
234    /// Sets how many parallel receiving sockets should be bound.
235    ///
236    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
237    ///
238    /// Default is `1`.
239    pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
240        self.receiver_parallelism = parallelism.into();
241        self
242    }
243
244    /// Sets how many parallel sending sockets should be bound.
245    ///
246    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
247    ///
248    /// Default is `1`.
249    pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
250        self.sender_parallelism = parallelism.into();
251        self
252    }
253
254    /// Sets the expected counterparty for data sent/received by the UDP sockets.
255    ///
256    /// If not specified, the counterparty is determined from the first packet received.
257    /// However, no data can be sent up until this point.
258    /// Therefore, the value must be set if data are sent first rather than received.
259    /// If data is expected to be received first, the value does not need to be set.
260    ///
261    /// See [`ConnectedUdpStream`] and [`ForeignDataMode`] for details.
262    ///
263    /// Default is none.
264    pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
265        self.counterparty = Some(counterparty);
266        self
267    }
268
269    /// Builds the [`ConnectedUdpStream`] with UDP socket(s) bound to `bind_addr`.
270    ///
271    /// The number of RX sockets bound is determined by [receiver
272    /// parallelism](UdpStreamBuilder::with_receiver_parallelism), and similarly, the number of TX sockets bound is
273    /// determined by [sender parallelism](UdpStreamBuilder::with_sender_parallelism). On non-Linux platforms, only
274    /// a single receiver and sender will be bound, regardless of the above.
275    ///
276    /// The returned instance is always ready to receive data.
277    /// It is also ready to send data
278    /// if the [counterparty](UdpStreamBuilder::with_counterparty) has been set.
279    ///
280    /// If `bind_addr` yields multiple addresses, binding will be attempted with each of the addresses
281    /// until one succeeds. If none of the addresses succeed in binding the socket(s),
282    /// the `AddrNotAvailable` error is returned.
283    ///
284    /// Note that wildcard addresses (such as `0.0.0.0`) are *not* considered as multiple addresses,
285    /// and such socket(s) will bind to all available interfaces at the system level.
286    pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
287        let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
288
289        let counterparty = Arc::new(
290            self.counterparty
291                .map(|s| OnceLock::from(SocketAddrStr::from(s)))
292                .unwrap_or_default(),
293        );
294        let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
295            (flume::bounded(q), flume::bounded(q))
296        } else {
297            (flume::unbounded(), flume::unbounded())
298        };
299
300        let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
301        let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
302        let mut bound_addr: Option<std::net::SocketAddr> = None;
303
304        // Try binding on all network addresses in `bind_addr`
305        for binding_to in bind_addr.to_socket_addrs()? {
306            debug!(
307                %binding_to,
308                num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
309            );
310
311            // TODO: split bound sockets into a separate cloneable object
312
313            // Try to bind sockets on the current network interface address
314            (0..num_socks_to_bind)
315                .map(|sock_id| {
316                    let domain = match &binding_to {
317                        std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
318                        std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
319                    };
320
321                    // Bind a new non-blocking UDP socket
322                    let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
323                    if num_socks_to_bind > 1 {
324                        sock.set_reuse_address(true)?; // Needed for every next socket with non-wildcard IP
325                        sock.set_reuse_port(true)?; // Needed on Linux to evenly distribute datagrams
326                    }
327                    sock.set_nonblocking(true)?;
328                    sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
329
330                    // Determine the address we bound this socket to, so we can also bind the others
331                    let socket_bound_addr = sock
332                        .local_addr()?
333                        .as_socket()
334                        .ok_or(std::io::Error::other("invalid socket type"))?;
335
336                    match bound_addr {
337                        None => bound_addr = Some(socket_bound_addr),
338                        Some(addr) if addr != socket_bound_addr => {
339                            return Err(std::io::Error::other(format!(
340                                "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
341                            )));
342                        }
343                        _ => {}
344                    }
345
346                    let sock = Arc::new(UdpSocket::from_std(sock.into())?);
347                    debug!(
348                        socket_id = sock_id,
349                        addr = %socket_bound_addr,
350                        "bound UDP socket"
351                    );
352
353                    Ok((sock_id, sock))
354                })
355                .filter_map(|result| match result {
356                    Ok(bound) => Some(bound),
357                    Err(e) => {
358                        error!(
359                            %binding_to,
360                            "failed to bind udp socket: {e}"
361                        );
362                        None
363                    }
364                })
365                .for_each(|(sock_id, sock)| {
366                    if sock_id < num_tx_socks {
367                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
368                            sock_id,
369                            sock.clone(),
370                            egress_rx.clone(),
371                            counterparty.clone(),
372                        )));
373                    }
374                    if sock_id < num_rx_socks {
375                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
376                            sock_id,
377                            sock.clone(),
378                            ingress_tx.clone(),
379                            counterparty.clone(),
380                            self.foreign_data_mode,
381                            self.buffer_size,
382                        )));
383                    }
384                });
385        }
386
387        Ok(ConnectedUdpStream {
388            ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
389            egress_tx: Some(Box::new(
390                egress_tx
391                    .into_sink()
392                    .sink_map_err(|e| std::io::Error::other(e.to_string())),
393            )),
394            socket_handles,
395            counterparty,
396            bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
397        })
398    }
399}
400
401impl ConnectedUdpStream {
402    /// Creates a receiver queue for the UDP stream.
403    fn setup_rx_queue(
404        socket_id: usize,
405        sock_rx: Arc<UdpSocket>,
406        ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
407        counterparty: Arc<OnceLock<SocketAddrStr>>,
408        foreign_data_mode: ForeignDataMode,
409        buf_size: usize,
410    ) -> futures::future::BoxFuture<'static, ()> {
411        let counterparty_rx = counterparty.clone();
412        async move {
413            let mut buffer = vec![0u8; buf_size];
414            let mut done = false;
415            loop {
416                // Read data from the socket
417                let out_res = match sock_rx.recv_from(&mut buffer).await {
418                    Ok((read, read_addr)) if read > 0 => {
419                        trace!(
420                            socket_id,
421                            udp_bound_addr = ?sock_rx.local_addr(),
422                            bytes = read,
423                            from = %read_addr,
424                            "received data from"
425                        );
426
427                        let addr = counterparty_rx.get_or_init(|| read_addr.into());
428
429                        #[cfg(all(feature = "prometheus", not(test)))]
430                        METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
431
432                        // If the data is from a counterparty, or we accept anything, pass it
433                        if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
434                            let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
435                            Some(Ok(out_buffer))
436                        } else {
437                            match foreign_data_mode {
438                                ForeignDataMode::Discard => {
439                                    // Don't even bother sending an error about discarded stuff
440                                    warn!(
441                                        socket_id,
442                                        udp_bound_addr = ?sock_rx.local_addr(),
443                                        ?read_addr,
444                                        expected_addr = ?addr,
445                                        "discarded data, which didn't come from the expected address"
446                                    );
447                                    None
448                                }
449                                ForeignDataMode::Error => {
450                                    // Terminate here, the ingress_tx gets dropped
451                                    done = true;
452                                    Some(Err(std::io::Error::new(
453                                        ErrorKind::ConnectionRefused,
454                                        "data from foreign client not allowed",
455                                    )))
456                                }
457                                // ForeignDataMode::Accept has been handled above
458                                _ => unreachable!(),
459                            }
460                        }
461                    }
462                    Ok(_) => {
463                        // Read EOF, terminate here, the ingress_tx gets dropped
464                        trace!(
465                            socket_id,
466                            udp_bound_addr = ?sock_rx.local_addr(),
467                            "read EOF on socket"
468                        );
469                        done = true;
470                        None
471                    }
472                    Err(e) => {
473                        // Forward the error
474                        debug!(
475                            socket_id,
476                            udp_bound_addr = ?sock_rx.local_addr(),
477                            error = %e,
478                            "forwarded error from socket"
479                        );
480                        done = true;
481                        Some(Err(e))
482                    }
483                };
484
485                // Dispatch the received data to the queue.
486                // If the underlying queue is bounded, it will wait until there is space.
487                if let Some(out_res) = out_res {
488                    if let Err(err) = ingress_tx.send_async(out_res).await {
489                        error!(
490                            socket_id,
491                            udp_bound_addr = ?sock_rx.local_addr(),
492                            error = %err,
493                            "failed to dispatch received data"
494                        );
495                        done = true;
496                    }
497                }
498
499                if done {
500                    trace!(
501                        socket_id,
502                        udp_bound_addr = ?sock_rx.local_addr(),
503                        "rx queue done"
504                    );
505                    break;
506                }
507            }
508        }
509        .boxed()
510    }
511
512    /// Creates a transmission queue for the UDP stream.
513    fn setup_tx_queue(
514        socket_id: usize,
515        sock_tx: Arc<UdpSocket>,
516        egress_rx: flume::Receiver<Box<[u8]>>,
517        counterparty: Arc<OnceLock<SocketAddrStr>>,
518    ) -> futures::future::BoxFuture<'static, ()> {
519        let counterparty_tx = counterparty.clone();
520        async move {
521            loop {
522                match egress_rx.recv_async().await {
523                    Ok(data) => {
524                        if let Some(target) = counterparty_tx.get() {
525                            if let Err(e) = sock_tx.send_to(&data, target.as_ref()).await {
526                                error!(
527                                    ?socket_id,
528                                    udp_bound_addr = ?sock_tx.local_addr(),
529                                    ?target,
530                                    error = %e,
531                                    "failed to send data"
532                                );
533                            }
534                            trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
535
536                            #[cfg(all(feature = "prometheus", not(test)))]
537                            METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
538                        } else {
539                            error!(
540                                ?socket_id,
541                                udp_bound_addr = ?sock_tx.local_addr(),
542                                "cannot send data, counterparty not set"
543                            );
544                            break;
545                        }
546                    }
547                    Err(e) => {
548                        error!(
549                            ?socket_id,
550                            udp_bound_addr = ?sock_tx.local_addr(),
551                            error = %e,
552                            "cannot receive more data from egress channel"
553                        );
554                        break;
555                    }
556                }
557                trace!(
558                    ?socket_id,
559                    udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
560                    "tx queue done"
561                );
562            }
563        }
564        .boxed()
565    }
566
567    /// Local address that all UDP sockets in this instance are bound to.
568    pub fn bound_address(&self) -> &std::net::SocketAddr {
569        &self.bound_to
570    }
571
572    /// Creates a new [builder](UdpStreamBuilder).
573    pub fn builder() -> UdpStreamBuilder {
574        UdpStreamBuilder::default()
575    }
576}
577
578impl Drop for ConnectedUdpStream {
579    fn drop(&mut self) {
580        self.socket_handles.iter().for_each(|handle| {
581            handle.abort();
582        })
583    }
584}
585
586impl tokio::io::AsyncRead for ConnectedUdpStream {
587    fn poll_read(
588        mut self: Pin<&mut Self>,
589        cx: &mut Context<'_>,
590        buf: &mut tokio::io::ReadBuf<'_>,
591    ) -> Poll<std::io::Result<()>> {
592        trace!(
593            remaining_bytes = buf.remaining(),
594            counterparty = ?self.counterparty.get(),
595            "polling read of from udp stream",
596        );
597        match Pin::new(&mut self.ingress_rx).poll_read(cx, buf) {
598            Poll::Ready(Ok(())) => {
599                let read = buf.filled().len();
600                trace!(bytes = read, "read bytes");
601                Poll::Ready(Ok(()))
602            }
603            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
604            Poll::Pending => Poll::Pending,
605        }
606    }
607}
608
609impl tokio::io::AsyncWrite for ConnectedUdpStream {
610    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
611        trace!(
612            bytes = buf.len(),
613            counterparty = ?self.counterparty.get(),
614            "polling write to udp stream",
615        );
616        if let Some(sender) = &mut self.egress_tx {
617            if let Err(e) = ready!(sender.poll_ready_unpin(cx)) {
618                return Poll::Ready(Err(e));
619            }
620
621            let len = buf.iter().len();
622            if let Err(e) = sender.start_send_unpin(Box::from(buf)) {
623                return Poll::Ready(Err(e));
624            }
625
626            // Explicitly flush after each data sent
627            pin_mut!(sender);
628            sender.poll_flush(cx).map_ok(|_| len)
629        } else {
630            Poll::Ready(Err(std::io::Error::new(
631                ErrorKind::NotConnected,
632                "udp stream is closed",
633            )))
634        }
635    }
636
637    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
638        trace!(
639            counterparty = ?self.counterparty.get(),
640            "polling flush to udp stream"
641        );
642        if let Some(sender) = &mut self.egress_tx {
643            pin_mut!(sender);
644            sender
645                .poll_flush(cx)
646                .map_err(|err| std::io::Error::other(err.to_string()))
647        } else {
648            Poll::Ready(Err(std::io::Error::new(
649                ErrorKind::NotConnected,
650                "udp stream is closed",
651            )))
652        }
653    }
654
655    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
656        trace!(
657            counterparty = ?self.counterparty.get(),
658            "polling close on udp stream"
659        );
660        // Take the sender to make sure it gets dropped
661        let mut taken_sender = self.egress_tx.take();
662        if let Some(sender) = &mut taken_sender {
663            pin_mut!(sender);
664            sender
665                .poll_close(cx)
666                .map_err(|err| std::io::Error::other(err.to_string()))
667        } else {
668            Poll::Ready(Err(std::io::Error::new(
669                ErrorKind::NotConnected,
670                "udp stream is closed",
671            )))
672        }
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use anyhow::Context;
679    use futures::future::Either;
680    use parameterized::parameterized;
681    use tokio::{
682        io::{AsyncReadExt, AsyncWriteExt},
683        net::UdpSocket,
684    };
685
686    use super::*;
687
688    #[parameterized(parallelism = {None, Some(2), Some(0)})]
689    #[parameterized_macro(tokio::test)]
690    //#[parameterized_macro(test_log::test(tokio::test))]
691    async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
692        const DATA_SIZE: usize = 200;
693
694        let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
695        let listen_addr = listener.local_addr()?;
696
697        // Simple echo UDP server
698        tokio::task::spawn(async move {
699            loop {
700                let mut buf = [0u8; DATA_SIZE];
701                let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
702                if read > 0 {
703                    assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
704                    listener.send_to(&buf, addr).await.expect("send must not fail");
705                }
706            }
707        });
708
709        let mut builder = ConnectedUdpStream::builder()
710            .with_buffer_size(1024)
711            .with_queue_size(512)
712            .with_counterparty(listen_addr);
713
714        if let Some(parallelism) = parallelism {
715            builder = builder.with_receiver_parallelism(parallelism);
716        }
717
718        let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
719
720        for _ in 1..1000 {
721            let mut w_buf = [0u8; DATA_SIZE];
722            hopr_crypto_random::random_fill(&mut w_buf);
723            let written = stream.write(&w_buf).await?;
724            assert_eq!(written, DATA_SIZE);
725
726            let mut r_buf = [0u8; DATA_SIZE];
727            let read = stream.read_exact(&mut r_buf).await?;
728            assert_eq!(read, DATA_SIZE);
729
730            assert_eq!(w_buf, r_buf);
731        }
732
733        stream.shutdown().await?;
734
735        Ok(())
736    }
737
738    #[tokio::test]
739    async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
740        const BUF_SIZE: usize = 1024;
741        const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
742
743        let mut listener = ConnectedUdpStream::builder()
744            .with_buffer_size(BUF_SIZE)
745            .with_queue_size(512)
746            .build(("127.0.0.1", 0))
747            .context("bind listener")?;
748
749        let bound_addr = *listener.bound_address();
750
751        let jh = tokio::task::spawn(async move {
752            let mut buf = [0u8; BUF_SIZE / 4];
753            let mut vec = Vec::<u8>::new();
754            loop {
755                let sz = listener.read(&mut buf).await.unwrap();
756                if sz > 0 {
757                    vec.extend_from_slice(&buf[..sz]);
758                    if vec.len() >= EXPECTED_DATA_LEN {
759                        return vec;
760                    }
761                } else {
762                    return vec;
763                }
764            }
765        });
766
767        let msg = [1u8; EXPECTED_DATA_LEN];
768        let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
769
770        sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
771        sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
772
773        let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
774        pin_mut!(timeout);
775        pin_mut!(jh);
776
777        match futures::future::select(jh, timeout).await {
778            Either::Left((Ok(v), _)) => {
779                assert_eq!(v.len(), EXPECTED_DATA_LEN);
780                assert_eq!(v.as_slice(), &msg);
781                Ok(())
782            }
783            _ => Err(anyhow::anyhow!("timeout or invalid data")),
784        }
785    }
786}