1use std::sync::Arc;
5
6use futures::{
7 AsyncRead, AsyncReadExt, AsyncWrite, SinkExt as _, Stream, StreamExt,
8 channel::mpsc::{Receiver, Sender, channel},
9};
10use libp2p::PeerId;
11use tokio_util::{
12 codec::{Decoder, Encoder, FramedRead, FramedWrite},
13 compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
14};
15
16#[async_trait::async_trait]
17pub trait BidirectionalStreamControl: std::fmt::Debug {
18 fn accept(
19 self,
20 ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
21
22 async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
23}
24
25pub async fn process_stream_protocol<C, V>(
26 codec: C,
27 control: V,
28) -> crate::errors::Result<(
29 Sender<(PeerId, <C as Decoder>::Item)>, Receiver<(PeerId, <C as Decoder>::Item)>, )>
32where
33 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
34 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
35 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
36 <C as Decoder>::Item: Clone + Send + 'static,
37 V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
38{
39 let (tx_out, rx_out) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
40 let (tx_in, rx_in) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
41
42 let cache_out = moka::future::Cache::new(2000);
43
44 let incoming = control
45 .clone()
46 .accept()
47 .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
48
49 let cache_ingress = cache_out.clone();
50 let codec_ingress = codec.clone();
51 let tx_in_ingress = tx_in.clone();
52
53 let _ingress_process = hopr_async_runtime::prelude::spawn(incoming.for_each(move |(peer_id, stream)| {
55 let codec = codec_ingress.clone();
56 let cache = cache_ingress.clone();
57 let tx_in = tx_in_ingress.clone();
58
59 tracing::debug!(peer = %peer_id, "Received incoming peer-to-peer stream");
60
61 async move {
62 let (stream_rx, stream_tx) = stream.split();
63 let (send, recv) = channel::<<C as Decoder>::Item>(1000);
64 let cache_internal = cache.clone();
65
66 hopr_async_runtime::prelude::spawn(recv.map(Ok).forward({
67 let mut fw = FramedWrite::new(stream_tx.compat_write(), codec.clone());
68 fw.set_backpressure_boundary(1); fw
70 }));
71 hopr_async_runtime::prelude::spawn(async move {
72 if let Err(error) = FramedRead::new(stream_rx.compat(), codec)
73 .filter_map(move |v| async move {
74 match v {
75 Ok(v) => Some((peer_id, v)),
76 Err(e) => {
77 tracing::error!(error = %e, "Error decoding object from the underlying stream");
78 None
79 }
80 }
81 })
82 .map(Ok)
83 .forward(tx_in)
84 .await
85 {
86 tracing::error!(peer = %peer_id, %error, "Incoming stream failed on reading");
87 }
88 cache_internal.invalidate(&peer_id).await;
89 });
90 cache.insert(peer_id, send).await;
91 }
92 }));
93
94 let _egress_process = hopr_async_runtime::prelude::spawn(rx_out.for_each(move |(peer_id, msg)| {
96 let cache = cache_out.clone();
97 let control = control.clone();
98 let codec = codec.clone();
99 let tx_in = tx_in.clone();
100
101 async move {
102 let cache = cache.clone();
103
104 if let Some(mut cached) = cache.get(&peer_id).await {
105 if let Err(error) = cached.send(msg.clone()).await {
106 tracing::error!(peer = %peer_id, %error, "Error sending message to peer from the cached connection");
107 cache.invalidate(&peer_id).await;
108 } else {
109 tracing::trace!(peer = %peer_id, "Message sent over an existing transport stream");
110 return;
111 }
112 }
113
114 let cached: std::result::Result<Sender<<C as Decoder>::Item>, Arc<anyhow::Error>> = cache
115 .try_get_with(peer_id, async move {
116 let stream = control.open(peer_id).await.map_err(|e| anyhow::anyhow!("{e}"))?;
117 tracing::debug!(peer = %peer_id, "Opening outgoing peer-to-peer stream");
118
119 let (stream_rx, stream_tx) = stream.split();
120 let (send, recv) = channel::<<C as Decoder>::Item>(1000);
121
122 hopr_async_runtime::prelude::spawn(
123 recv.map(Ok)
124 .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
125 );
126 hopr_async_runtime::prelude::spawn(
127 FramedRead::new(stream_rx.compat(), codec)
128 .filter_map(move |v| async move {
129 match v {
130 Ok(v) => Some((peer_id, v)),
131 Err(e) => {
132 tracing::error!(error = %e, "Error decoding object from the underlying stream");
133 None
134 }
135 }
136 })
137 .map(Ok)
138 .forward(tx_in),
139 );
140
141 Ok(send)
142 })
143 .await;
144
145 match cached {
146 Ok(mut cached) => {
147 if let Err(error) = cached.send(msg).await {
148 tracing::error!(peer = %peer_id, %error, "Error sending message to peer");
149 cache.invalidate(&peer_id).await;
150 }
151 },
152 Err(error) => {
153 tracing::error!(peer = %peer_id, %error, "Failed to open a stream to peer");
154 },
155 }
156 }
157 }));
158
159 Ok((tx_out, rx_in))
160}
161
162#[cfg(test)]
163mod tests {
164 use anyhow::Context;
165 use futures::SinkExt;
166
167 use super::*;
168
169 struct AsyncBinaryStreamChannel {
170 read: async_channel_io::ChannelReader,
171 write: async_channel_io::ChannelWriter,
172 }
173
174 impl AsyncBinaryStreamChannel {
175 pub fn new() -> Self {
176 let (write, read) = async_channel_io::pipe();
177 Self { read, write }
178 }
179 }
180
181 impl AsyncRead for AsyncBinaryStreamChannel {
182 fn poll_read(
183 self: std::pin::Pin<&mut Self>,
184 cx: &mut std::task::Context<'_>,
185 buf: &mut [u8],
186 ) -> std::task::Poll<std::io::Result<usize>> {
187 let mut pinned = std::pin::pin!(&mut self.get_mut().read);
188 pinned.as_mut().poll_read(cx, buf)
189 }
190 }
191
192 impl AsyncWrite for AsyncBinaryStreamChannel {
193 fn poll_write(
194 self: std::pin::Pin<&mut Self>,
195 cx: &mut std::task::Context<'_>,
196 buf: &[u8],
197 ) -> std::task::Poll<std::io::Result<usize>> {
198 let mut pinned = std::pin::pin!(&mut self.get_mut().write);
199 pinned.as_mut().poll_write(cx, buf)
200 }
201
202 fn poll_flush(
203 self: std::pin::Pin<&mut Self>,
204 cx: &mut std::task::Context<'_>,
205 ) -> std::task::Poll<std::io::Result<()>> {
206 let pinned = std::pin::pin!(&mut self.get_mut().write);
207 pinned.poll_flush(cx)
208 }
209
210 fn poll_close(
211 self: std::pin::Pin<&mut Self>,
212 cx: &mut std::task::Context<'_>,
213 ) -> std::task::Poll<std::io::Result<()>> {
214 let pinned = std::pin::pin!(&mut self.get_mut().write);
215 pinned.poll_close(cx)
216 }
217 }
218
219 #[tokio::test]
220 async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
221 let stream = AsyncBinaryStreamChannel::new();
222 let codec = tokio_util::codec::BytesCodec::new();
223
224 let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
225 let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
226
227 let (stream_rx, stream_tx) = stream.split();
228 let (mut tx, rx) = (
229 FramedWrite::new(stream_tx.compat_write(), codec),
230 FramedRead::new(stream_rx.compat(), codec),
231 );
232 tx.send(value)
233 .await
234 .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
235
236 futures::pin_mut!(rx);
237
238 assert_eq!(
239 rx.next().await.context("Value must be present")??,
240 tokio_util::bytes::BytesMut::from(expected.as_ref())
241 );
242
243 Ok(())
244 }
245}