hopr_protocol_session/socket/
state.rs1use futures::channel::mpsc::Sender;
2
3use crate::{
4 errors::SessionError,
5 processing::types::FrameInspector,
6 protocol::{FrameAcknowledgements, FrameId, Segment, SegmentId, SegmentRequest, SeqIndicator, SessionMessage},
7};
8
9pub struct SocketComponents<const C: usize> {
13 pub inspector: Option<FrameInspector>,
18 pub ctl_tx: Sender<SessionMessage<C>>,
22}
23
24pub trait SocketState<const C: usize>: Send {
26 fn session_id(&self) -> &str;
28
29 fn run(&mut self, components: SocketComponents<C>) -> Result<(), SessionError>;
32
33 fn stop(&mut self) -> Result<(), SessionError>;
35
36 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
39
40 fn incoming_retransmission_request(&mut self, request: SegmentRequest<C>) -> Result<(), SessionError>;
42
43 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<C>) -> Result<(), SessionError>;
45
46 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
48
49 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
51
52 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
54
55 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
57
58 fn incoming_message(&mut self, message: &SessionMessage<C>) -> Result<(), SessionError> {
60 match &message {
61 SessionMessage::Segment(s) => self.incoming_segment(&s.id(), s.seq_flags),
62 SessionMessage::Request(r) => self.incoming_retransmission_request(r.clone()),
63 SessionMessage::Acknowledge(a) => self.incoming_acknowledged_frames(a.clone()),
64 }
65 }
66}
67
68#[derive(Clone)]
72pub struct Stateless<const C: usize>(String);
73
74impl<const C: usize> Stateless<C> {
75 pub(crate) fn new<I: std::fmt::Display>(session_id: I) -> Self {
76 Self(session_id.to_string())
77 }
78}
79
80impl<const C: usize> SocketState<C> for Stateless<C> {
81 fn session_id(&self) -> &str {
82 &self.0
83 }
84
85 fn run(&mut self, _: SocketComponents<C>) -> Result<(), SessionError> {
86 Ok(())
87 }
88
89 fn stop(&mut self) -> Result<(), SessionError> {
90 Ok(())
91 }
92
93 fn incoming_segment(&mut self, _: &SegmentId, _: SeqIndicator) -> Result<(), SessionError> {
94 Ok(())
95 }
96
97 fn incoming_retransmission_request(&mut self, _: SegmentRequest<C>) -> Result<(), SessionError> {
98 Ok(())
99 }
100
101 fn incoming_acknowledged_frames(&mut self, _: FrameAcknowledgements<C>) -> Result<(), SessionError> {
102 Ok(())
103 }
104
105 fn frame_complete(&mut self, _: FrameId) -> Result<(), SessionError> {
106 Ok(())
107 }
108
109 fn frame_emitted(&mut self, _: FrameId) -> Result<(), SessionError> {
110 Ok(())
111 }
112
113 fn frame_discarded(&mut self, _: FrameId) -> Result<(), SessionError> {
114 Ok(())
115 }
116
117 fn segment_sent(&mut self, _: &Segment) -> Result<(), SessionError> {
118 Ok(())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use std::{collections::HashSet, time::Duration};
125
126 use anyhow::Context;
127 use futures::{AsyncReadExt, AsyncWriteExt};
128 use futures_time::future::FutureExt;
129
130 use super::*;
131 #[cfg(feature = "telemetry")]
132 use crate::socket::telemetry::NoopTracker;
133 use crate::{
134 SessionSocket, SessionSocketConfig,
135 utils::test::{FaultyNetworkConfig, setup_alice_bob},
136 };
137
138 const FRAME_SIZE: usize = 1500;
139
140 const MTU: usize = 1000;
141
142 mockall::mock! {
143 SockState {}
144 impl SocketState<MTU> for SockState {
145 fn session_id(&self) -> &str;
146 fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError>;
147 fn stop(&mut self) -> Result<(), SessionError>;
148 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
149 fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError>;
150 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError>;
151 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
152 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
153 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
154 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
155 }
156 }
157
158 #[derive(Clone)]
159 struct CloneableMockState<'a>(std::sync::Arc<std::sync::Mutex<MockSockState>>, &'a str);
160
161 impl<'a> CloneableMockState<'a> {
162 pub fn new(state: MockSockState, id: &'a str) -> Self {
163 Self(std::sync::Arc::new(std::sync::Mutex::new(state)), id)
164 }
165 }
166
167 impl SocketState<MTU> for CloneableMockState<'_> {
168 fn session_id(&self) -> &str {
169 let _ = self.0.lock().unwrap().session_id();
170 self.1
171 }
172
173 fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError> {
174 tracing::debug!(id = self.1, "run called");
175 self.0.lock().unwrap().run(components)
176 }
177
178 fn stop(&mut self) -> Result<(), SessionError> {
179 tracing::debug!(id = self.1, "stop called");
180 self.0.lock().unwrap().stop()
181 }
182
183 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError> {
184 tracing::debug!(id = self.1, "incoming_segment called");
185 self.0.lock().unwrap().incoming_segment(id, ind)
186 }
187
188 fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError> {
189 tracing::debug!(id = self.1, "incoming_retransmission_request called");
190 self.0.lock().unwrap().incoming_retransmission_request(request)
191 }
192
193 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError> {
194 tracing::debug!(id = self.1, "incoming_acknowledged_frames called");
195 self.0.lock().unwrap().incoming_acknowledged_frames(ack)
196 }
197
198 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError> {
199 tracing::debug!(id = self.1, "frame_complete called");
200 self.0.lock().unwrap().frame_complete(id)
201 }
202
203 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError> {
204 tracing::debug!(id = self.1, "frame_received called");
205 self.0.lock().unwrap().frame_emitted(id)
206 }
207
208 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError> {
209 tracing::debug!(id = self.1, "frame_discarded called");
210 self.0.lock().unwrap().frame_discarded(id)
211 }
212
213 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError> {
214 tracing::debug!(id = self.1, "segment_sent called");
215 self.0.lock().unwrap().segment_sent(segment)
216 }
217 }
218
219 #[test_log::test(tokio::test)]
220 async fn session_socket_must_correctly_dispatch_segment_and_frame_state_events() -> anyhow::Result<()> {
221 const NUM_FRAMES: usize = 2;
222
223 const NUM_SEGMENTS: usize = NUM_FRAMES * FRAME_SIZE / MTU + 1;
224
225 let mut alice_seq = mockall::Sequence::new();
226 let mut alice_state = MockSockState::new();
227 alice_state.expect_session_id().return_const("alice".into());
228
229 alice_state
230 .expect_run()
231 .once()
232 .in_sequence(&mut alice_seq)
233 .return_once(|_| Ok::<_, SessionError>(()));
234 alice_state
235 .expect_segment_sent()
236 .times(NUM_SEGMENTS)
237 .in_sequence(&mut alice_seq)
238 .returning(|_| Ok::<_, SessionError>(()));
239 alice_state
240 .expect_stop()
241 .once()
242 .in_sequence(&mut alice_seq)
243 .return_once(|| Ok::<_, SessionError>(()));
244 alice_state
245 .expect_segment_sent() .once()
247 .in_sequence(&mut alice_seq)
248 .return_once(|_| Ok::<_, SessionError>(()));
249
250 let mut bob_seq = mockall::Sequence::new();
251 let mut bob_state = MockSockState::new();
252 bob_state.expect_session_id().return_const("bob".into());
253
254 bob_state
255 .expect_run()
256 .once()
257 .in_sequence(&mut bob_seq)
258 .return_once(|_| Ok::<_, SessionError>(()));
259 bob_state
260 .expect_incoming_segment()
261 .times(NUM_SEGMENTS - 1)
262 .in_sequence(&mut bob_seq)
263 .returning(|_, _| Ok::<_, SessionError>(()));
264 bob_state
265 .expect_frame_complete()
266 .once()
267 .in_sequence(&mut bob_seq)
268 .with(mockall::predicate::eq(2))
269 .returning(|_| Ok::<_, SessionError>(()));
270 bob_state
271 .expect_frame_discarded()
272 .once()
273 .in_sequence(&mut bob_seq)
274 .with(mockall::predicate::eq(1))
275 .returning(|_| Ok::<_, SessionError>(()));
276 bob_state
277 .expect_frame_emitted()
278 .once()
279 .in_sequence(&mut bob_seq)
280 .with(mockall::predicate::eq(2))
281 .returning(|_| Ok::<_, SessionError>(()));
282 bob_state
283 .expect_stop()
284 .once()
285 .in_sequence(&mut bob_seq)
286 .return_once(|| Ok::<_, SessionError>(()));
287 bob_state
288 .expect_segment_sent() .once()
290 .in_sequence(&mut bob_seq)
291 .return_once(|_| Ok::<_, SessionError>(()));
292
293 let (alice, bob) = setup_alice_bob::<MTU>(
294 FaultyNetworkConfig {
295 avg_delay: Duration::from_millis(10),
296 ids_to_drop: HashSet::from_iter([0_usize]),
297 ..Default::default()
298 },
299 None,
300 None,
301 );
302
303 let cfg = SessionSocketConfig {
304 frame_size: FRAME_SIZE,
305 frame_timeout: Duration::from_millis(55),
306 ..Default::default()
307 };
308
309 let mut alice_socket = SessionSocket::new(
310 alice,
311 CloneableMockState::new(alice_state, "alice"),
312 cfg,
313 #[cfg(feature = "telemetry")]
314 NoopTracker,
315 )?;
316 let mut bob_socket = SessionSocket::new(
317 bob,
318 CloneableMockState::new(bob_state, "bob"),
319 cfg,
320 #[cfg(feature = "telemetry")]
321 NoopTracker,
322 )?;
323
324 let alice_sent_data = hopr_crypto_random::random_bytes::<{ NUM_FRAMES * FRAME_SIZE }>();
325 alice_socket
326 .write_all(&alice_sent_data)
327 .timeout(futures_time::time::Duration::from_secs(2))
328 .await
329 .context("write_all timeout")??;
330 alice_socket.flush().await?;
331
332 let mut bob_recv_data = [0u8; (NUM_FRAMES - 1) * FRAME_SIZE];
334 bob_socket
335 .read_exact(&mut bob_recv_data)
336 .timeout(futures_time::time::Duration::from_secs(2))
337 .await
338 .context("read_exact timeout")??;
339
340 tracing::debug!("stopping");
341 alice_socket.close().await?;
342 bob_socket.close().await?;
343
344 Ok(())
345 }
346}