1use std::sync::Arc;
5
6use futures::{
7 AsyncRead, AsyncReadExt, AsyncWrite, FutureExt, SinkExt as _, Stream, StreamExt,
8 channel::mpsc::{Receiver, Sender, channel},
9};
10use libp2p::PeerId;
11use tokio_util::{
12 codec::{Decoder, Encoder, FramedRead, FramedWrite},
13 compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
14};
15
16const GLOBAL_STREAM_OPEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
20const MAX_CONCURRENT_PACKETS: usize = 30;
21
22#[async_trait::async_trait]
23pub trait BidirectionalStreamControl: std::fmt::Debug {
24 fn accept(
25 self,
26 ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
27
28 async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
29}
30
31fn build_peer_stream_io<S, C>(
32 peer: PeerId,
33 stream: S,
34 cache: moka::future::Cache<PeerId, Sender<<C as Decoder>::Item>>,
35 codec: C,
36 ingress_from_peers: Sender<(PeerId, <C as Decoder>::Item)>,
37) -> Sender<<C as Decoder>::Item>
38where
39 S: AsyncRead + AsyncWrite + Send + 'static,
40 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
41 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
42 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
43 <C as Decoder>::Item: Clone + Send + 'static,
44{
45 let (stream_rx, stream_tx) = stream.split();
46 let (send, recv) = channel::<<C as Decoder>::Item>(1000);
47 let cache_internal = cache.clone();
48
49 let mut frame_writer = FramedWrite::new(stream_tx.compat_write(), codec.clone());
50
51 frame_writer.set_backpressure_boundary(1);
53
54 hopr_async_runtime::prelude::spawn(
56 recv.inspect(move |_| tracing::trace!(%peer, "writing message to peer stream"))
57 .map(Ok)
58 .forward(frame_writer)
59 .inspect(move |res| {
60 tracing::debug!(%peer, ?res, component = "stream", "writing stream with peer finished");
61 }),
62 );
63
64 hopr_async_runtime::prelude::spawn(
66 FramedRead::new(stream_rx.compat(), codec)
67 .filter_map(move |v| async move {
68 match v {
69 Ok(v) => {
70 tracing::trace!(%peer, "read message from peer stream");
71 Some((peer, v))
72 }
73 Err(error) => {
74 tracing::error!(%error, "Error decoding object from the underlying stream");
75 None
76 }
77 }
78 })
79 .map(Ok)
80 .forward(ingress_from_peers)
81 .inspect(move |res| match res {
82 Ok(_) => tracing::debug!(%peer, component = "stream", "incoming stream done reading"),
83 Err(error) => {
84 tracing::error!(%peer, %error, component = "stream", "incoming stream failed on reading")
85 }
86 })
87 .then(move |_| {
88 let peer = peer;
90 async move {
91 cache_internal.invalidate(&peer).await;
92 }
93 }),
94 );
95
96 tracing::trace!(%peer, "created new io for peer");
97 send
98}
99
100pub async fn process_stream_protocol<C, V>(
101 codec: C,
102 control: V,
103) -> crate::errors::Result<(
104 Sender<(PeerId, <C as Decoder>::Item)>, Receiver<(PeerId, <C as Decoder>::Item)>, )>
107where
108 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
109 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
110 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
111 <C as Decoder>::Item: Clone + Send + 'static,
112 V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
113{
114 let capacity = std::env::var("HOPR_INTERNAL_PROCESS_STREAM_CHANNEL_CAPACITY")
115 .ok()
116 .and_then(|s| s.trim().parse::<usize>().ok())
117 .filter(|&c| c >= 1024)
118 .unwrap_or(1_000_000);
119
120 let (tx_out, rx_out) = channel::<(PeerId, <C as Decoder>::Item)>(capacity);
121 let (tx_in, rx_in) = channel::<(PeerId, <C as Decoder>::Item)>(capacity);
122
123 let cache_out = moka::future::Cache::builder()
124 .max_capacity(2000)
125 .eviction_listener(|key: Arc<PeerId>, _, cause| {
126 tracing::trace!(peer = %key.as_ref(), ?cause, "evicting stream for peer");
127 })
128 .build();
129
130 let incoming = control
131 .clone()
132 .accept()
133 .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
134
135 let cache_ingress = cache_out.clone();
136 let codec_ingress = codec.clone();
137 let tx_in_ingress = tx_in.clone();
138
139 let _ingress_process = hopr_async_runtime::prelude::spawn(
141 incoming
142 .for_each(move |(peer, stream)| {
143 let codec = codec_ingress.clone();
144 let cache = cache_ingress.clone();
145 let tx_in = tx_in_ingress.clone();
146
147 tracing::debug!(%peer, "received incoming peer-to-peer stream");
148
149 let send = build_peer_stream_io(peer, stream, cache.clone(), codec.clone(), tx_in.clone());
150
151 async move {
152 cache.insert(peer, send).await;
153 }
154 })
155 .inspect(|_| {
156 tracing::info!(
157 task = "ingress stream processing",
158 "long-running background task finished"
159 )
160 }),
161 );
162
163 let max_concurrent_packets = std::env::var("HOPR_TRANSPORT_MAX_CONCURRENT_PACKETS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(MAX_CONCURRENT_PACKETS);
167
168 let global_stream_open_timeout = std::env::var("HOPR_TRANSPORT_STREAM_OPEN_TIMEOUT_MS")
169 .ok()
170 .and_then(|v| v.parse().ok())
171 .map(std::time::Duration::from_millis)
172 .unwrap_or(GLOBAL_STREAM_OPEN_TIMEOUT);
173
174 let _egress_process = hopr_async_runtime::prelude::spawn(
176 rx_out
177 .inspect(|(peer, _)| tracing::trace!(%peer, "proceeding to deliver message to peer"))
178 .for_each_concurrent(max_concurrent_packets, move |(peer, msg)| {
179 let cache = cache_out.clone();
180 let control = control.clone();
181 let codec = codec.clone();
182 let tx_in = tx_in.clone();
183
184 async move {
185 let cache_clone = cache.clone();
186 tracing::trace!(%peer, "trying to deliver message to peer");
187
188 let cached: Result<Sender<<C as Decoder>::Item>, Arc<anyhow::Error>> = cache
189 .try_get_with(peer, async move {
190 tracing::trace!(%peer, "peer is not in cache, opening new stream");
191
192 use futures_time::future::FutureExt as TimeExt;
197 let stream = control
198 .open(peer)
199 .timeout(futures_time::time::Duration::from(global_stream_open_timeout))
200 .await
201 .map_err(|_| anyhow::anyhow!("timeout trying to open stream to {peer}"))?
202 .map_err(|e| anyhow::anyhow!("could not open outgoing peer-to-peer stream: {e}"))?;
203
204 tracing::debug!(%peer, "opening outgoing peer-to-peer stream");
205
206 Ok(build_peer_stream_io(
207 peer,
208 stream,
209 cache_clone.clone(),
210 codec.clone(),
211 tx_in.clone(),
212 ))
213 })
214 .await;
215
216 match cached {
217 Ok(mut cached) => {
218 if let Err(error) = cached.send(msg).await {
219 tracing::error!(%peer, %error, "error sending message to peer");
220 cache.invalidate(&peer).await;
221 } else {
222 tracing::trace!(%peer, "message sent to peer");
223 }
224 }
225 Err(error) => {
226 tracing::error!(%peer, %error, "failed to open a stream to peer");
227 }
228 }
229 }
230 })
231 .inspect(|_| {
232 tracing::info!(
233 task = "egress stream processing",
234 "long-running background task finished"
235 )
236 }),
237 );
238
239 Ok((tx_out, rx_in))
240}
241
242#[cfg(test)]
243mod tests {
244 use anyhow::Context;
245 use futures::SinkExt;
246
247 use super::*;
248
249 struct AsyncBinaryStreamChannel {
250 read: async_channel_io::ChannelReader,
251 write: async_channel_io::ChannelWriter,
252 }
253
254 impl AsyncBinaryStreamChannel {
255 pub fn new() -> Self {
256 let (write, read) = async_channel_io::pipe();
257 Self { read, write }
258 }
259 }
260
261 impl AsyncRead for AsyncBinaryStreamChannel {
262 fn poll_read(
263 self: std::pin::Pin<&mut Self>,
264 cx: &mut std::task::Context<'_>,
265 buf: &mut [u8],
266 ) -> std::task::Poll<std::io::Result<usize>> {
267 let mut pinned = std::pin::pin!(&mut self.get_mut().read);
268 pinned.as_mut().poll_read(cx, buf)
269 }
270 }
271
272 impl AsyncWrite for AsyncBinaryStreamChannel {
273 fn poll_write(
274 self: std::pin::Pin<&mut Self>,
275 cx: &mut std::task::Context<'_>,
276 buf: &[u8],
277 ) -> std::task::Poll<std::io::Result<usize>> {
278 let mut pinned = std::pin::pin!(&mut self.get_mut().write);
279 pinned.as_mut().poll_write(cx, buf)
280 }
281
282 fn poll_flush(
283 self: std::pin::Pin<&mut Self>,
284 cx: &mut std::task::Context<'_>,
285 ) -> std::task::Poll<std::io::Result<()>> {
286 let pinned = std::pin::pin!(&mut self.get_mut().write);
287 pinned.poll_flush(cx)
288 }
289
290 fn poll_close(
291 self: std::pin::Pin<&mut Self>,
292 cx: &mut std::task::Context<'_>,
293 ) -> std::task::Poll<std::io::Result<()>> {
294 let pinned = std::pin::pin!(&mut self.get_mut().write);
295 pinned.poll_close(cx)
296 }
297 }
298
299 #[tokio::test]
300 async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
301 let stream = AsyncBinaryStreamChannel::new();
302 let codec = tokio_util::codec::BytesCodec::new();
303
304 let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
305 let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
306
307 let (stream_rx, stream_tx) = stream.split();
308 let (mut tx, rx) = (
309 FramedWrite::new(stream_tx.compat_write(), codec),
310 FramedRead::new(stream_rx.compat(), codec),
311 );
312 tx.send(value)
313 .await
314 .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
315
316 futures::pin_mut!(rx);
317
318 assert_eq!(
319 rx.next().await.context("Value must be present")??,
320 tokio_util::bytes::BytesMut::from(expected.as_ref())
321 );
322
323 Ok(())
324 }
325}