hopr_network_types/
udp.rs

1use std::{
2    fmt::Debug,
3    io::ErrorKind,
4    num::NonZeroUsize,
5    pin::Pin,
6    sync::{Arc, OnceLock},
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use futures::{FutureExt, Sink, SinkExt, ready};
12use tokio::net::UdpSocket;
13use tracing::{debug, error, instrument, trace, warn};
14
15use crate::utils::SocketAddrStr;
16
17type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
18
19#[cfg(all(feature = "prometheus", not(test)))]
20lazy_static::lazy_static! {
21    static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::SimpleHistogram =
22        hopr_metrics::SimpleHistogram::new(
23            "hopr_udp_ingress_packet_len",
24            "UDP packet lengths on ingress",
25            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0]
26    ).unwrap();
27    static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::SimpleHistogram =
28        hopr_metrics::SimpleHistogram::new(
29            "hopr_udp_egress_packet_len",
30            "UDP packet lengths on egress",
31            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0]
32    ).unwrap();
33}
34
35/// Mimics TCP-like stream functionality on a UDP socket by restricting it to a single
36/// counterparty and implements [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`].
37/// The instance must always be constructed using a [`UdpStreamBuilder`].
38///
39/// To set a counterparty, one of the following must happen:
40/// 1) setting it during build via [`UdpStreamBuilder::with_counterparty`]
41/// 2) receiving some data from the other side.
42///
43/// Whatever of the above happens first, sets the counterparty.
44/// Once the counterparty is set, all data sent and received will be sent or filtered by this
45/// counterparty address.
46///
47/// If data from another party is received, an error is raised, unless the object has been constructed
48/// with [`ForeignDataMode::Discard`] or [`ForeignDataMode::Accept`] setting.
49///
50/// This object is also capable of parallel processing on a UDP socket.
51/// If [parallelism](UdpStreamBuilder::with_receiver_parallelism) is set, the instance will create
52/// multiple sockets with `SO_REUSEADDR` and spin parallel tasks that will coordinate data and
53/// transmission retrieval using these sockets. This is driven by RX/TX MPMC queues, which are
54/// per-default unbounded (see [queue size](UdpStreamBuilder::with_queue_size) for details).
55#[pin_project::pin_project(PinnedDrop)]
56pub struct ConnectedUdpStream {
57    socket_handles: Vec<tokio::task::JoinHandle<()>>,
58    #[pin]
59    ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
60    #[pin]
61    egress_tx: Option<BoxIoSink<Box<[u8]>>>,
62    counterparty: Arc<OnceLock<SocketAddrStr>>,
63    bound_to: std::net::SocketAddr,
64    state: State,
65}
66
67#[derive(Copy, Clone)]
68enum State {
69    Writing,
70    Flushing(usize),
71}
72
73/// Defines what happens when data from another [`SocketAddr`](std::net::SocketAddr) arrives
74/// into the [`ConnectedUdpStream`] (other than the one that is considered a counterparty for that
75/// instance).
76#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
77pub enum ForeignDataMode {
78    /// Foreign data are simply discarded.
79    Discard,
80    /// Foreign data are accepted as if they arrived from the set counterparty.
81    Accept,
82    /// Error is raised on the `poll_read` attempt.
83    #[default]
84    Error,
85}
86
87/// Determines how many parallel readers or writer sockets should be bound in [`ConnectedUdpStream`].
88///
89/// Each UDP socket is bound with `SO_REUSEADDR` and `SO_REUSEPORT` to facilitate parallel processing
90/// of send and/or receive operations.
91///
92/// **NOTE**: This is a Linux-specific optimization, and it will have no effect on other systems.
93///
94/// - If some [`Specific`](UdpStreamParallelism::Specific) value `n` > 0 is given, the [`ConnectedUdpStream`] will bind
95///   `n` sockets.
96/// - If [`Auto`](UdpStreamParallelism::Auto) is given, the number of sockets bound by [`ConnectedUdpStream`] is
97///   determined by [`std::thread::available_parallelism`].
98///
99/// The default is `Specific(1)`.
100///
101/// Always use [`into_num_tasks`](UdpStreamParallelism::into_num_tasks) or
102/// [`split_evenly_with`](UdpStreamParallelism::split_evenly_with) to determine the correct number of sockets to spawn.
103#[derive(Copy, Clone, Debug, PartialEq, Eq)]
104pub enum UdpStreamParallelism {
105    /// Bind as many sender or receiver sockets as given by [`std::thread::available_parallelism`].
106    Auto,
107    /// Bind a specific number of sender or receiver sockets.
108    Specific(NonZeroUsize),
109}
110
111impl Default for UdpStreamParallelism {
112    fn default() -> Self {
113        Self::Specific(NonZeroUsize::MIN)
114    }
115}
116
117impl UdpStreamParallelism {
118    fn avail_parallelism() -> usize {
119        // On non-Linux system this will always default to 1, since the
120        // multiple UDP socket optimization is not possible for those platforms.
121        std::thread::available_parallelism()
122            .map(|n| {
123                if cfg!(target_os = "linux") {
124                    n
125                } else {
126                    NonZeroUsize::MIN
127                }
128            })
129            .unwrap_or_else(|e| {
130                warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
131                NonZeroUsize::MIN
132            })
133            .into()
134    }
135
136    /// Returns the number of sockets for this and the `other` instance
137    /// when they evenly split the available CPU parallelism.
138    pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
139        let cpu_half = (Self::avail_parallelism() / 2).max(1);
140
141        match (self, other) {
142            (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
143            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
144                let a = cpu_half.min(a.into());
145                (a, cpu_half * 2 - a)
146            }
147            (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
148                let b = cpu_half.min(b.into());
149                (cpu_half * 2 - b, b)
150            }
151            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
152                (cpu_half.min(a.into()), cpu_half.min(b.into()))
153            }
154        }
155    }
156
157    /// Calculates the actual number of tasks for this instance.
158    ///
159    /// The returned value is never more than the maximum available CPU parallelism.
160    pub fn into_num_tasks(self) -> usize {
161        let avail_parallelism = Self::avail_parallelism();
162        match self {
163            UdpStreamParallelism::Auto => avail_parallelism,
164            UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
165        }
166    }
167}
168
169impl From<usize> for UdpStreamParallelism {
170    fn from(value: usize) -> Self {
171        NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
172    }
173}
174
175impl From<Option<usize>> for UdpStreamParallelism {
176    fn from(value: Option<usize>) -> Self {
177        value.map(UdpStreamParallelism::from).unwrap_or_default()
178    }
179}
180
181/// Builder object for the [`ConnectedUdpStream`].
182///
183/// If you wish to use defaults, do `UdpStreamBuilder::default().build(addr)`.
184#[derive(Debug, Clone)]
185pub struct UdpStreamBuilder {
186    foreign_data_mode: ForeignDataMode,
187    buffer_size: usize,
188    queue_size: Option<usize>,
189    receiver_parallelism: UdpStreamParallelism,
190    sender_parallelism: UdpStreamParallelism,
191    counterparty: Option<std::net::SocketAddr>,
192}
193
194impl Default for UdpStreamBuilder {
195    fn default() -> Self {
196        Self {
197            buffer_size: 2048,
198            foreign_data_mode: Default::default(),
199            queue_size: None,
200            receiver_parallelism: Default::default(),
201            sender_parallelism: Default::default(),
202            counterparty: None,
203        }
204    }
205}
206
207impl UdpStreamBuilder {
208    /// Defines the behavior when data from an unexpected source arrive into the socket.
209    /// See [`ForeignDataMode`] for details.
210    ///
211    /// Default is [`ForeignDataMode::Error`].
212    pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
213        self.foreign_data_mode = mode;
214        self
215    }
216
217    /// The size of the UDP receive buffer.
218    ///
219    /// This size must be at least the size of the MTU, otherwise the unread UDP data that
220    /// does not fit this buffer will be discarded.
221    ///
222    /// Default is 2048.
223    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
224        self.buffer_size = buffer_size;
225        self
226    }
227
228    /// Size of the TX/RX queue that dispatches data of reads from/writings to
229    /// the sockets.
230    ///
231    /// This an important back-pressure mechanism when dispatching received data from
232    /// fast senders.
233    /// Reduces the maximum memory consumed by the object, which is given by:
234    /// [`buffer_size`](UdpStreamBuilder::with_buffer_size) *
235    /// [`queue_size`](UdpStreamBuilder::with_queue_size)
236    ///
237    /// Default is unbounded.
238    pub fn with_queue_size(mut self, queue_size: usize) -> Self {
239        self.queue_size = Some(queue_size);
240        self
241    }
242
243    /// Sets how many parallel receiving sockets should be bound.
244    ///
245    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
246    ///
247    /// Default is `1`.
248    pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
249        self.receiver_parallelism = parallelism.into();
250        self
251    }
252
253    /// Sets how many parallel sending sockets should be bound.
254    ///
255    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
256    ///
257    /// Default is `1`.
258    pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
259        self.sender_parallelism = parallelism.into();
260        self
261    }
262
263    /// Sets the expected counterparty for data sent/received by the UDP sockets.
264    ///
265    /// If not specified, the counterparty is determined from the first packet received.
266    /// However, no data can be sent up until this point.
267    /// Therefore, the value must be set if data are sent first rather than received.
268    /// If data is expected to be received first, the value does not need to be set.
269    ///
270    /// See [`ConnectedUdpStream`] and [`ForeignDataMode`] for details.
271    ///
272    /// Default is none.
273    pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
274        self.counterparty = Some(counterparty);
275        self
276    }
277
278    /// Builds the [`ConnectedUdpStream`] with UDP socket(s) bound to `bind_addr`.
279    ///
280    /// The number of RX sockets bound is determined by [receiver
281    /// parallelism](UdpStreamBuilder::with_receiver_parallelism), and similarly, the number of TX sockets bound is
282    /// determined by [sender parallelism](UdpStreamBuilder::with_sender_parallelism). On non-Linux platforms, only
283    /// a single receiver and sender will be bound, regardless of the above.
284    ///
285    /// The returned instance is always ready to receive data.
286    /// It is also ready to send data
287    /// if the [counterparty](UdpStreamBuilder::with_counterparty) has been set.
288    ///
289    /// If `bind_addr` yields multiple addresses, binding will be attempted with each of the addresses
290    /// until one succeeds. If none of the addresses succeed in binding the socket(s),
291    /// the `AddrNotAvailable` error is returned.
292    ///
293    /// Note that wildcard addresses (such as `0.0.0.0`) are *not* considered as multiple addresses,
294    /// and such socket(s) will bind to all available interfaces at the system level.
295    pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
296        let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
297
298        let counterparty = Arc::new(
299            self.counterparty
300                .map(|s| OnceLock::from(SocketAddrStr::from(s)))
301                .unwrap_or_default(),
302        );
303        let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
304            (flume::bounded(q), flume::bounded(q))
305        } else {
306            (flume::unbounded(), flume::unbounded())
307        };
308
309        let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
310        let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
311        let mut bound_addr: Option<std::net::SocketAddr> = None;
312
313        // Try binding on all network addresses in `bind_addr`
314        for binding_to in bind_addr.to_socket_addrs()? {
315            debug!(
316                %binding_to,
317                num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
318            );
319
320            // TODO: split bound sockets into a separate cloneable object
321
322            // Try to bind sockets on the current network interface address
323            (0..num_socks_to_bind)
324                .map(|sock_id| {
325                    let domain = match &binding_to {
326                        std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
327                        std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
328                    };
329
330                    // Bind a new non-blocking UDP socket
331                    let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
332                    if num_socks_to_bind > 1 {
333                        sock.set_reuse_address(true)?; // Needed for every next socket with non-wildcard IP
334                        sock.set_reuse_port(true)?; // Needed on Linux to evenly distribute datagrams
335                    }
336                    sock.set_nonblocking(true)?;
337                    sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
338
339                    // Determine the address we bound this socket to, so we can also bind the others
340                    let socket_bound_addr = sock
341                        .local_addr()?
342                        .as_socket()
343                        .ok_or(std::io::Error::other("invalid socket type"))?;
344
345                    match bound_addr {
346                        None => bound_addr = Some(socket_bound_addr),
347                        Some(addr) if addr != socket_bound_addr => {
348                            return Err(std::io::Error::other(format!(
349                                "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
350                            )));
351                        }
352                        _ => {}
353                    }
354
355                    let sock = Arc::new(UdpSocket::from_std(sock.into())?);
356                    debug!(
357                        socket_id = sock_id,
358                        addr = %socket_bound_addr,
359                        "bound UDP socket"
360                    );
361
362                    Ok((sock_id, sock))
363                })
364                .filter_map(|result| match result {
365                    Ok(bound) => Some(bound),
366                    Err(error) => {
367                        error!(
368                            %binding_to,
369                            %error,
370                            "failed to bind udp socket"
371                        );
372                        None
373                    }
374                })
375                .for_each(|(sock_id, sock)| {
376                    if sock_id < num_tx_socks {
377                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
378                            sock_id,
379                            sock.clone(),
380                            egress_rx.clone(),
381                            counterparty.clone(),
382                        )));
383                    }
384                    if sock_id < num_rx_socks {
385                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
386                            sock_id,
387                            sock.clone(),
388                            ingress_tx.clone(),
389                            counterparty.clone(),
390                            self.foreign_data_mode,
391                            self.buffer_size,
392                        )));
393                    }
394                });
395        }
396
397        Ok(ConnectedUdpStream {
398            ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
399            egress_tx: Some(Box::new(
400                egress_tx
401                    .into_sink()
402                    .sink_map_err(|e| std::io::Error::other(e.to_string())),
403            )),
404            socket_handles,
405            counterparty,
406            bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
407            state: State::Writing,
408        })
409    }
410}
411
412const QUEUE_DISPATCH_THRESHOLD: Duration = Duration::from_millis(150);
413
414impl ConnectedUdpStream {
415    /// Creates a receiver queue for the UDP stream.
416    fn setup_rx_queue(
417        socket_id: usize,
418        sock_rx: Arc<UdpSocket>,
419        ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
420        counterparty: Arc<OnceLock<SocketAddrStr>>,
421        foreign_data_mode: ForeignDataMode,
422        buf_size: usize,
423    ) -> futures::future::BoxFuture<'static, ()> {
424        let counterparty_rx = counterparty.clone();
425        async move {
426            let mut buffer = vec![0u8; buf_size];
427            let mut done = false;
428            loop {
429                // Read data from the socket
430                let out_res = match sock_rx.recv_from(&mut buffer).await {
431                    Ok((read, read_addr)) if read > 0 => {
432                        trace!(
433                            socket_id,
434                            udp_bound_addr = ?sock_rx.local_addr(),
435                            bytes = read,
436                            from = %read_addr,
437                            "received data from"
438                        );
439
440                        let addr = counterparty_rx.get_or_init(|| read_addr.into());
441
442                        #[cfg(all(feature = "prometheus", not(test)))]
443                        METRIC_UDP_INGRESS_LEN.observe(read as f64);
444
445                        // If the data is from a counterparty, or we accept anything, pass it
446                        if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
447                            let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
448                            Some(Ok(out_buffer))
449                        } else {
450                            match foreign_data_mode {
451                                ForeignDataMode::Discard => {
452                                    // Don't even bother sending an error about discarded stuff
453                                    warn!(
454                                        socket_id,
455                                        udp_bound_addr = ?sock_rx.local_addr(),
456                                        ?read_addr,
457                                        expected_addr = ?addr,
458                                        "discarded data, which didn't come from the expected address"
459                                    );
460                                    None
461                                }
462                                ForeignDataMode::Error => {
463                                    // Terminate here, the ingress_tx gets dropped
464                                    done = true;
465                                    Some(Err(std::io::Error::new(
466                                        ErrorKind::ConnectionRefused,
467                                        "data from foreign client not allowed",
468                                    )))
469                                }
470                                // ForeignDataMode::Accept has been handled above
471                                _ => unreachable!(),
472                            }
473                        }
474                    }
475                    Ok(_) => {
476                        // Read EOF, terminate here, the ingress_tx gets dropped
477                        trace!(
478                            socket_id,
479                            udp_bound_addr = ?sock_rx.local_addr(),
480                            "read EOF on socket"
481                        );
482                        done = true;
483                        None
484                    }
485                    Err(error) => {
486                        // Forward the error
487                        debug!(
488                            socket_id,
489                            udp_bound_addr = ?sock_rx.local_addr(),
490                            %error,
491                            "forwarded error from socket"
492                        );
493                        done = true;
494                        Some(Err(error))
495                    }
496                };
497
498                // Dispatch the received data to the queue.
499                // If the underlying queue is bounded, it will wait until there is space.
500                if let Some(out_res) = out_res {
501                    let start = std::time::Instant::now();
502                    if let Err(error) = ingress_tx.send_async(out_res).await {
503                        error!(
504                            socket_id,
505                            udp_bound_addr = ?sock_rx.local_addr(),
506                            %error,
507                            "failed to dispatch received data"
508                        );
509                        done = true;
510                    }
511                    let elapsed = start.elapsed();
512                    if elapsed > QUEUE_DISPATCH_THRESHOLD {
513                        warn!(
514                            ?elapsed,
515                            "udp queue dispatch took too long, consider increasing the queue size"
516                        );
517                    }
518                }
519
520                if done {
521                    trace!(
522                        socket_id,
523                        udp_bound_addr = ?sock_rx.local_addr(),
524                        "rx queue done"
525                    );
526                    break;
527                }
528            }
529        }
530        .boxed()
531    }
532
533    /// Creates a transmission queue for the UDP stream.
534    fn setup_tx_queue(
535        socket_id: usize,
536        sock_tx: Arc<UdpSocket>,
537        egress_rx: flume::Receiver<Box<[u8]>>,
538        counterparty: Arc<OnceLock<SocketAddrStr>>,
539    ) -> futures::future::BoxFuture<'static, ()> {
540        let counterparty_tx = counterparty.clone();
541        async move {
542            loop {
543                match egress_rx.recv_async().await {
544                    Ok(data) => {
545                        if let Some(target) = counterparty_tx.get() {
546                            if let Err(error) = sock_tx.send_to(&data, target.as_ref()).await {
547                                error!(
548                                    ?socket_id,
549                                    udp_bound_addr = ?sock_tx.local_addr(),
550                                    ?target,
551                                    %error,
552                                    "failed to send data"
553                                );
554                            }
555                            trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
556
557                            #[cfg(all(feature = "prometheus", not(test)))]
558                            METRIC_UDP_EGRESS_LEN.observe(data.len() as f64);
559                        } else {
560                            error!(
561                                ?socket_id,
562                                udp_bound_addr = ?sock_tx.local_addr(),
563                                "cannot send data, counterparty not set"
564                            );
565                            break;
566                        }
567                    }
568                    Err(error) => {
569                        error!(
570                            ?socket_id,
571                            udp_bound_addr = ?sock_tx.local_addr(),
572                            %error,
573                            "cannot receive more data from egress channel"
574                        );
575                        break;
576                    }
577                }
578                trace!(
579                    ?socket_id,
580                    udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
581                    "tx queue done"
582                );
583            }
584        }
585        .boxed()
586    }
587
588    /// Local address that all UDP sockets in this instance are bound to.
589    pub fn bound_address(&self) -> &std::net::SocketAddr {
590        &self.bound_to
591    }
592
593    /// Creates a new [builder](UdpStreamBuilder).
594    pub fn builder() -> UdpStreamBuilder {
595        UdpStreamBuilder::default()
596    }
597}
598
599impl tokio::io::AsyncRead for ConnectedUdpStream {
600    #[instrument(name = "ConnectedUdpStream::poll_read", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), rem = buf.remaining()) , ret)]
601    fn poll_read(
602        self: Pin<&mut Self>,
603        cx: &mut Context<'_>,
604        buf: &mut tokio::io::ReadBuf<'_>,
605    ) -> Poll<std::io::Result<()>> {
606        ready!(self.project().ingress_rx.poll_read(cx, buf))?;
607        trace!(bytes = buf.filled().len(), "read bytes");
608        Poll::Ready(Ok(()))
609    }
610}
611
612impl tokio::io::AsyncWrite for ConnectedUdpStream {
613    #[instrument(name = "ConnectedUdpStream::poll_write", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get(), len = buf.len()) , ret)]
614    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
615        let this = self.project();
616        if let Some(sender) = this.egress_tx.get_mut() {
617            loop {
618                match *this.state {
619                    State::Writing => {
620                        ready!(sender.poll_ready_unpin(cx))?;
621
622                        let len = buf.iter().len();
623                        sender.start_send_unpin(Box::from(buf))?;
624                        *this.state = State::Flushing(len);
625                    }
626                    State::Flushing(len) => {
627                        // Explicitly flush after each data sent
628                        let res = ready!(sender.poll_flush_unpin(cx)).map(|_| len);
629                        *this.state = State::Writing;
630
631                        return Poll::Ready(res);
632                    }
633                }
634            }
635        } else {
636            Poll::Ready(Err(std::io::Error::new(
637                ErrorKind::NotConnected,
638                "udp stream is closed",
639            )))
640        }
641    }
642
643    #[instrument(name = "ConnectedUdpStream::poll_flush", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
644    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
645        let this = self.project();
646        if let Some(sender) = this.egress_tx.as_pin_mut() {
647            sender.poll_flush(cx).map_err(std::io::Error::other)
648        } else {
649            Poll::Ready(Err(std::io::Error::new(
650                ErrorKind::NotConnected,
651                "udp stream is closed",
652            )))
653        }
654    }
655
656    #[instrument(name = "ConnectedUdpStream::poll_shutdown", level = "trace", skip(self, cx), fields(counterparty = ?self.counterparty.get()) , ret)]
657    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
658        let this = self.project();
659        if let Some(sender) = this.egress_tx.as_pin_mut() {
660            let ret = ready!(sender.poll_close(cx));
661
662            this.socket_handles.iter().for_each(|handle| {
663                handle.abort();
664            });
665
666            Poll::Ready(ret)
667        } else {
668            Poll::Ready(Err(std::io::Error::new(
669                ErrorKind::NotConnected,
670                "udp stream is closed",
671            )))
672        }
673    }
674}
675
676#[pin_project::pinned_drop]
677impl PinnedDrop for ConnectedUdpStream {
678    fn drop(self: Pin<&mut Self>) {
679        debug!(binding = ?self.bound_to,"dropping ConnectedUdpStream");
680        self.project().socket_handles.iter().for_each(|handle| {
681            handle.abort();
682        })
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use anyhow::Context;
689    use futures::{future::Either, pin_mut};
690    use parameterized::parameterized;
691    use tokio::{
692        io::{AsyncReadExt, AsyncWriteExt},
693        net::UdpSocket,
694    };
695
696    use super::*;
697
698    #[parameterized(parallelism = {None, Some(2), Some(0)})]
699    #[parameterized_macro(tokio::test)]
700    //#[parameterized_macro(test_log::test(tokio::test))]
701    async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
702        const DATA_SIZE: usize = 200;
703
704        let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
705        let listen_addr = listener.local_addr()?;
706
707        // Simple echo UDP server
708        tokio::task::spawn(async move {
709            loop {
710                let mut buf = [0u8; DATA_SIZE];
711                let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
712                if read > 0 {
713                    assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
714                    listener.send_to(&buf, addr).await.expect("send must not fail");
715                }
716            }
717        });
718
719        let mut builder = ConnectedUdpStream::builder()
720            .with_buffer_size(1024)
721            .with_queue_size(512)
722            .with_counterparty(listen_addr);
723
724        if let Some(parallelism) = parallelism {
725            builder = builder.with_receiver_parallelism(parallelism);
726        }
727
728        let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
729
730        for _ in 1..1000 {
731            let mut w_buf = [0u8; DATA_SIZE];
732            hopr_crypto_random::random_fill(&mut w_buf);
733            let written = stream.write(&w_buf).await?;
734            assert_eq!(written, DATA_SIZE);
735
736            let mut r_buf = [0u8; DATA_SIZE];
737            let read = stream.read_exact(&mut r_buf).await?;
738            assert_eq!(read, DATA_SIZE);
739
740            assert_eq!(w_buf, r_buf);
741        }
742
743        stream.shutdown().await?;
744
745        Ok(())
746    }
747
748    #[tokio::test]
749    async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
750        const BUF_SIZE: usize = 1024;
751        const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
752
753        let mut listener = ConnectedUdpStream::builder()
754            .with_buffer_size(BUF_SIZE)
755            .with_queue_size(512)
756            .build(("127.0.0.1", 0))
757            .context("bind listener")?;
758
759        let bound_addr = *listener.bound_address();
760
761        let jh = tokio::task::spawn(async move {
762            let mut buf = [0u8; BUF_SIZE / 4];
763            let mut vec = Vec::<u8>::new();
764            loop {
765                let sz = listener.read(&mut buf).await.unwrap();
766                if sz > 0 {
767                    vec.extend_from_slice(&buf[..sz]);
768                    if vec.len() >= EXPECTED_DATA_LEN {
769                        return vec;
770                    }
771                } else {
772                    return vec;
773                }
774            }
775        });
776
777        let msg = [1u8; EXPECTED_DATA_LEN];
778        let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
779
780        sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
781        sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
782
783        let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
784        pin_mut!(timeout);
785        pin_mut!(jh);
786
787        match futures::future::select(jh, timeout).await {
788            Either::Left((Ok(v), _)) => {
789                assert_eq!(v.len(), EXPECTED_DATA_LEN);
790                assert_eq!(v.as_slice(), &msg);
791                Ok(())
792            }
793            _ => Err(anyhow::anyhow!("timeout or invalid data")),
794        }
795    }
796}