hopr_protocol_session/socket/
state.rs

1use futures::channel::mpsc::UnboundedSender;
2
3use crate::{
4    errors::SessionError,
5    processing::types::FrameInspector,
6    protocol::{FrameAcknowledgements, FrameId, Segment, SegmentId, SegmentRequest, SeqIndicator, SessionMessage},
7};
8
9/// Components the [`SessionSocket`] exposes to a [`SocketState`].
10///
11/// This is the primary communication interface between the state and the socket.
12pub struct SocketComponents<const C: usize> {
13    /// Allows inspecting incomplete frames that are currently held by the socket.
14    ///
15    /// Some states might strictly require a frame inspector and may therefore
16    /// return an error in [`SocketState::run`] if not present.
17    pub inspector: Option<FrameInspector>,
18    /// Allows emitting control messages to the socket.
19    ///
20    /// It is a regular [`SessionMessage`] injected into the downstream.
21    pub ctl_tx: UnboundedSender<SessionMessage<C>>,
22}
23
24/// Abstraction of the [`SessionSocket`](super::SessionSocket) state.
25pub trait SocketState<const C: usize>: Send {
26    /// Gets ID of this Session.
27    fn session_id(&self) -> &str;
28
29    /// Starts the necessary processes inside the state.
30    /// Should be idempotent if called multiple times.
31    fn run(&mut self, components: SocketComponents<C>) -> Result<(), SessionError>;
32
33    /// Stops processes inside the state for the given direction.
34    fn stop(&mut self) -> Result<(), SessionError>;
35
36    /// Called when the Socket receives a new segment from Downstream.
37    /// When the error is returned, the incoming segment is not passed Upstream.
38    fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
39
40    /// Called when [segment retransmission request](SegmentRequest) is received from Downstream.
41    fn incoming_retransmission_request(&mut self, request: SegmentRequest<C>) -> Result<(), SessionError>;
42
43    /// Called when an [acknowledgement of frames](FrameAcknowledgements) is received from Downstream.
44    fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<C>) -> Result<(), SessionError>;
45
46    /// Called when a complete Frame has been finalized from segments received from Downstream.
47    fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
48
49    /// Called when a complete Frame emitted to Upstream in-sequence.
50    fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
51
52    /// Called when a frame could not be completed from the segments received from Downstream.
53    fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
54
55    /// Called when a segment of a Frame was sent to the Downstream.
56    fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
57}
58
59/// Represents a stateless Session socket.
60///
61/// Does nothing by default, only logs warnings and events for tracing.
62#[derive(Clone)]
63pub struct Stateless<const C: usize>(String);
64
65impl<const C: usize> Stateless<C> {
66    pub(crate) fn new<I: std::fmt::Display>(session_id: I) -> Self {
67        Self(session_id.to_string())
68    }
69}
70
71impl<const C: usize> SocketState<C> for Stateless<C> {
72    fn session_id(&self) -> &str {
73        &self.0
74    }
75
76    fn run(&mut self, _: SocketComponents<C>) -> Result<(), SessionError> {
77        Ok(())
78    }
79
80    fn stop(&mut self) -> Result<(), SessionError> {
81        Ok(())
82    }
83
84    fn incoming_segment(&mut self, _: &SegmentId, _: SeqIndicator) -> Result<(), SessionError> {
85        Ok(())
86    }
87
88    fn incoming_retransmission_request(&mut self, _: SegmentRequest<C>) -> Result<(), SessionError> {
89        Ok(())
90    }
91
92    fn incoming_acknowledged_frames(&mut self, _: FrameAcknowledgements<C>) -> Result<(), SessionError> {
93        Ok(())
94    }
95
96    fn frame_complete(&mut self, _: FrameId) -> Result<(), SessionError> {
97        Ok(())
98    }
99
100    fn frame_emitted(&mut self, _: FrameId) -> Result<(), SessionError> {
101        Ok(())
102    }
103
104    fn frame_discarded(&mut self, _: FrameId) -> Result<(), SessionError> {
105        Ok(())
106    }
107
108    fn segment_sent(&mut self, _: &Segment) -> Result<(), SessionError> {
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::{collections::HashSet, time::Duration};
116
117    use anyhow::Context;
118    use futures::{AsyncReadExt, AsyncWriteExt};
119    use futures_time::future::FutureExt;
120
121    use super::*;
122    use crate::{
123        SessionSocket, SessionSocketConfig,
124        utils::test::{FaultyNetworkConfig, setup_alice_bob},
125    };
126
127    const FRAME_SIZE: usize = 1500;
128
129    const MTU: usize = 1000;
130
131    mockall::mock! {
132        SockState {}
133        impl SocketState<MTU> for SockState {
134            fn session_id(&self) -> &str;
135            fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError>;
136            fn stop(&mut self) -> Result<(), SessionError>;
137            fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
138            fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError>;
139            fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError>;
140            fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
141            fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
142            fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
143            fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
144        }
145    }
146
147    #[derive(Clone)]
148    struct CloneableMockState<'a>(std::sync::Arc<std::sync::Mutex<MockSockState>>, &'a str);
149
150    impl<'a> CloneableMockState<'a> {
151        pub fn new(state: MockSockState, id: &'a str) -> Self {
152            Self(std::sync::Arc::new(std::sync::Mutex::new(state)), id)
153        }
154    }
155
156    impl SocketState<MTU> for CloneableMockState<'_> {
157        fn session_id(&self) -> &str {
158            let _ = self.0.lock().unwrap().session_id();
159            self.1
160        }
161
162        fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError> {
163            tracing::debug!(id = self.1, "run called");
164            self.0.lock().unwrap().run(components)
165        }
166
167        fn stop(&mut self) -> Result<(), SessionError> {
168            tracing::debug!(id = self.1, "stop called");
169            self.0.lock().unwrap().stop()
170        }
171
172        fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError> {
173            tracing::debug!(id = self.1, "incoming_segment called");
174            self.0.lock().unwrap().incoming_segment(id, ind)
175        }
176
177        fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError> {
178            tracing::debug!(id = self.1, "incoming_retransmission_request called");
179            self.0.lock().unwrap().incoming_retransmission_request(request)
180        }
181
182        fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError> {
183            tracing::debug!(id = self.1, "incoming_acknowledged_frames called");
184            self.0.lock().unwrap().incoming_acknowledged_frames(ack)
185        }
186
187        fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError> {
188            tracing::debug!(id = self.1, "frame_complete called");
189            self.0.lock().unwrap().frame_complete(id)
190        }
191
192        fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError> {
193            tracing::debug!(id = self.1, "frame_received called");
194            self.0.lock().unwrap().frame_emitted(id)
195        }
196
197        fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError> {
198            tracing::debug!(id = self.1, "frame_discarded called");
199            self.0.lock().unwrap().frame_discarded(id)
200        }
201
202        fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError> {
203            tracing::debug!(id = self.1, "segment_sent called");
204            self.0.lock().unwrap().segment_sent(segment)
205        }
206    }
207
208    #[test_log::test(tokio::test)]
209    async fn session_socket_must_correctly_dispatch_segment_and_frame_state_events() -> anyhow::Result<()> {
210        const NUM_FRAMES: usize = 2;
211
212        const NUM_SEGMENTS: usize = NUM_FRAMES * FRAME_SIZE / MTU + 1;
213
214        let mut alice_seq = mockall::Sequence::new();
215        let mut alice_state = MockSockState::new();
216        alice_state.expect_session_id().return_const("alice".into());
217
218        alice_state
219            .expect_run()
220            .once()
221            .in_sequence(&mut alice_seq)
222            .return_once(|_| Ok::<_, SessionError>(()));
223        alice_state
224            .expect_segment_sent()
225            .times(NUM_SEGMENTS)
226            .in_sequence(&mut alice_seq)
227            .returning(|_| Ok::<_, SessionError>(()));
228        alice_state
229            .expect_stop()
230            .once()
231            .in_sequence(&mut alice_seq)
232            .return_once(|| Ok::<_, SessionError>(()));
233        alice_state
234            .expect_segment_sent() // terminating segment
235            .once()
236            .in_sequence(&mut alice_seq)
237            .return_once(|_| Ok::<_, SessionError>(()));
238
239        let mut bob_seq = mockall::Sequence::new();
240        let mut bob_state = MockSockState::new();
241        bob_state.expect_session_id().return_const("bob".into());
242
243        bob_state
244            .expect_run()
245            .once()
246            .in_sequence(&mut bob_seq)
247            .return_once(|_| Ok::<_, SessionError>(()));
248        bob_state
249            .expect_incoming_segment()
250            .times(NUM_SEGMENTS - 1)
251            .in_sequence(&mut bob_seq)
252            .returning(|_, _| Ok::<_, SessionError>(()));
253        bob_state
254            .expect_frame_complete()
255            .once()
256            .in_sequence(&mut bob_seq)
257            .with(mockall::predicate::eq(2))
258            .returning(|_| Ok::<_, SessionError>(()));
259        bob_state
260            .expect_frame_discarded()
261            .once()
262            .in_sequence(&mut bob_seq)
263            .with(mockall::predicate::eq(1))
264            .returning(|_| Ok::<_, SessionError>(()));
265        bob_state
266            .expect_frame_emitted()
267            .once()
268            .in_sequence(&mut bob_seq)
269            .with(mockall::predicate::eq(2))
270            .returning(|_| Ok::<_, SessionError>(()));
271        bob_state
272            .expect_stop()
273            .once()
274            .in_sequence(&mut bob_seq)
275            .return_once(|| Ok::<_, SessionError>(()));
276        bob_state
277            .expect_segment_sent() // terminating segment
278            .once()
279            .in_sequence(&mut bob_seq)
280            .return_once(|_| Ok::<_, SessionError>(()));
281
282        let (alice, bob) = setup_alice_bob::<MTU>(
283            FaultyNetworkConfig {
284                avg_delay: Duration::from_millis(10),
285                ids_to_drop: HashSet::from_iter([0_usize]),
286                ..Default::default()
287            },
288            None,
289            None,
290        );
291
292        let cfg = SessionSocketConfig {
293            frame_size: FRAME_SIZE,
294            frame_timeout: Duration::from_millis(55),
295            ..Default::default()
296        };
297
298        let mut alice_socket = SessionSocket::new(alice, CloneableMockState::new(alice_state, "alice"), cfg)?;
299        let mut bob_socket = SessionSocket::new(bob, CloneableMockState::new(bob_state, "bob"), cfg)?;
300
301        let alice_sent_data = hopr_crypto_random::random_bytes::<{ NUM_FRAMES * FRAME_SIZE }>();
302        alice_socket
303            .write_all(&alice_sent_data)
304            .timeout(futures_time::time::Duration::from_secs(2))
305            .await
306            .context("write_all timeout")??;
307        alice_socket.flush().await?;
308
309        // One entire frame is discarded
310        let mut bob_recv_data = [0u8; (NUM_FRAMES - 1) * FRAME_SIZE];
311        bob_socket
312            .read_exact(&mut bob_recv_data)
313            .timeout(futures_time::time::Duration::from_secs(2))
314            .await
315            .context("read_exact timeout")??;
316
317        tracing::debug!("stopping");
318        alice_socket.close().await?;
319        bob_socket.close().await?;
320
321        Ok(())
322    }
323}