1use std::sync::Arc;
5
6use futures::{
7 AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, SinkExt as _, Stream, StreamExt,
8 channel::mpsc::{Receiver, Sender, channel},
9};
10use libp2p::PeerId;
11use tokio_util::{
12 codec::{Decoder, Encoder, FramedRead, FramedWrite},
13 compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
14};
15
16const GLOBAL_STREAM_OPEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
20const MAX_CONCURRENT_PACKETS: usize = 30;
21
22#[async_trait::async_trait]
23pub trait BidirectionalStreamControl: std::fmt::Debug {
24 fn accept(
25 self,
26 ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
27
28 async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
29}
30
31fn build_peer_stream_io<S, C>(
32 peer: PeerId,
33 stream: S,
34 cache: moka::future::Cache<PeerId, Sender<<C as Decoder>::Item>>,
35 codec: C,
36 ingress_from_peers: Sender<(PeerId, <C as Decoder>::Item)>,
37) -> Sender<<C as Decoder>::Item>
38where
39 S: AsyncRead + AsyncWrite + Send + 'static,
40 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
41 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
42 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
43 <C as Decoder>::Item: AsRef<[u8]> + Clone + Send + 'static,
44{
45 let (stream_rx, stream_tx) = stream.split();
46 let (send, recv) = channel::<<C as Decoder>::Item>(1000);
47 let cache_internal = cache.clone();
48
49 let mut frame_writer = FramedWrite::new(stream_tx.compat_write(), codec.clone());
50
51 frame_writer.set_backpressure_boundary(1);
53
54 hopr_async_runtime::prelude::spawn(
56 recv.inspect(move |_| tracing::trace!(%peer, "writing message to peer stream"))
57 .map(Ok)
58 .forward(frame_writer)
59 .inspect(move |res| {
60 tracing::debug!(%peer, ?res, component = "stream", "writing stream with peer finished");
61 }),
62 );
63
64 hopr_async_runtime::prelude::spawn(
66 FramedRead::new(stream_rx.compat(), codec)
67 .filter_map(move |v| {
68 futures::future::ready(match v {
69 Ok(v) => {
70 tracing::trace!(%peer, "read message from peer stream");
71 Some((peer, v))
72 }
73 Err(error) => {
74 tracing::error!(%error, "Error decoding object from the underlying stream");
75 None
76 }
77 })
78 })
79 .map(Ok)
80 .forward(ingress_from_peers)
81 .inspect(move |res| match res {
82 Ok(_) => tracing::debug!(%peer, component = "stream", "incoming stream done reading"),
83 Err(error) => {
84 tracing::error!(%peer, %error, component = "stream", "incoming stream failed on reading")
85 }
86 })
87 .then(move |_| {
88 let peer = peer;
90 async move {
91 cache_internal.invalidate(&peer).await;
92 }
93 }),
94 );
95
96 tracing::trace!(%peer, "created new io for peer");
97 send
98}
99
100pub async fn process_stream_protocol<C, V>(
101 codec: C,
102 control: V,
103) -> crate::errors::Result<(
104 Sender<(PeerId, <C as Decoder>::Item)>, Receiver<(PeerId, <C as Decoder>::Item)>, )>
107where
108 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
109 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
110 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
111 <C as Decoder>::Item: AsRef<[u8]> + Clone + Send + 'static,
112 V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
113{
114 let (tx_out, rx_out) = channel::<(PeerId, <C as Decoder>::Item)>(100_000);
115 let (tx_in, rx_in) = channel::<(PeerId, <C as Decoder>::Item)>(100_000);
116
117 let cache_out = moka::future::Cache::builder()
118 .max_capacity(2000)
119 .eviction_listener(|key: Arc<PeerId>, _, cause| {
120 tracing::trace!(peer = %key.as_ref(), ?cause, "evicting stream for peer");
121 })
122 .build();
123
124 let incoming = control
125 .clone()
126 .accept()
127 .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
128
129 let cache_ingress = cache_out.clone();
130 let codec_ingress = codec.clone();
131 let tx_in_ingress = tx_in.clone();
132
133 let _ingress_process = hopr_async_runtime::prelude::spawn(
135 incoming
136 .for_each(move |(peer, stream)| {
137 let codec = codec_ingress.clone();
138 let cache = cache_ingress.clone();
139 let tx_in = tx_in_ingress.clone();
140
141 tracing::debug!(%peer, "received incoming peer-to-peer stream");
142 let send = build_peer_stream_io(peer, stream, cache.clone(), codec.clone(), tx_in.clone());
143
144 async move {
145 cache.insert(peer, send).await;
146 }
147 })
148 .inspect(|_| {
149 tracing::info!(
150 task = "ingress stream processing",
151 "long-running background task finished"
152 )
153 }),
154 );
155
156 let max_concurrent_packets = std::env::var("HOPR_TRANSPORT_MAX_CONCURRENT_PACKETS")
157 .ok()
158 .and_then(|v| v.parse().ok())
159 .unwrap_or(MAX_CONCURRENT_PACKETS);
160
161 let global_stream_open_timeout = std::env::var("HOPR_TRANSPORT_STREAM_OPEN_TIMEOUT_MS")
162 .ok()
163 .and_then(|v| v.parse().ok())
164 .map(std::time::Duration::from_millis)
165 .unwrap_or(GLOBAL_STREAM_OPEN_TIMEOUT);
166
167 let _egress_process = hopr_async_runtime::prelude::spawn(
169 rx_out
170 .inspect(|(peer, _)| tracing::trace!(%peer, "proceeding to deliver message to peer"))
171 .for_each_concurrent(max_concurrent_packets, move |(peer, msg)| {
172 let cache = cache_out.clone();
173 let control = control.clone();
174 let codec = codec.clone();
175 let tx_in = tx_in.clone();
176
177 async move {
178 let cache_clone = cache.clone();
179 tracing::trace!(%peer, "trying to deliver message to peer");
180
181 let cached: Result<Sender<<C as Decoder>::Item>, Arc<anyhow::Error>> = cache
182 .try_get_with(peer, async move {
183 tracing::trace!(%peer, "peer is not in cache, opening new stream");
184
185 use futures_time::future::FutureExt as TimeExt;
190 let stream = control
191 .open(peer)
192 .timeout(futures_time::time::Duration::from(global_stream_open_timeout))
193 .await
194 .map_err(|_| anyhow::anyhow!("timeout trying to open stream to {peer}"))?
195 .map_err(|e| anyhow::anyhow!("could not open outgoing peer-to-peer stream: {e}"))?;
196
197 tracing::debug!(%peer, "opening outgoing peer-to-peer stream");
198
199 Ok(build_peer_stream_io(
200 peer,
201 stream,
202 cache_clone.clone(),
203 codec.clone(),
204 tx_in.clone(),
205 ))
206 })
207 .await;
208
209 match cached {
210 Ok(mut cached) => {
211 if let Err(error) = cached.send(msg).await {
212 tracing::error!(%peer, %error, "error sending message to peer");
213 cache.invalidate(&peer).await;
214 } else {
215 tracing::trace!(%peer, "message sent to peer");
216 }
217 }
218 Err(error) => {
219 tracing::debug!(%peer, %error, "failed to open a stream to peer");
220 }
221 }
222 }
223 })
224 .inspect(|_| {
225 tracing::info!(
226 task = "egress stream processing",
227 "long-running background task finished"
228 )
229 }),
230 );
231
232 Ok((tx_out, rx_in))
233}
234
235#[cfg(test)]
236mod tests {
237 use anyhow::Context;
238 use futures::SinkExt;
239
240 use super::*;
241
242 struct AsyncBinaryStreamChannel {
243 read: async_channel_io::ChannelReader,
244 write: async_channel_io::ChannelWriter,
245 }
246
247 impl AsyncBinaryStreamChannel {
248 pub fn new() -> Self {
249 let (write, read) = async_channel_io::pipe();
250 Self { read, write }
251 }
252 }
253
254 impl AsyncRead for AsyncBinaryStreamChannel {
255 fn poll_read(
256 self: std::pin::Pin<&mut Self>,
257 cx: &mut std::task::Context<'_>,
258 buf: &mut [u8],
259 ) -> std::task::Poll<std::io::Result<usize>> {
260 let mut pinned = std::pin::pin!(&mut self.get_mut().read);
261 pinned.as_mut().poll_read(cx, buf)
262 }
263 }
264
265 impl AsyncWrite for AsyncBinaryStreamChannel {
266 fn poll_write(
267 self: std::pin::Pin<&mut Self>,
268 cx: &mut std::task::Context<'_>,
269 buf: &[u8],
270 ) -> std::task::Poll<std::io::Result<usize>> {
271 let mut pinned = std::pin::pin!(&mut self.get_mut().write);
272 pinned.as_mut().poll_write(cx, buf)
273 }
274
275 fn poll_flush(
276 self: std::pin::Pin<&mut Self>,
277 cx: &mut std::task::Context<'_>,
278 ) -> std::task::Poll<std::io::Result<()>> {
279 let pinned = std::pin::pin!(&mut self.get_mut().write);
280 pinned.poll_flush(cx)
281 }
282
283 fn poll_close(
284 self: std::pin::Pin<&mut Self>,
285 cx: &mut std::task::Context<'_>,
286 ) -> std::task::Poll<std::io::Result<()>> {
287 let pinned = std::pin::pin!(&mut self.get_mut().write);
288 pinned.poll_close(cx)
289 }
290 }
291
292 #[tokio::test]
293 async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
294 let stream = AsyncBinaryStreamChannel::new();
295 let codec = tokio_util::codec::BytesCodec::new();
296
297 let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
298 let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
299
300 let (stream_rx, stream_tx) = stream.split();
301 let (mut tx, rx) = (
302 FramedWrite::new(stream_tx.compat_write(), codec),
303 FramedRead::new(stream_rx.compat(), codec),
304 );
305 tx.send(value)
306 .await
307 .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
308
309 futures::pin_mut!(rx);
310
311 assert_eq!(
312 rx.next().await.context("Value must be present")??,
313 tokio_util::bytes::BytesMut::from(expected.as_ref())
314 );
315
316 Ok(())
317 }
318}