hopr_network_types/
udp.rs

1use crate::utils::SocketAddrStr;
2use futures::{pin_mut, ready, FutureExt, Sink, SinkExt};
3use std::fmt::Debug;
4use std::io::ErrorKind;
5use std::num::NonZeroUsize;
6use std::pin::Pin;
7use std::sync::{Arc, OnceLock};
8use std::task::{Context, Poll};
9use tokio::net::UdpSocket;
10use tracing::{debug, error, trace, warn};
11
12type BoxIoSink<T> = Box<dyn Sink<T, Error = std::io::Error> + Send + Unpin>;
13
14#[cfg(all(feature = "prometheus", not(test)))]
15lazy_static::lazy_static! {
16    static ref METRIC_UDP_INGRESS_LEN: hopr_metrics::MultiHistogram =
17        hopr_metrics::MultiHistogram::new(
18            "hopr_udp_ingress_packet_len",
19            "UDP packet lengths on ingress per counterparty",
20            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
21            &["counterparty"]
22    ).unwrap();
23    static ref METRIC_UDP_EGRESS_LEN: hopr_metrics::MultiHistogram =
24        hopr_metrics::MultiHistogram::new(
25            "hopr_udp_egress_packet_len",
26            "UDP packet lengths on egress per counterparty",
27            vec![20.0, 40.0, 80.0, 160.0, 320.0, 640.0, 1280.0, 2560.0, 5120.0],
28            &["counterparty"]
29    ).unwrap();
30}
31
32/// Mimics TCP-like stream functionality on a UDP socket by restricting it to a single
33/// counterparty and implements [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`].
34/// The instance is always constructed using a [`UdpStreamBuilder`].
35///
36/// To set a counterparty, one of the following must happen:
37/// 1) setting it during build via [`UdpStreamBuilder::with_counterparty`]
38/// 2) receiving some data from the other side.
39///
40/// Whatever of the above happens first, sets the counterparty.
41/// Once the counterparty is set, all data sent and received will be sent or filtered by this
42/// counterparty address.
43///
44/// If data from another party is received, an error is raised, unless the object has been constructed
45/// with [`ForeignDataMode::Discard`] or [`ForeignDataMode::Accept`] setting.
46///
47/// This object is also capable of parallel processing on a UDP socket.
48/// If [parallelism](UdpStreamBuilder::with_receiver_parallelism) is set, the instance will create
49/// multiple sockets with `SO_REUSEADDR` and spin parallel tasks that will coordinate data and
50/// transmission retrieval using these sockets. This is driven by RX/TX MPMC queues, which are
51/// per-default unbounded (see [queue size](UdpStreamBuilder::with_queue_size) for details).
52pub struct ConnectedUdpStream {
53    socket_handles: Vec<tokio::task::JoinHandle<()>>,
54    ingress_rx: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
55    egress_tx: Option<BoxIoSink<Box<[u8]>>>,
56    counterparty: Arc<OnceLock<SocketAddrStr>>,
57    bound_to: std::net::SocketAddr,
58}
59
60/// Defines what happens when data from another [`SocketAddr`](std::net::SocketAddr) arrives
61/// into the [`ConnectedUdpStream`] (other than the one that is considered a counterparty for that
62/// instance).
63#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
64pub enum ForeignDataMode {
65    /// Foreign data are simply discarded.
66    Discard,
67    /// Foreign data are accepted as if they arrived from the set counterparty.
68    Accept,
69    /// Error is raised on the `poll_read` attempt.
70    #[default]
71    Error,
72}
73
74/// Determines how many parallel readers or writer sockets should be bound in [`ConnectedUdpStream`].
75///
76/// Each UDP socket is bound with `SO_REUSEADDR` and `SO_REUSEPORT` to facilitate parallel processing
77/// of send and/or receive operations.
78///
79/// **NOTE**: This is a Linux-specific optimization, and it will have no effect on other systems.
80///
81/// - If some [`Specific`](UdpStreamParallelism::Specific) value `n` > 0 is given, the [`ConnectedUdpStream`] will bind `n` sockets.
82/// - If [`Auto`](UdpStreamParallelism::Auto) is given, the number of sockets bound by [`ConnectedUdpStream`] is determined by [`std::thread::available_parallelism`].
83///
84/// The default is `Specific(1)`.
85///
86/// Always use [`into_num_tasks`](UdpStreamParallelism::into_num_tasks) or [`split_evenly_with`](UdpStreamParallelism::split_evenly_with)
87/// to determine the correct number of sockets to spawn.
88#[derive(Copy, Clone, Debug, PartialEq, Eq)]
89pub enum UdpStreamParallelism {
90    /// Bind as many sender or receiver sockets as given by [`std::thread::available_parallelism`].
91    Auto,
92    /// Bind specific number of sender or receiver sockets.
93    Specific(NonZeroUsize),
94}
95
96impl Default for UdpStreamParallelism {
97    fn default() -> Self {
98        Self::Specific(NonZeroUsize::MIN)
99    }
100}
101
102impl UdpStreamParallelism {
103    fn avail_parallelism() -> usize {
104        // On non-Linux system this will always default to 1, since the
105        // multiple UDP socket optimization is not possible for those platforms.
106        std::thread::available_parallelism()
107            .map(|n| {
108                if cfg!(target_os = "linux") {
109                    n
110                } else {
111                    NonZeroUsize::MIN
112                }
113            })
114            .unwrap_or_else(|e| {
115                warn!(error = %e, "failed to determine available parallelism, defaulting to 1.");
116                NonZeroUsize::MIN
117            })
118            .into()
119    }
120
121    /// Returns the number of sockets for this and the `other` instance
122    /// when they evenly split the available CPU parallelism.
123    pub fn split_evenly_with(self, other: UdpStreamParallelism) -> (usize, usize) {
124        let cpu_half = (Self::avail_parallelism() / 2).max(1);
125
126        match (self, other) {
127            (UdpStreamParallelism::Auto, UdpStreamParallelism::Auto) => (cpu_half, cpu_half),
128            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Auto) => {
129                let a = cpu_half.min(a.into());
130                (a, cpu_half * 2 - a)
131            }
132            (UdpStreamParallelism::Auto, UdpStreamParallelism::Specific(b)) => {
133                let b = cpu_half.min(b.into());
134                (cpu_half * 2 - b, b)
135            }
136            (UdpStreamParallelism::Specific(a), UdpStreamParallelism::Specific(b)) => {
137                (cpu_half.min(a.into()), cpu_half.min(b.into()))
138            }
139        }
140    }
141
142    /// Calculates the actual number of tasks for this instance.
143    ///
144    /// The returned value is never more than the maximum available CPU parallelism.
145    pub fn into_num_tasks(self) -> usize {
146        let avail_parallelism = Self::avail_parallelism();
147        match self {
148            UdpStreamParallelism::Auto => avail_parallelism,
149            UdpStreamParallelism::Specific(n) => usize::from(n).min(avail_parallelism),
150        }
151    }
152}
153
154impl From<usize> for UdpStreamParallelism {
155    fn from(value: usize) -> Self {
156        NonZeroUsize::new(value).map(Self::Specific).unwrap_or_default()
157    }
158}
159
160impl From<Option<usize>> for UdpStreamParallelism {
161    fn from(value: Option<usize>) -> Self {
162        value.map(UdpStreamParallelism::from).unwrap_or_default()
163    }
164}
165
166/// Builder object for the [`ConnectedUdpStream`].
167///
168/// If you wish to use defaults, do `UdpStreamBuilder::default().build(addr)`.
169#[derive(Debug, Clone)]
170pub struct UdpStreamBuilder {
171    foreign_data_mode: ForeignDataMode,
172    buffer_size: usize,
173    queue_size: Option<usize>,
174    receiver_parallelism: UdpStreamParallelism,
175    sender_parallelism: UdpStreamParallelism,
176    counterparty: Option<std::net::SocketAddr>,
177}
178
179impl Default for UdpStreamBuilder {
180    fn default() -> Self {
181        Self {
182            buffer_size: 2048,
183            foreign_data_mode: Default::default(),
184            queue_size: None,
185            receiver_parallelism: Default::default(),
186            sender_parallelism: Default::default(),
187            counterparty: None,
188        }
189    }
190}
191
192impl UdpStreamBuilder {
193    /// Defines the behavior when data from an unexpected source arrive into the socket.
194    /// See [`ForeignDataMode`] for details.
195    ///
196    /// Default is [`ForeignDataMode::Error`].
197    pub fn with_foreign_data_mode(mut self, mode: ForeignDataMode) -> Self {
198        self.foreign_data_mode = mode;
199        self
200    }
201
202    /// The size of the UDP receive buffer.
203    ///
204    /// This size must be at least the size of the MTU, otherwise the unread UDP data that
205    /// does not fit this buffer will be discarded.
206    ///
207    /// Default is 2048.
208    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
209        self.buffer_size = buffer_size;
210        self
211    }
212
213    /// Size of the TX/RX queue that dispatches data of reads from/writings to
214    /// the sockets.
215    ///
216    /// This an important back-pressure mechanism when dispatching received data from
217    /// fast senders.
218    /// Reduces the maximum memory consumed by the object, which is given by:
219    /// [`buffer_size`](UdpStreamBuilder::with_buffer_size) *
220    /// [`queue_size`](UdpStreamBuilder::with_queue_size)
221    ///
222    /// Default is unbounded.
223    pub fn with_queue_size(mut self, queue_size: usize) -> Self {
224        self.queue_size = Some(queue_size);
225        self
226    }
227
228    /// Sets how many parallel receiving sockets should be bound.
229    ///
230    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
231    ///
232    /// Default is `1`.
233    pub fn with_receiver_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
234        self.receiver_parallelism = parallelism.into();
235        self
236    }
237
238    /// Sets how many parallel sending sockets should be bound.
239    ///
240    /// Has no effect on non-Linux machines. See [`UdpStreamParallelism`] for details.
241    ///
242    /// Default is `1`.
243    pub fn with_sender_parallelism<T: Into<UdpStreamParallelism>>(mut self, parallelism: T) -> Self {
244        self.sender_parallelism = parallelism.into();
245        self
246    }
247
248    /// Sets the expected counterparty for data sent/received by the UDP sockets.
249    ///
250    /// If not specified, the counterparty is determined from the first packet received.
251    /// However, no data can be sent up until this point.
252    /// Therefore, the value must be set if data are sent first rather than received.
253    /// If data is expected to be received first, the value does not need to be set.
254    ///
255    /// See [`ConnectedUdpStream`] and [`ForeignDataMode`] for details.
256    ///
257    /// Default is none.
258    pub fn with_counterparty(mut self, counterparty: std::net::SocketAddr) -> Self {
259        self.counterparty = Some(counterparty);
260        self
261    }
262
263    /// Builds the [`ConnectedUdpStream`] with UDP socket(s) bound to `bind_addr`.
264    ///
265    /// The number of RX sockets bound is determined by [receiver parallelism](UdpStreamBuilder::with_receiver_parallelism),
266    /// and similarly, the number of TX sockets bound is determined by [sender parallelism](UdpStreamBuilder::with_sender_parallelism).
267    /// On non-Linux platforms, only a single receiver and sender will be bound, regardless of the above.
268    ///
269    /// The returned instance is always ready to receive data.
270    /// It is also ready to send data
271    /// if the [counterparty](UdpStreamBuilder::with_counterparty) has been set.
272    ///
273    /// If `bind_addr` yields multiple addresses, binding will be attempted with each of the addresses
274    /// until one succeeds. If none of the addresses succeed in binding the socket(s),
275    /// the `AddrNotAvailable` error is returned.
276    ///
277    /// Note that wildcard addresses (such as `0.0.0.0`) are *not* considered as multiple addresses,
278    /// and such socket(s) will bind to all available interfaces at the system level.
279    pub fn build<A: std::net::ToSocketAddrs>(self, bind_addr: A) -> std::io::Result<ConnectedUdpStream> {
280        let (num_rx_socks, num_tx_socks) = self.receiver_parallelism.split_evenly_with(self.sender_parallelism);
281
282        let counterparty = Arc::new(
283            self.counterparty
284                .map(|s| OnceLock::from(SocketAddrStr::from(s)))
285                .unwrap_or_default(),
286        );
287        let ((ingress_tx, ingress_rx), (egress_tx, egress_rx)) = if let Some(q) = self.queue_size {
288            (flume::bounded(q), flume::bounded(q))
289        } else {
290            (flume::unbounded(), flume::unbounded())
291        };
292
293        let num_socks_to_bind = num_rx_socks.max(num_tx_socks);
294        let mut socket_handles = Vec::with_capacity(num_socks_to_bind);
295        let mut bound_addr: Option<std::net::SocketAddr> = None;
296
297        // Try binding on all network addresses in `bind_addr`
298        for binding_to in bind_addr.to_socket_addrs()? {
299            debug!(
300                %binding_to,
301                num_socks_to_bind, num_rx_socks, num_tx_socks, "binding UDP stream"
302            );
303
304            // TODO: split bound sockets into a separate cloneable object
305
306            // Try to bind sockets on the current network interface address
307            (0..num_socks_to_bind)
308                .map(|sock_id| {
309                    let domain = match &binding_to {
310                        std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
311                        std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
312                    };
313
314                    // Bind a new non-blocking UDP socket
315                    let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, None)?;
316                    if num_socks_to_bind > 1 {
317                        sock.set_reuse_address(true)?; // Needed for every next socket with non-wildcard IP
318                        sock.set_reuse_port(true)?; // Needed on Linux to evenly distribute datagrams
319                    }
320                    sock.set_nonblocking(true)?;
321                    sock.bind(&bound_addr.unwrap_or(binding_to).into())?;
322
323                    // Determine the address we bound this socket to, so we can also bind the others
324                    let socket_bound_addr = sock
325                        .local_addr()?
326                        .as_socket()
327                        .ok_or(std::io::Error::other("invalid socket type"))?;
328
329                    match bound_addr {
330                        None => bound_addr = Some(socket_bound_addr),
331                        Some(addr) if addr != socket_bound_addr => {
332                            return Err(std::io::Error::other(format!(
333                                "inconsistent binding address {addr} != {socket_bound_addr} on socket id {sock_id}"
334                            )))
335                        }
336                        _ => {}
337                    }
338
339                    let sock = Arc::new(UdpSocket::from_std(sock.into())?);
340                    debug!(
341                        socket_id = sock_id,
342                        addr = %socket_bound_addr,
343                        "bound UDP socket"
344                    );
345
346                    Ok((sock_id, sock))
347                })
348                .filter_map(|result| match result {
349                    Ok(bound) => Some(bound),
350                    Err(e) => {
351                        error!(
352                            %binding_to,
353                            "failed to bind udp socket: {e}"
354                        );
355                        None
356                    }
357                })
358                .for_each(|(sock_id, sock)| {
359                    if sock_id < num_tx_socks {
360                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_tx_queue(
361                            sock_id,
362                            sock.clone(),
363                            egress_rx.clone(),
364                            counterparty.clone(),
365                        )));
366                    }
367                    if sock_id < num_rx_socks {
368                        socket_handles.push(tokio::task::spawn(ConnectedUdpStream::setup_rx_queue(
369                            sock_id,
370                            sock.clone(),
371                            ingress_tx.clone(),
372                            counterparty.clone(),
373                            self.foreign_data_mode,
374                            self.buffer_size,
375                        )));
376                    }
377                });
378        }
379
380        Ok(ConnectedUdpStream {
381            ingress_rx: Box::new(tokio_util::io::StreamReader::new(ingress_rx.into_stream())),
382            egress_tx: Some(Box::new(
383                egress_tx
384                    .into_sink()
385                    .sink_map_err(|e| std::io::Error::other(e.to_string())),
386            )),
387            socket_handles,
388            counterparty,
389            bound_to: bound_addr.ok_or(ErrorKind::AddrNotAvailable)?,
390        })
391    }
392}
393
394impl ConnectedUdpStream {
395    /// Creates a receiver queue for the UDP stream.
396    fn setup_rx_queue(
397        socket_id: usize,
398        sock_rx: Arc<UdpSocket>,
399        ingress_tx: flume::Sender<std::io::Result<tokio_util::bytes::Bytes>>,
400        counterparty: Arc<OnceLock<SocketAddrStr>>,
401        foreign_data_mode: ForeignDataMode,
402        buf_size: usize,
403    ) -> futures::future::BoxFuture<'static, ()> {
404        let counterparty_rx = counterparty.clone();
405        async move {
406            let mut buffer = vec![0u8; buf_size];
407            let mut done = false;
408            loop {
409                // Read data from the socket
410                let out_res = match sock_rx.recv_from(&mut buffer).await {
411                    Ok((read, read_addr)) if read > 0 => {
412                        trace!(
413                            socket_id,
414                            udp_bound_addr = ?sock_rx.local_addr(),
415                            bytes = read,
416                            from = %read_addr,
417                            "received data from"
418                        );
419
420                        let addr = counterparty_rx.get_or_init(|| read_addr.into());
421
422                        #[cfg(all(feature = "prometheus", not(test)))]
423                        METRIC_UDP_INGRESS_LEN.observe(&[addr.as_str()], read as f64);
424
425                        // If the data is from a counterparty, or we accept anything, pass it
426                        if read_addr.eq(addr) || foreign_data_mode == ForeignDataMode::Accept {
427                            let out_buffer = tokio_util::bytes::Bytes::copy_from_slice(&buffer[..read]);
428                            Some(Ok(out_buffer))
429                        } else {
430                            match foreign_data_mode {
431                                ForeignDataMode::Discard => {
432                                    // Don't even bother sending an error about discarded stuff
433                                    warn!(
434                                        socket_id,
435                                        udp_bound_addr = ?sock_rx.local_addr(),
436                                        ?read_addr,
437                                        expected_addr = ?addr,
438                                        "discarded data, which didn't come from the expected address"
439                                    );
440                                    None
441                                }
442                                ForeignDataMode::Error => {
443                                    // Terminate here, the ingress_tx gets dropped
444                                    done = true;
445                                    Some(Err(std::io::Error::new(
446                                        ErrorKind::ConnectionRefused,
447                                        "data from foreign client not allowed",
448                                    )))
449                                }
450                                // ForeignDataMode::Accept has been handled above
451                                _ => unreachable!(),
452                            }
453                        }
454                    }
455                    Ok(_) => {
456                        // Read EOF, terminate here, the ingress_tx gets dropped
457                        trace!(
458                            socket_id,
459                            udp_bound_addr = ?sock_rx.local_addr(),
460                            "read EOF on socket"
461                        );
462                        done = true;
463                        None
464                    }
465                    Err(e) => {
466                        // Forward the error
467                        debug!(
468                            socket_id,
469                            udp_bound_addr = ?sock_rx.local_addr(),
470                            error = %e,
471                            "forwarded error from socket"
472                        );
473                        done = true;
474                        Some(Err(e))
475                    }
476                };
477
478                // Dispatch the received data to the queue.
479                // If the underlying queue is bounded, it will wait until there is space.
480                if let Some(out_res) = out_res {
481                    if let Err(err) = ingress_tx.send_async(out_res).await {
482                        error!(
483                            socket_id,
484                            udp_bound_addr = ?sock_rx.local_addr(),
485                            error = %err,
486                            "failed to dispatch received data"
487                        );
488                        done = true;
489                    }
490                }
491
492                if done {
493                    trace!(
494                        socket_id,
495                        udp_bound_addr = ?sock_rx.local_addr(),
496                        "rx queue done"
497                    );
498                    break;
499                }
500            }
501        }
502        .boxed()
503    }
504
505    /// Creates a transmission queue for the UDP stream.
506    fn setup_tx_queue(
507        socket_id: usize,
508        sock_tx: Arc<UdpSocket>,
509        egress_rx: flume::Receiver<Box<[u8]>>,
510        counterparty: Arc<OnceLock<SocketAddrStr>>,
511    ) -> futures::future::BoxFuture<'static, ()> {
512        let counterparty_tx = counterparty.clone();
513        async move {
514            loop {
515                match egress_rx.recv_async().await {
516                    Ok(data) => {
517                        if let Some(target) = counterparty_tx.get() {
518                            if let Err(e) = sock_tx.send_to(&data, target.as_ref()).await {
519                                error!(
520                                    ?socket_id,
521                                    udp_bound_addr = ?sock_tx.local_addr(),
522                                    ?target,
523                                    error = %e,
524                                    "failed to send data"
525                                );
526                            }
527                            trace!(socket_id, bytes = data.len(), ?target, "sent bytes to");
528
529                            #[cfg(all(feature = "prometheus", not(test)))]
530                            METRIC_UDP_EGRESS_LEN.observe(&[target.as_str()], data.len() as f64);
531                        } else {
532                            error!(
533                                ?socket_id,
534                                udp_bound_addr = ?sock_tx.local_addr(),
535                                "cannot send data, counterparty not set"
536                            );
537                            break;
538                        }
539                    }
540                    Err(e) => {
541                        error!(
542                            ?socket_id,
543                            udp_bound_addr = ?sock_tx.local_addr(),
544                            error = %e,
545                            "cannot receive more data from egress channel"
546                        );
547                        break;
548                    }
549                }
550                trace!(
551                    ?socket_id,
552                    udp_bound_addr = tracing::field::debug(sock_tx.local_addr()),
553                    "tx queue done"
554                );
555            }
556        }
557        .boxed()
558    }
559
560    /// Local address that all UDP sockets in this instance are bound to.
561    pub fn bound_address(&self) -> &std::net::SocketAddr {
562        &self.bound_to
563    }
564
565    /// Creates a new [builder](UdpStreamBuilder).
566    pub fn builder() -> UdpStreamBuilder {
567        UdpStreamBuilder::default()
568    }
569}
570
571impl Drop for ConnectedUdpStream {
572    fn drop(&mut self) {
573        self.socket_handles.iter().for_each(|handle| {
574            handle.abort();
575        })
576    }
577}
578
579impl tokio::io::AsyncRead for ConnectedUdpStream {
580    fn poll_read(
581        mut self: Pin<&mut Self>,
582        cx: &mut Context<'_>,
583        buf: &mut tokio::io::ReadBuf<'_>,
584    ) -> Poll<std::io::Result<()>> {
585        trace!(
586            remaining_bytes = buf.remaining(),
587            counterparty = ?self.counterparty.get(),
588            "polling read of from udp stream",
589        );
590        match Pin::new(&mut self.ingress_rx).poll_read(cx, buf) {
591            Poll::Ready(Ok(())) => {
592                let read = buf.filled().len();
593                trace!(bytes = read, "read bytes");
594                Poll::Ready(Ok(()))
595            }
596            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
597            Poll::Pending => Poll::Pending,
598        }
599    }
600}
601
602impl tokio::io::AsyncWrite for ConnectedUdpStream {
603    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
604        trace!(
605            bytes = buf.len(),
606            counterparty = ?self.counterparty.get(),
607            "polling write to udp stream",
608        );
609        if let Some(sender) = &mut self.egress_tx {
610            if let Err(e) = ready!(sender.poll_ready_unpin(cx)) {
611                return Poll::Ready(Err(e));
612            }
613
614            let len = buf.iter().len();
615            if let Err(e) = sender.start_send_unpin(Box::from(buf)) {
616                return Poll::Ready(Err(e));
617            }
618
619            // Explicitly flush after each data sent
620            pin_mut!(sender);
621            sender.poll_flush(cx).map_ok(|_| len)
622        } else {
623            Poll::Ready(Err(std::io::Error::new(
624                ErrorKind::NotConnected,
625                "udp stream is closed",
626            )))
627        }
628    }
629
630    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
631        trace!(
632            counterparty = ?self.counterparty.get(),
633            "polling flush to udp stream"
634        );
635        if let Some(sender) = &mut self.egress_tx {
636            pin_mut!(sender);
637            sender
638                .poll_flush(cx)
639                .map_err(|err| std::io::Error::other(err.to_string()))
640        } else {
641            Poll::Ready(Err(std::io::Error::new(
642                ErrorKind::NotConnected,
643                "udp stream is closed",
644            )))
645        }
646    }
647
648    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
649        trace!(
650            counterparty = ?self.counterparty.get(),
651            "polling close on udp stream"
652        );
653        // Take the sender to make sure it gets dropped
654        let mut taken_sender = self.egress_tx.take();
655        if let Some(sender) = &mut taken_sender {
656            pin_mut!(sender);
657            sender
658                .poll_close(cx)
659                .map_err(|err| std::io::Error::other(err.to_string()))
660        } else {
661            Poll::Ready(Err(std::io::Error::new(
662                ErrorKind::NotConnected,
663                "udp stream is closed",
664            )))
665        }
666    }
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use anyhow::Context;
673    use futures::future::Either;
674    use parameterized::parameterized;
675    use tokio::io::{AsyncReadExt, AsyncWriteExt};
676    use tokio::net::UdpSocket;
677
678    #[parameterized(parallelism = {None, Some(2), Some(0)})]
679    #[parameterized_macro(tokio::test)]
680    //#[parameterized_macro(test_log::test(tokio::test))]
681    async fn basic_udp_stream_tests(parallelism: Option<usize>) -> anyhow::Result<()> {
682        const DATA_SIZE: usize = 200;
683
684        let listener = UdpSocket::bind("127.0.0.1:0").await.context("bind listener")?;
685        let listen_addr = listener.local_addr()?;
686
687        // Simple echo UDP server
688        tokio::task::spawn(async move {
689            loop {
690                let mut buf = [0u8; DATA_SIZE];
691                let (read, addr) = listener.recv_from(&mut buf).await.expect("recv must not fail");
692                if read > 0 {
693                    assert_eq!(DATA_SIZE, read, "read size must be exactly {DATA_SIZE}");
694                    listener.send_to(&buf, addr).await.expect("send must not fail");
695                }
696            }
697        });
698
699        let mut builder = ConnectedUdpStream::builder()
700            .with_buffer_size(1024)
701            .with_queue_size(512)
702            .with_counterparty(listen_addr);
703
704        if let Some(parallelism) = parallelism {
705            builder = builder.with_receiver_parallelism(parallelism);
706        }
707
708        let mut stream = builder.build(("127.0.0.1", 0)).context("connection")?;
709
710        for _ in 1..1000 {
711            let mut w_buf = [0u8; DATA_SIZE];
712            hopr_crypto_random::random_fill(&mut w_buf);
713            let written = stream.write(&w_buf).await?;
714            assert_eq!(written, DATA_SIZE);
715
716            let mut r_buf = [0u8; DATA_SIZE];
717            let read = stream.read_exact(&mut r_buf).await?;
718            assert_eq!(read, DATA_SIZE);
719
720            assert_eq!(w_buf, r_buf);
721        }
722
723        stream.shutdown().await?;
724
725        Ok(())
726    }
727
728    #[tokio::test]
729    async fn udp_stream_should_process_sequential_writes() -> anyhow::Result<()> {
730        const BUF_SIZE: usize = 1024;
731        const EXPECTED_DATA_LEN: usize = BUF_SIZE + 500;
732
733        let mut listener = ConnectedUdpStream::builder()
734            .with_buffer_size(BUF_SIZE)
735            .with_queue_size(512)
736            .build(("127.0.0.1", 0))
737            .context("bind listener")?;
738
739        let bound_addr = *listener.bound_address();
740
741        let jh = tokio::task::spawn(async move {
742            let mut buf = [0u8; BUF_SIZE / 4];
743            let mut vec = Vec::<u8>::new();
744            loop {
745                let sz = listener.read(&mut buf).await.unwrap();
746                if sz > 0 {
747                    vec.extend_from_slice(&buf[..sz]);
748                    if vec.len() >= EXPECTED_DATA_LEN {
749                        return vec;
750                    }
751                } else {
752                    return vec;
753                }
754            }
755        });
756
757        let msg = [1u8; EXPECTED_DATA_LEN];
758        let sender = UdpSocket::bind(("127.0.0.1", 0)).await.context("bind")?;
759
760        sender.send_to(&msg[..BUF_SIZE], bound_addr).await?;
761        sender.send_to(&msg[BUF_SIZE..], bound_addr).await?;
762
763        let timeout = tokio::time::sleep(std::time::Duration::from_millis(1000));
764        pin_mut!(timeout);
765        pin_mut!(jh);
766
767        match futures::future::select(jh, timeout).await {
768            Either::Left((Ok(v), _)) => {
769                assert_eq!(v.len(), EXPECTED_DATA_LEN);
770                assert_eq!(v.as_slice(), &msg);
771                Ok(())
772            }
773            _ => Err(anyhow::anyhow!("timeout or invalid data")),
774        }
775    }
776}