1use std::sync::Arc;
5
6use futures::{
7 AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, SinkExt as _, Stream, StreamExt,
8 channel::mpsc::{Receiver, Sender, channel},
9};
10use libp2p::PeerId;
11use tokio_util::{
12 codec::{Decoder, Encoder, FramedRead, FramedWrite},
13 compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
14};
15
16const GLOBAL_STREAM_OPEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
20const MAX_CONCURRENT_PACKETS: usize = 30;
21
22#[async_trait::async_trait]
23pub trait BidirectionalStreamControl: std::fmt::Debug {
24 fn accept(
25 self,
26 ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
27
28 async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
29}
30
31fn build_peer_stream_io<S, C>(
32 peer: PeerId,
33 stream: S,
34 cache: moka::future::Cache<PeerId, Sender<<C as Decoder>::Item>>,
35 codec: C,
36 ingress_from_peers: Sender<(PeerId, <C as Decoder>::Item)>,
37) -> Sender<<C as Decoder>::Item>
38where
39 S: AsyncRead + AsyncWrite + Send + 'static,
40 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
41 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
42 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
43 <C as Decoder>::Item: Clone + Send + 'static,
44{
45 let (stream_rx, stream_tx) = stream.split();
46 let (send, recv) = channel::<<C as Decoder>::Item>(1000);
47 let cache_internal = cache.clone();
48
49 let mut frame_writer = FramedWrite::new(stream_tx.compat_write(), codec.clone());
50
51 frame_writer.set_backpressure_boundary(1);
53
54 hopr_async_runtime::prelude::spawn(
56 recv.inspect(move |_| tracing::trace!(%peer, "writing message to peer stream"))
57 .map(Ok)
58 .forward(frame_writer)
59 .inspect(move |res| {
60 tracing::debug!(%peer, ?res, component = "stream", "writing stream with peer finished");
61 }),
62 );
63
64 hopr_async_runtime::prelude::spawn(
66 FramedRead::new(stream_rx.compat(), codec)
67 .filter_map(move |v| async move {
68 match v {
69 Ok(v) => {
70 tracing::trace!(%peer, "read message from peer stream");
71 Some((peer, v))
72 }
73 Err(error) => {
74 tracing::error!(%error, "Error decoding object from the underlying stream");
75 None
76 }
77 }
78 })
79 .map(Ok)
80 .forward(ingress_from_peers)
81 .inspect(move |res| match res {
82 Ok(_) => tracing::debug!(%peer, component = "stream", "incoming stream done reading"),
83 Err(error) => {
84 tracing::error!(%peer, %error, component = "stream", "incoming stream failed on reading")
85 }
86 })
87 .then(move |_| {
88 let peer = peer;
90 async move {
91 cache_internal.invalidate(&peer).await;
92 }
93 }),
94 );
95
96 tracing::trace!(%peer, "created new io for peer");
97 send
98}
99
100pub async fn process_stream_protocol<C, V>(
101 codec: C,
102 control: V,
103) -> crate::errors::Result<(
104 Sender<(PeerId, <C as Decoder>::Item)>, Receiver<(PeerId, <C as Decoder>::Item)>, )>
107where
108 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
109 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
110 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
111 <C as Decoder>::Item: Clone + Send + 'static,
112 V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
113{
114 let (tx_out, rx_out) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
115 let (tx_in, rx_in) = channel::<(PeerId, <C as Decoder>::Item)>(10_000);
116
117 let cache_out = moka::future::Cache::builder()
118 .max_capacity(2000)
119 .eviction_listener(|key: Arc<PeerId>, _, cause| {
120 tracing::trace!(peer = %key.as_ref(), ?cause, "evicting stream for peer");
121 })
122 .build();
123
124 let incoming = control
125 .clone()
126 .accept()
127 .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
128
129 let cache_ingress = cache_out.clone();
130 let codec_ingress = codec.clone();
131 let tx_in_ingress = tx_in.clone();
132
133 let _ingress_process = hopr_async_runtime::prelude::spawn(
135 incoming
136 .for_each(move |(peer, stream)| {
137 let codec = codec_ingress.clone();
138 let cache = cache_ingress.clone();
139 let tx_in = tx_in_ingress.clone();
140
141 tracing::debug!(%peer, "received incoming peer-to-peer stream");
142
143 let send = build_peer_stream_io(peer, stream, cache.clone(), codec.clone(), tx_in.clone());
144
145 async move {
146 cache.insert(peer, send).await;
147 }
148 })
149 .inspect(|_| {
150 tracing::info!(
151 task = "ingress stream processing",
152 "long-running background task finished"
153 )
154 }),
155 );
156
157 let max_concurrent_packets = std::env::var("HOPR_TRANSPORT_MAX_CONCURRENT_PACKETS")
158 .ok()
159 .and_then(|v| v.parse().ok())
160 .unwrap_or(MAX_CONCURRENT_PACKETS);
161
162 let global_stream_open_timeout = std::env::var("HOPR_TRANSPORT_STREAM_OPEN_TIMEOUT_MS")
163 .ok()
164 .and_then(|v| v.parse().ok())
165 .map(std::time::Duration::from_millis)
166 .unwrap_or(GLOBAL_STREAM_OPEN_TIMEOUT);
167
168 let _egress_process = hopr_async_runtime::prelude::spawn(
170 rx_out
171 .inspect(|(peer, _)| tracing::trace!(%peer, "proceeding to deliver message to peer"))
172 .for_each_concurrent(max_concurrent_packets, move |(peer, msg)| {
173 let cache = cache_out.clone();
174 let control = control.clone();
175 let codec = codec.clone();
176 let tx_in = tx_in.clone();
177
178 async move {
179 let cache_clone = cache.clone();
180 tracing::trace!(%peer, "trying to deliver message to peer");
181
182 let cached: Result<Sender<<C as Decoder>::Item>, Arc<anyhow::Error>> = cache
183 .try_get_with(peer, async move {
184 tracing::trace!(%peer, "peer is not in cache, opening new stream");
185
186 use futures_time::future::FutureExt as TimeExt;
191 let stream = control
192 .open(peer)
193 .timeout(futures_time::time::Duration::from(global_stream_open_timeout))
194 .await
195 .map_err(|_| anyhow::anyhow!("timeout trying to open stream to {peer}"))?
196 .map_err(|e| anyhow::anyhow!("could not open outgoing peer-to-peer stream: {e}"))?;
197
198 tracing::debug!(%peer, "opening outgoing peer-to-peer stream");
199
200 Ok(build_peer_stream_io(
201 peer,
202 stream,
203 cache_clone.clone(),
204 codec.clone(),
205 tx_in.clone(),
206 ))
207 })
208 .await;
209
210 match cached {
211 Ok(mut cached) => {
212 if let Err(error) = cached.send(msg).await {
213 tracing::error!(%peer, %error, "error sending message to peer");
214 cache.invalidate(&peer).await;
215 } else {
216 tracing::trace!(%peer, "message sent to peer");
217 }
218 }
219 Err(error) => {
220 tracing::error!(%peer, %error, "failed to open a stream to peer");
221 }
222 }
223 }
224 })
225 .inspect(|_| {
226 tracing::info!(
227 task = "egress stream processing",
228 "long-running background task finished"
229 )
230 }),
231 );
232
233 Ok((tx_out, rx_in))
234}
235
236#[cfg(test)]
237mod tests {
238 use anyhow::Context;
239 use futures::SinkExt;
240
241 use super::*;
242
243 struct AsyncBinaryStreamChannel {
244 read: async_channel_io::ChannelReader,
245 write: async_channel_io::ChannelWriter,
246 }
247
248 impl AsyncBinaryStreamChannel {
249 pub fn new() -> Self {
250 let (write, read) = async_channel_io::pipe();
251 Self { read, write }
252 }
253 }
254
255 impl AsyncRead for AsyncBinaryStreamChannel {
256 fn poll_read(
257 self: std::pin::Pin<&mut Self>,
258 cx: &mut std::task::Context<'_>,
259 buf: &mut [u8],
260 ) -> std::task::Poll<std::io::Result<usize>> {
261 let mut pinned = std::pin::pin!(&mut self.get_mut().read);
262 pinned.as_mut().poll_read(cx, buf)
263 }
264 }
265
266 impl AsyncWrite for AsyncBinaryStreamChannel {
267 fn poll_write(
268 self: std::pin::Pin<&mut Self>,
269 cx: &mut std::task::Context<'_>,
270 buf: &[u8],
271 ) -> std::task::Poll<std::io::Result<usize>> {
272 let mut pinned = std::pin::pin!(&mut self.get_mut().write);
273 pinned.as_mut().poll_write(cx, buf)
274 }
275
276 fn poll_flush(
277 self: std::pin::Pin<&mut Self>,
278 cx: &mut std::task::Context<'_>,
279 ) -> std::task::Poll<std::io::Result<()>> {
280 let pinned = std::pin::pin!(&mut self.get_mut().write);
281 pinned.poll_flush(cx)
282 }
283
284 fn poll_close(
285 self: std::pin::Pin<&mut Self>,
286 cx: &mut std::task::Context<'_>,
287 ) -> std::task::Poll<std::io::Result<()>> {
288 let pinned = std::pin::pin!(&mut self.get_mut().write);
289 pinned.poll_close(cx)
290 }
291 }
292
293 #[tokio::test]
294 async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
295 let stream = AsyncBinaryStreamChannel::new();
296 let codec = tokio_util::codec::BytesCodec::new();
297
298 let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
299 let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
300
301 let (stream_rx, stream_tx) = stream.split();
302 let (mut tx, rx) = (
303 FramedWrite::new(stream_tx.compat_write(), codec),
304 FramedRead::new(stream_rx.compat(), codec),
305 );
306 tx.send(value)
307 .await
308 .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
309
310 futures::pin_mut!(rx);
311
312 assert_eq!(
313 rx.next().await.context("Value must be present")??,
314 tokio_util::bytes::BytesMut::from(expected.as_ref())
315 );
316
317 Ok(())
318 }
319}