hopr_protocol_session/socket/
state.rs1use futures::channel::mpsc::UnboundedSender;
2
3use crate::{
4 errors::SessionError,
5 processing::types::FrameInspector,
6 protocol::{FrameAcknowledgements, FrameId, Segment, SegmentId, SegmentRequest, SeqIndicator, SessionMessage},
7};
8
9pub struct SocketComponents<const C: usize> {
13 pub inspector: Option<FrameInspector>,
18 pub ctl_tx: UnboundedSender<SessionMessage<C>>,
22}
23
24pub trait SocketState<const C: usize>: Send {
26 fn session_id(&self) -> &str;
28
29 fn run(&mut self, components: SocketComponents<C>) -> Result<(), SessionError>;
32
33 fn stop(&mut self) -> Result<(), SessionError>;
35
36 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
39
40 fn incoming_retransmission_request(&mut self, request: SegmentRequest<C>) -> Result<(), SessionError>;
42
43 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<C>) -> Result<(), SessionError>;
45
46 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
48
49 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
51
52 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
54
55 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
57}
58
59#[derive(Clone)]
63pub struct Stateless<const C: usize>(String);
64
65impl<const C: usize> Stateless<C> {
66 pub(crate) fn new<I: std::fmt::Display>(session_id: I) -> Self {
67 Self(session_id.to_string())
68 }
69}
70
71impl<const C: usize> SocketState<C> for Stateless<C> {
72 fn session_id(&self) -> &str {
73 &self.0
74 }
75
76 fn run(&mut self, _: SocketComponents<C>) -> Result<(), SessionError> {
77 Ok(())
78 }
79
80 fn stop(&mut self) -> Result<(), SessionError> {
81 Ok(())
82 }
83
84 fn incoming_segment(&mut self, _: &SegmentId, _: SeqIndicator) -> Result<(), SessionError> {
85 Ok(())
86 }
87
88 fn incoming_retransmission_request(&mut self, _: SegmentRequest<C>) -> Result<(), SessionError> {
89 Ok(())
90 }
91
92 fn incoming_acknowledged_frames(&mut self, _: FrameAcknowledgements<C>) -> Result<(), SessionError> {
93 Ok(())
94 }
95
96 fn frame_complete(&mut self, _: FrameId) -> Result<(), SessionError> {
97 Ok(())
98 }
99
100 fn frame_emitted(&mut self, _: FrameId) -> Result<(), SessionError> {
101 Ok(())
102 }
103
104 fn frame_discarded(&mut self, _: FrameId) -> Result<(), SessionError> {
105 Ok(())
106 }
107
108 fn segment_sent(&mut self, _: &Segment) -> Result<(), SessionError> {
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::{collections::HashSet, time::Duration};
116
117 use anyhow::Context;
118 use futures::{AsyncReadExt, AsyncWriteExt};
119 use futures_time::future::FutureExt;
120
121 use super::*;
122 use crate::{
123 SessionSocket, SessionSocketConfig,
124 utils::test::{FaultyNetworkConfig, setup_alice_bob},
125 };
126
127 const FRAME_SIZE: usize = 1500;
128
129 const MTU: usize = 1000;
130
131 mockall::mock! {
132 SockState {}
133 impl SocketState<MTU> for SockState {
134 fn session_id(&self) -> &str;
135 fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError>;
136 fn stop(&mut self) -> Result<(), SessionError>;
137 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
138 fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError>;
139 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError>;
140 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
141 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
142 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
143 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
144 }
145 }
146
147 #[derive(Clone)]
148 struct CloneableMockState<'a>(std::sync::Arc<std::sync::Mutex<MockSockState>>, &'a str);
149
150 impl<'a> CloneableMockState<'a> {
151 pub fn new(state: MockSockState, id: &'a str) -> Self {
152 Self(std::sync::Arc::new(std::sync::Mutex::new(state)), id)
153 }
154 }
155
156 impl SocketState<MTU> for CloneableMockState<'_> {
157 fn session_id(&self) -> &str {
158 let _ = self.0.lock().unwrap().session_id();
159 self.1
160 }
161
162 fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError> {
163 tracing::debug!(id = self.1, "run called");
164 self.0.lock().unwrap().run(components)
165 }
166
167 fn stop(&mut self) -> Result<(), SessionError> {
168 tracing::debug!(id = self.1, "stop called");
169 self.0.lock().unwrap().stop()
170 }
171
172 fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError> {
173 tracing::debug!(id = self.1, "incoming_segment called");
174 self.0.lock().unwrap().incoming_segment(id, ind)
175 }
176
177 fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError> {
178 tracing::debug!(id = self.1, "incoming_retransmission_request called");
179 self.0.lock().unwrap().incoming_retransmission_request(request)
180 }
181
182 fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError> {
183 tracing::debug!(id = self.1, "incoming_acknowledged_frames called");
184 self.0.lock().unwrap().incoming_acknowledged_frames(ack)
185 }
186
187 fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError> {
188 tracing::debug!(id = self.1, "frame_complete called");
189 self.0.lock().unwrap().frame_complete(id)
190 }
191
192 fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError> {
193 tracing::debug!(id = self.1, "frame_received called");
194 self.0.lock().unwrap().frame_emitted(id)
195 }
196
197 fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError> {
198 tracing::debug!(id = self.1, "frame_discarded called");
199 self.0.lock().unwrap().frame_discarded(id)
200 }
201
202 fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError> {
203 tracing::debug!(id = self.1, "segment_sent called");
204 self.0.lock().unwrap().segment_sent(segment)
205 }
206 }
207
208 #[test_log::test(tokio::test)]
209 async fn session_socket_must_correctly_dispatch_segment_and_frame_state_events() -> anyhow::Result<()> {
210 const NUM_FRAMES: usize = 2;
211
212 const NUM_SEGMENTS: usize = NUM_FRAMES * FRAME_SIZE / MTU + 1;
213
214 let mut alice_seq = mockall::Sequence::new();
215 let mut alice_state = MockSockState::new();
216 alice_state.expect_session_id().return_const("alice".into());
217
218 alice_state
219 .expect_run()
220 .once()
221 .in_sequence(&mut alice_seq)
222 .return_once(|_| Ok::<_, SessionError>(()));
223 alice_state
224 .expect_segment_sent()
225 .times(NUM_SEGMENTS)
226 .in_sequence(&mut alice_seq)
227 .returning(|_| Ok::<_, SessionError>(()));
228 alice_state
229 .expect_stop()
230 .once()
231 .in_sequence(&mut alice_seq)
232 .return_once(|| Ok::<_, SessionError>(()));
233 alice_state
234 .expect_segment_sent() .once()
236 .in_sequence(&mut alice_seq)
237 .return_once(|_| Ok::<_, SessionError>(()));
238
239 let mut bob_seq = mockall::Sequence::new();
240 let mut bob_state = MockSockState::new();
241 bob_state.expect_session_id().return_const("bob".into());
242
243 bob_state
244 .expect_run()
245 .once()
246 .in_sequence(&mut bob_seq)
247 .return_once(|_| Ok::<_, SessionError>(()));
248 bob_state
249 .expect_incoming_segment()
250 .times(NUM_SEGMENTS - 1)
251 .in_sequence(&mut bob_seq)
252 .returning(|_, _| Ok::<_, SessionError>(()));
253 bob_state
254 .expect_frame_complete()
255 .once()
256 .in_sequence(&mut bob_seq)
257 .with(mockall::predicate::eq(2))
258 .returning(|_| Ok::<_, SessionError>(()));
259 bob_state
260 .expect_frame_discarded()
261 .once()
262 .in_sequence(&mut bob_seq)
263 .with(mockall::predicate::eq(1))
264 .returning(|_| Ok::<_, SessionError>(()));
265 bob_state
266 .expect_frame_emitted()
267 .once()
268 .in_sequence(&mut bob_seq)
269 .with(mockall::predicate::eq(2))
270 .returning(|_| Ok::<_, SessionError>(()));
271 bob_state
272 .expect_stop()
273 .once()
274 .in_sequence(&mut bob_seq)
275 .return_once(|| Ok::<_, SessionError>(()));
276 bob_state
277 .expect_segment_sent() .once()
279 .in_sequence(&mut bob_seq)
280 .return_once(|_| Ok::<_, SessionError>(()));
281
282 let (alice, bob) = setup_alice_bob::<MTU>(
283 FaultyNetworkConfig {
284 avg_delay: Duration::from_millis(10),
285 ids_to_drop: HashSet::from_iter([0_usize]),
286 ..Default::default()
287 },
288 None,
289 None,
290 );
291
292 let cfg = SessionSocketConfig {
293 frame_size: FRAME_SIZE,
294 frame_timeout: Duration::from_millis(55),
295 ..Default::default()
296 };
297
298 let mut alice_socket = SessionSocket::new(alice, CloneableMockState::new(alice_state, "alice"), cfg)?;
299 let mut bob_socket = SessionSocket::new(bob, CloneableMockState::new(bob_state, "bob"), cfg)?;
300
301 let alice_sent_data = hopr_crypto_random::random_bytes::<{ NUM_FRAMES * FRAME_SIZE }>();
302 alice_socket
303 .write_all(&alice_sent_data)
304 .timeout(futures_time::time::Duration::from_secs(2))
305 .await
306 .context("write_all timeout")??;
307 alice_socket.flush().await?;
308
309 let mut bob_recv_data = [0u8; (NUM_FRAMES - 1) * FRAME_SIZE];
311 bob_socket
312 .read_exact(&mut bob_recv_data)
313 .timeout(futures_time::time::Duration::from_secs(2))
314 .await
315 .context("read_exact timeout")??;
316
317 tracing::debug!("stopping");
318 alice_socket.close().await?;
319 bob_socket.close().await?;
320
321 Ok(())
322 }
323}