hopr_transport_protocol/
stream.rs

1//! Infrastructure supporting converting a collection of [`libp2p::PeerId`] split [`libp2p_stream`] managed
2//! individual peer-to-peer [`libp2p::swarm::Stream`]s.
3
4use std::sync::Arc;
5
6use futures::{
7    AsyncRead, AsyncReadExt, AsyncWrite, SinkExt as _, Stream, StreamExt,
8    channel::mpsc::{Receiver, Sender, channel},
9};
10use libp2p::PeerId;
11use tokio_util::{
12    codec::{Decoder, Encoder, FramedRead, FramedWrite},
13    compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
14};
15
16#[async_trait::async_trait]
17pub trait BidirectionalStreamControl: std::fmt::Debug {
18    fn accept(
19        self,
20    ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
21
22    async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
23}
24
25pub async fn process_stream_protocol<C, V>(
26    codec: C,
27    control: V,
28) -> crate::errors::Result<(
29    Sender<(PeerId, <C as Decoder>::Item)>, // impl Sink<(PeerId, <C as Decoder>::Item)>,
30    Receiver<(PeerId, <C as Decoder>::Item)>, // impl Stream<Item = (PeerId, <C as Decoder>::Item)>,
31)>
32where
33    C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
34    <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
35    <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
36    <C as Decoder>::Item: Clone + Send + 'static,
37    V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
38{
39    let (tx_out, rx_out) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
40    let (tx_in, rx_in) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
41
42    let cache_out = moka::future::Cache::new(2000);
43
44    let incoming = control
45        .clone()
46        .accept()
47        .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
48
49    let cache_ingress = cache_out.clone();
50    let codec_ingress = codec.clone();
51    let tx_in_ingress = tx_in.clone();
52
53    // terminated when the incoming is dropped
54    let _ingress_process = hopr_async_runtime::prelude::spawn(incoming.for_each(move |(peer_id, stream)| {
55        let codec = codec_ingress.clone();
56        let cache = cache_ingress.clone();
57        let tx_in = tx_in_ingress.clone();
58
59        tracing::debug!(peer = %peer_id, "Received incoming peer-to-peer stream");
60
61        async move {
62            let (stream_rx, stream_tx) = stream.split();
63            let (send, recv) = channel::<<C as Decoder>::Item>(1000);
64            let cache_internal = cache.clone();
65
66            hopr_async_runtime::prelude::spawn(recv.map(Ok).forward({
67                let mut fw = FramedWrite::new(stream_tx.compat_write(), codec.clone());
68                fw.set_backpressure_boundary(1); // Low backpressure boundary to make sure each message is flushed after writing to buffer
69                fw
70            }));
71            hopr_async_runtime::prelude::spawn(async move {
72                if let Err(error) = FramedRead::new(stream_rx.compat(), codec)
73                    .filter_map(move |v| async move {
74                        match v {
75                            Ok(v) => Some((peer_id, v)),
76                            Err(e) => {
77                                tracing::error!(error = %e, "Error decoding object from the underlying stream");
78                                None
79                            }
80                        }
81                    })
82                    .map(Ok)
83                    .forward(tx_in)
84                    .await
85                {
86                    tracing::error!(peer = %peer_id, %error, "Incoming stream failed on reading");
87                }
88                cache_internal.invalidate(&peer_id).await;
89            });
90            cache.insert(peer_id, send).await;
91        }
92    }));
93
94    // terminated when the rx_in is dropped
95    let _egress_process = hopr_async_runtime::prelude::spawn(rx_out.for_each(move |(peer_id, msg)| {
96        let cache = cache_out.clone();
97        let control = control.clone();
98        let codec = codec.clone();
99        let tx_in = tx_in.clone();
100
101        async move {
102            let cache = cache.clone();
103
104            if let Some(mut cached) = cache.get(&peer_id).await {
105                if let Err(error) = cached.send(msg.clone()).await {
106                    tracing::error!(peer = %peer_id, %error, "Error sending message to peer from the cached connection");
107                    cache.invalidate(&peer_id).await;
108                } else {
109                    tracing::trace!(peer = %peer_id, "Message sent over an existing transport stream");
110                    return;
111                }
112            }
113
114            let cached: std::result::Result<Sender<<C as Decoder>::Item>, Arc<anyhow::Error>> = cache
115                .try_get_with(peer_id, async move {
116                    let stream = control.open(peer_id).await.map_err(|e| anyhow::anyhow!("{e}"))?;
117                    tracing::debug!(peer = %peer_id, "Opening outgoing peer-to-peer stream");
118
119                    let (stream_rx, stream_tx) = stream.split();
120                    let (send, recv) = channel::<<C as Decoder>::Item>(1000);
121
122                    hopr_async_runtime::prelude::spawn(
123                        recv.map(Ok)
124                            .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
125                    );
126                    hopr_async_runtime::prelude::spawn(
127                        FramedRead::new(stream_rx.compat(), codec)
128                            .filter_map(move |v| async move {
129                                match v {
130                                    Ok(v) => Some((peer_id, v)),
131                                    Err(e) => {
132                                        tracing::error!(error = %e, "Error decoding object from the underlying stream");
133                                        None
134                                    }
135                                }
136                            })
137                            .map(Ok)
138                            .forward(tx_in),
139                    );
140
141                    Ok(send)
142                })
143                .await;
144
145            match cached {
146                Ok(mut cached) => {
147                    if let Err(error) = cached.send(msg).await {
148                        tracing::error!(peer = %peer_id, %error, "Error sending message to peer");
149                        cache.invalidate(&peer_id).await;
150                    }
151                },
152                Err(error) => {
153                    tracing::error!(peer = %peer_id, %error, "Failed to open a stream to peer");
154                },
155            }
156        }
157    }));
158
159    Ok((tx_out, rx_in))
160}
161
162#[cfg(test)]
163mod tests {
164    use anyhow::Context;
165    use futures::SinkExt;
166
167    use super::*;
168
169    struct AsyncBinaryStreamChannel {
170        read: async_channel_io::ChannelReader,
171        write: async_channel_io::ChannelWriter,
172    }
173
174    impl AsyncBinaryStreamChannel {
175        pub fn new() -> Self {
176            let (write, read) = async_channel_io::pipe();
177            Self { read, write }
178        }
179    }
180
181    impl AsyncRead for AsyncBinaryStreamChannel {
182        fn poll_read(
183            self: std::pin::Pin<&mut Self>,
184            cx: &mut std::task::Context<'_>,
185            buf: &mut [u8],
186        ) -> std::task::Poll<std::io::Result<usize>> {
187            let mut pinned = std::pin::pin!(&mut self.get_mut().read);
188            pinned.as_mut().poll_read(cx, buf)
189        }
190    }
191
192    impl AsyncWrite for AsyncBinaryStreamChannel {
193        fn poll_write(
194            self: std::pin::Pin<&mut Self>,
195            cx: &mut std::task::Context<'_>,
196            buf: &[u8],
197        ) -> std::task::Poll<std::io::Result<usize>> {
198            let mut pinned = std::pin::pin!(&mut self.get_mut().write);
199            pinned.as_mut().poll_write(cx, buf)
200        }
201
202        fn poll_flush(
203            self: std::pin::Pin<&mut Self>,
204            cx: &mut std::task::Context<'_>,
205        ) -> std::task::Poll<std::io::Result<()>> {
206            let pinned = std::pin::pin!(&mut self.get_mut().write);
207            pinned.poll_flush(cx)
208        }
209
210        fn poll_close(
211            self: std::pin::Pin<&mut Self>,
212            cx: &mut std::task::Context<'_>,
213        ) -> std::task::Poll<std::io::Result<()>> {
214            let pinned = std::pin::pin!(&mut self.get_mut().write);
215            pinned.poll_close(cx)
216        }
217    }
218
219    #[tokio::test]
220    async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
221        let stream = AsyncBinaryStreamChannel::new();
222        let codec = tokio_util::codec::BytesCodec::new();
223
224        let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
225        let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
226
227        let (stream_rx, stream_tx) = stream.split();
228        let (mut tx, rx) = (
229            FramedWrite::new(stream_tx.compat_write(), codec),
230            FramedRead::new(stream_rx.compat(), codec),
231        );
232        tx.send(value)
233            .await
234            .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
235
236        futures::pin_mut!(rx);
237
238        assert_eq!(
239            rx.next().await.context("Value must be present")??,
240            tokio_util::bytes::BytesMut::from(expected.as_ref())
241        );
242
243        Ok(())
244    }
245}