hopr_transport_protocol/
stream.rs

1//! Infrastructure supporting converting a collection of [`libp2p::PeerId`] split [`libp2p_stream`] managed
2//! individual peer-to-peer [`libp2p::swarm::Stream`]s.
3
4use futures::{AsyncRead, AsyncReadExt, AsyncWrite, SinkExt as _, Stream, StreamExt};
5use libp2p::PeerId;
6use tokio_util::{
7    codec::{Decoder, Encoder, FramedRead, FramedWrite},
8    compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
9};
10
11#[async_trait::async_trait]
12pub trait BidirectionalStreamControl: std::fmt::Debug {
13    fn accept(
14        self,
15    ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
16
17    async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
18}
19
20pub async fn process_stream_protocol<C, V>(
21    codec: C,
22    control: V,
23) -> crate::errors::Result<(
24    futures::channel::mpsc::Sender<(PeerId, <C as Decoder>::Item)>, // impl Sink<(PeerId, <C as Decoder>::Item)>,
25    futures::channel::mpsc::Receiver<(PeerId, <C as Decoder>::Item)>, // impl Stream<Item = (PeerId, <C as Decoder>::Item)>,
26)>
27where
28    C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
29    <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
30    <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
31    <C as Decoder>::Item: Send + 'static,
32    V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
33{
34    let (tx_out, rx_out) = futures::channel::mpsc::channel::<(PeerId, <C as Decoder>::Item)>(10_000);
35    let (tx_in, rx_in) = futures::channel::mpsc::channel::<(PeerId, <C as Decoder>::Item)>(10_000);
36
37    let cache_out = moka::future::Cache::new(2000);
38
39    let incoming = control
40        .clone()
41        .accept()
42        .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
43
44    let cache_ingress = cache_out.clone();
45    let codec_ingress = codec.clone();
46    let tx_in_ingress = tx_in.clone();
47
48    // terminated when the incoming is dropped
49    let _ingress_process = hopr_async_runtime::prelude::spawn(incoming.for_each(move |(peer_id, stream)| {
50        let codec = codec_ingress.clone();
51        let cache = cache_ingress.clone();
52        let tx_in = tx_in_ingress.clone();
53
54        tracing::debug!(peer = %peer_id, "Received incoming peer-to-peer stream");
55
56        async move {
57            let (stream_rx, stream_tx) = stream.split();
58            let (send, recv) = futures::channel::mpsc::channel::<<C as Decoder>::Item>(1000);
59
60            hopr_async_runtime::prelude::spawn(
61                recv.map(Ok)
62                    .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
63            );
64            hopr_async_runtime::prelude::spawn(
65                FramedRead::new(stream_rx.compat(), codec)
66                    .filter_map(move |v| async move {
67                        match v {
68                            Ok(v) => Some((peer_id, v)),
69                            Err(e) => {
70                                tracing::error!(error = %e, "Error decoding object from the underlying stream");
71                                None
72                            }
73                        }
74                    })
75                    .map(Ok)
76                    .forward(tx_in),
77            );
78            cache.insert(peer_id, send).await;
79        }
80    }));
81
82    // terminated when the rx_in is dropped
83    let _egress_process = hopr_async_runtime::prelude::spawn(rx_out.for_each(move |(peer_id, msg)| {
84        let cache = cache_out.clone();
85        let control = control.clone();
86        let codec = codec.clone();
87        let tx_in = tx_in.clone();
88
89        async move {
90            let cache = cache.clone();
91
92            let cached = cache
93                .optionally_get_with(peer_id, async move {
94                    let r = control.open(peer_id).await;
95                    match r {
96                        Ok(stream) => {
97                            tracing::debug!(peer = %peer_id, "Opening outgoing peer-to-peer stream");
98
99                            let (stream_rx, stream_tx) = stream.split();
100                            let (send, recv) = futures::channel::mpsc::channel::<<C as Decoder>::Item>(1000);
101
102                            hopr_async_runtime::prelude::spawn(
103                                recv.map(Ok)
104                                    .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
105                            );
106                            hopr_async_runtime::prelude::spawn(
107                                FramedRead::new(stream_rx.compat(), codec)
108                                    .filter_map(move |v| async move {
109                                        match v {
110                                            Ok(v) => Some((peer_id, v)),
111                                            Err(e) => {
112                                                tracing::error!(error = %e, "Error decoding object from the underlying stream");
113                                                None
114                                            }
115                                        }
116                                    })
117                                    .map(Ok)
118                                    .forward(tx_in),
119                            );
120
121                            Some(send)
122                        }
123                        Err(error) => {
124                            tracing::error!(peer = %peer_id, %error, "Error opening stream to peer");
125                            None
126                        }
127                    }
128                })
129                .await;
130
131            if let Some(mut cached) = cached {
132                cached.send(msg).await.unwrap_or_else(|e| {
133                    tracing::error!(peer = %peer_id, error = %e, "Error sending message to peer");
134                });
135            } else {
136                tracing::error!(peer = %peer_id, "Error sending message to peer: the stream failed to be created and cached");
137            }
138        }
139    }));
140
141    Ok((tx_out, rx_in))
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use anyhow::Context;
148    use futures::SinkExt;
149
150    struct AsyncBinaryStreamChannel {
151        read: async_channel_io::ChannelReader,
152        write: async_channel_io::ChannelWriter,
153    }
154
155    impl AsyncBinaryStreamChannel {
156        pub fn new() -> Self {
157            let (write, read) = async_channel_io::pipe();
158            Self { read, write }
159        }
160    }
161
162    impl AsyncRead for AsyncBinaryStreamChannel {
163        fn poll_read(
164            self: std::pin::Pin<&mut Self>,
165            cx: &mut std::task::Context<'_>,
166            buf: &mut [u8],
167        ) -> std::task::Poll<std::io::Result<usize>> {
168            let mut pinned = std::pin::pin!(&mut self.get_mut().read);
169            pinned.as_mut().poll_read(cx, buf)
170        }
171    }
172
173    impl AsyncWrite for AsyncBinaryStreamChannel {
174        fn poll_write(
175            self: std::pin::Pin<&mut Self>,
176            cx: &mut std::task::Context<'_>,
177            buf: &[u8],
178        ) -> std::task::Poll<std::io::Result<usize>> {
179            let mut pinned = std::pin::pin!(&mut self.get_mut().write);
180            pinned.as_mut().poll_write(cx, buf)
181        }
182
183        fn poll_flush(
184            self: std::pin::Pin<&mut Self>,
185            cx: &mut std::task::Context<'_>,
186        ) -> std::task::Poll<std::io::Result<()>> {
187            let pinned = std::pin::pin!(&mut self.get_mut().write);
188            pinned.poll_flush(cx)
189        }
190
191        fn poll_close(
192            self: std::pin::Pin<&mut Self>,
193            cx: &mut std::task::Context<'_>,
194        ) -> std::task::Poll<std::io::Result<()>> {
195            let pinned = std::pin::pin!(&mut self.get_mut().write);
196            pinned.poll_close(cx)
197        }
198    }
199
200    #[async_std::test]
201    async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
202        let stream = AsyncBinaryStreamChannel::new();
203        let codec = tokio_util::codec::BytesCodec::new();
204
205        let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
206        let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
207
208        let (stream_rx, stream_tx) = stream.split();
209        let (mut tx, rx) = (
210            FramedWrite::new(stream_tx.compat_write(), codec.clone()),
211            FramedRead::new(stream_rx.compat(), codec),
212        );
213        tx.send(value)
214            .await
215            .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
216
217        futures::pin_mut!(rx);
218
219        assert_eq!(
220            rx.next().await.context("Value must be present")??,
221            tokio_util::bytes::BytesMut::from(expected.as_ref())
222        );
223
224        Ok(())
225    }
226}