1use futures::{AsyncRead, AsyncReadExt, AsyncWrite, SinkExt as _, Stream, StreamExt};
5use libp2p::PeerId;
6use tokio_util::{
7 codec::{Decoder, Encoder, FramedRead, FramedWrite},
8 compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt},
9};
10
11#[async_trait::async_trait]
12pub trait BidirectionalStreamControl: std::fmt::Debug {
13 fn accept(
14 self,
15 ) -> Result<impl Stream<Item = (PeerId, impl AsyncRead + AsyncWrite + Send)> + Send, impl std::error::Error>;
16
17 async fn open(self, peer: PeerId) -> Result<impl AsyncRead + AsyncWrite + Send, impl std::error::Error>;
18}
19
20pub async fn process_stream_protocol<C, V>(
21 codec: C,
22 control: V,
23) -> crate::errors::Result<(
24 futures::channel::mpsc::Sender<(PeerId, <C as Decoder>::Item)>, futures::channel::mpsc::Receiver<(PeerId, <C as Decoder>::Item)>, )>
27where
28 C: Encoder<<C as Decoder>::Item> + Decoder + Send + Sync + Clone + 'static,
29 <C as Encoder<<C as Decoder>::Item>>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
30 <C as Decoder>::Error: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
31 <C as Decoder>::Item: Send + 'static,
32 V: BidirectionalStreamControl + Clone + Send + Sync + 'static,
33{
34 let (tx_out, rx_out) = futures::channel::mpsc::channel::<(PeerId, <C as Decoder>::Item)>(10_000);
35 let (tx_in, rx_in) = futures::channel::mpsc::channel::<(PeerId, <C as Decoder>::Item)>(10_000);
36
37 let cache_out = moka::future::Cache::new(2000);
38
39 let incoming = control
40 .clone()
41 .accept()
42 .map_err(|e| crate::errors::ProtocolError::Logic(format!("failed to listen on protocol: {e}")))?;
43
44 let cache_ingress = cache_out.clone();
45 let codec_ingress = codec.clone();
46 let tx_in_ingress = tx_in.clone();
47
48 let _ingress_process = hopr_async_runtime::prelude::spawn(incoming.for_each(move |(peer_id, stream)| {
50 let codec = codec_ingress.clone();
51 let cache = cache_ingress.clone();
52 let tx_in = tx_in_ingress.clone();
53
54 tracing::debug!(peer = %peer_id, "Received incoming peer-to-peer stream");
55
56 async move {
57 let (stream_rx, stream_tx) = stream.split();
58 let (send, recv) = futures::channel::mpsc::channel::<<C as Decoder>::Item>(1000);
59
60 hopr_async_runtime::prelude::spawn(
61 recv.map(Ok)
62 .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
63 );
64 hopr_async_runtime::prelude::spawn(
65 FramedRead::new(stream_rx.compat(), codec)
66 .filter_map(move |v| async move {
67 match v {
68 Ok(v) => Some((peer_id, v)),
69 Err(e) => {
70 tracing::error!(error = %e, "Error decoding object from the underlying stream");
71 None
72 }
73 }
74 })
75 .map(Ok)
76 .forward(tx_in),
77 );
78 cache.insert(peer_id, send).await;
79 }
80 }));
81
82 let _egress_process = hopr_async_runtime::prelude::spawn(rx_out.for_each(move |(peer_id, msg)| {
84 let cache = cache_out.clone();
85 let control = control.clone();
86 let codec = codec.clone();
87 let tx_in = tx_in.clone();
88
89 async move {
90 let cache = cache.clone();
91
92 let cached = cache
93 .optionally_get_with(peer_id, async move {
94 let r = control.open(peer_id).await;
95 match r {
96 Ok(stream) => {
97 tracing::debug!(peer = %peer_id, "Opening outgoing peer-to-peer stream");
98
99 let (stream_rx, stream_tx) = stream.split();
100 let (send, recv) = futures::channel::mpsc::channel::<<C as Decoder>::Item>(1000);
101
102 hopr_async_runtime::prelude::spawn(
103 recv.map(Ok)
104 .forward(FramedWrite::new(stream_tx.compat_write(), codec.clone())),
105 );
106 hopr_async_runtime::prelude::spawn(
107 FramedRead::new(stream_rx.compat(), codec)
108 .filter_map(move |v| async move {
109 match v {
110 Ok(v) => Some((peer_id, v)),
111 Err(e) => {
112 tracing::error!(error = %e, "Error decoding object from the underlying stream");
113 None
114 }
115 }
116 })
117 .map(Ok)
118 .forward(tx_in),
119 );
120
121 Some(send)
122 }
123 Err(error) => {
124 tracing::error!(peer = %peer_id, %error, "Error opening stream to peer");
125 None
126 }
127 }
128 })
129 .await;
130
131 if let Some(mut cached) = cached {
132 cached.send(msg).await.unwrap_or_else(|e| {
133 tracing::error!(peer = %peer_id, error = %e, "Error sending message to peer");
134 });
135 } else {
136 tracing::error!(peer = %peer_id, "Error sending message to peer: the stream failed to be created and cached");
137 }
138 }
139 }));
140
141 Ok((tx_out, rx_in))
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use anyhow::Context;
148 use futures::SinkExt;
149
150 struct AsyncBinaryStreamChannel {
151 read: async_channel_io::ChannelReader,
152 write: async_channel_io::ChannelWriter,
153 }
154
155 impl AsyncBinaryStreamChannel {
156 pub fn new() -> Self {
157 let (write, read) = async_channel_io::pipe();
158 Self { read, write }
159 }
160 }
161
162 impl AsyncRead for AsyncBinaryStreamChannel {
163 fn poll_read(
164 self: std::pin::Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 buf: &mut [u8],
167 ) -> std::task::Poll<std::io::Result<usize>> {
168 let mut pinned = std::pin::pin!(&mut self.get_mut().read);
169 pinned.as_mut().poll_read(cx, buf)
170 }
171 }
172
173 impl AsyncWrite for AsyncBinaryStreamChannel {
174 fn poll_write(
175 self: std::pin::Pin<&mut Self>,
176 cx: &mut std::task::Context<'_>,
177 buf: &[u8],
178 ) -> std::task::Poll<std::io::Result<usize>> {
179 let mut pinned = std::pin::pin!(&mut self.get_mut().write);
180 pinned.as_mut().poll_write(cx, buf)
181 }
182
183 fn poll_flush(
184 self: std::pin::Pin<&mut Self>,
185 cx: &mut std::task::Context<'_>,
186 ) -> std::task::Poll<std::io::Result<()>> {
187 let pinned = std::pin::pin!(&mut self.get_mut().write);
188 pinned.poll_flush(cx)
189 }
190
191 fn poll_close(
192 self: std::pin::Pin<&mut Self>,
193 cx: &mut std::task::Context<'_>,
194 ) -> std::task::Poll<std::io::Result<()>> {
195 let pinned = std::pin::pin!(&mut self.get_mut().write);
196 pinned.poll_close(cx)
197 }
198 }
199
200 #[async_std::test]
201 async fn split_codec_should_always_produce_correct_data() -> anyhow::Result<()> {
202 let stream = AsyncBinaryStreamChannel::new();
203 let codec = tokio_util::codec::BytesCodec::new();
204
205 let expected = [0u8, 1u8, 2u8, 3u8, 4u8, 5u8];
206 let value = tokio_util::bytes::BytesMut::from(expected.as_ref());
207
208 let (stream_rx, stream_tx) = stream.split();
209 let (mut tx, rx) = (
210 FramedWrite::new(stream_tx.compat_write(), codec.clone()),
211 FramedRead::new(stream_rx.compat(), codec),
212 );
213 tx.send(value)
214 .await
215 .map_err(|_| anyhow::anyhow!("should not fail on send"))?;
216
217 futures::pin_mut!(rx);
218
219 assert_eq!(
220 rx.next().await.context("Value must be present")??,
221 tokio_util::bytes::BytesMut::from(expected.as_ref())
222 );
223
224 Ok(())
225 }
226}