Skip to main content

hopr_protocol_session/socket/
state.rs

1use futures::channel::mpsc::Sender;
2
3use crate::{
4    errors::SessionError,
5    processing::types::FrameInspector,
6    protocol::{FrameAcknowledgements, FrameId, Segment, SegmentId, SegmentRequest, SeqIndicator, SessionMessage},
7};
8
9/// Components the [`SessionSocket`] exposes to a [`SocketState`].
10///
11/// This is the primary communication interface between the state and the socket.
12pub struct SocketComponents<const C: usize> {
13    /// Allows inspecting incomplete frames that are currently held by the socket.
14    ///
15    /// Some states might strictly require a frame inspector and may therefore
16    /// return an error in [`SocketState::run`] if not present.
17    pub inspector: Option<FrameInspector>,
18    /// Allows emitting control messages to the socket.
19    ///
20    /// It is a regular [`SessionMessage`] injected into the downstream.
21    pub ctl_tx: Sender<SessionMessage<C>>,
22}
23
24/// Abstraction of the [`SessionSocket`](super::SessionSocket) state.
25pub trait SocketState<const C: usize>: Send {
26    /// Gets ID of this Session.
27    fn session_id(&self) -> &str;
28
29    /// Starts the necessary processes inside the state.
30    /// Should be idempotent if called multiple times.
31    fn run(&mut self, components: SocketComponents<C>) -> Result<(), SessionError>;
32
33    /// Stops processes inside the state for the given direction.
34    fn stop(&mut self) -> Result<(), SessionError>;
35
36    /// Called when the Socket receives a new segment from Downstream.
37    /// When the error is returned, the incoming segment is not passed Upstream.
38    fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
39
40    /// Called when [segment retransmission request](SegmentRequest) is received from Downstream.
41    fn incoming_retransmission_request(&mut self, request: SegmentRequest<C>) -> Result<(), SessionError>;
42
43    /// Called when an [acknowledgement of frames](FrameAcknowledgements) is received from Downstream.
44    fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<C>) -> Result<(), SessionError>;
45
46    /// Called when a complete Frame has been finalized from segments received from Downstream.
47    fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
48
49    /// Called when a complete Frame emitted to Upstream in-sequence.
50    fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
51
52    /// Called when a frame could not be completed from the segments received from Downstream.
53    fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
54
55    /// Called when a segment of a Frame was sent to the Downstream.
56    fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
57
58    /// Convenience method to dispatch a [`message`](SessionMessage) to one of the available handlers.
59    fn incoming_message(&mut self, message: &SessionMessage<C>) -> Result<(), SessionError> {
60        match &message {
61            SessionMessage::Segment(s) => self.incoming_segment(&s.id(), s.seq_flags),
62            SessionMessage::Request(r) => self.incoming_retransmission_request(r.clone()),
63            SessionMessage::Acknowledge(a) => self.incoming_acknowledged_frames(a.clone()),
64        }
65    }
66}
67
68/// Represents a stateless Session socket.
69///
70/// Does nothing by default, only logs warnings and events for tracing.
71#[derive(Clone)]
72pub struct Stateless<const C: usize>(String);
73
74impl<const C: usize> Stateless<C> {
75    pub(crate) fn new<I: std::fmt::Display>(session_id: I) -> Self {
76        Self(session_id.to_string())
77    }
78}
79
80impl<const C: usize> SocketState<C> for Stateless<C> {
81    fn session_id(&self) -> &str {
82        &self.0
83    }
84
85    fn run(&mut self, _: SocketComponents<C>) -> Result<(), SessionError> {
86        Ok(())
87    }
88
89    fn stop(&mut self) -> Result<(), SessionError> {
90        Ok(())
91    }
92
93    fn incoming_segment(&mut self, _: &SegmentId, _: SeqIndicator) -> Result<(), SessionError> {
94        Ok(())
95    }
96
97    fn incoming_retransmission_request(&mut self, _: SegmentRequest<C>) -> Result<(), SessionError> {
98        Ok(())
99    }
100
101    fn incoming_acknowledged_frames(&mut self, _: FrameAcknowledgements<C>) -> Result<(), SessionError> {
102        Ok(())
103    }
104
105    fn frame_complete(&mut self, _: FrameId) -> Result<(), SessionError> {
106        Ok(())
107    }
108
109    fn frame_emitted(&mut self, _: FrameId) -> Result<(), SessionError> {
110        Ok(())
111    }
112
113    fn frame_discarded(&mut self, _: FrameId) -> Result<(), SessionError> {
114        Ok(())
115    }
116
117    fn segment_sent(&mut self, _: &Segment) -> Result<(), SessionError> {
118        Ok(())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::{collections::HashSet, time::Duration};
125
126    use anyhow::Context;
127    use futures::{AsyncReadExt, AsyncWriteExt};
128    use futures_time::future::FutureExt;
129
130    use super::*;
131    #[cfg(feature = "telemetry")]
132    use crate::socket::telemetry::NoopTracker;
133    use crate::{
134        SessionSocket, SessionSocketConfig,
135        utils::test::{FaultyNetworkConfig, setup_alice_bob},
136    };
137
138    const FRAME_SIZE: usize = 1500;
139
140    const MTU: usize = 1000;
141
142    mockall::mock! {
143        SockState {}
144        impl SocketState<MTU> for SockState {
145            fn session_id(&self) -> &str;
146            fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError>;
147            fn stop(&mut self) -> Result<(), SessionError>;
148            fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
149            fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError>;
150            fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError>;
151            fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
152            fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
153            fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
154            fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
155        }
156    }
157
158    #[derive(Clone)]
159    struct CloneableMockState<'a>(std::sync::Arc<std::sync::Mutex<MockSockState>>, &'a str);
160
161    impl<'a> CloneableMockState<'a> {
162        pub fn new(state: MockSockState, id: &'a str) -> Self {
163            Self(std::sync::Arc::new(std::sync::Mutex::new(state)), id)
164        }
165    }
166
167    impl SocketState<MTU> for CloneableMockState<'_> {
168        fn session_id(&self) -> &str {
169            let _ = self.0.lock().unwrap().session_id();
170            self.1
171        }
172
173        fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError> {
174            tracing::debug!(id = self.1, "run called");
175            self.0.lock().unwrap().run(components)
176        }
177
178        fn stop(&mut self) -> Result<(), SessionError> {
179            tracing::debug!(id = self.1, "stop called");
180            self.0.lock().unwrap().stop()
181        }
182
183        fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError> {
184            tracing::debug!(id = self.1, "incoming_segment called");
185            self.0.lock().unwrap().incoming_segment(id, ind)
186        }
187
188        fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError> {
189            tracing::debug!(id = self.1, "incoming_retransmission_request called");
190            self.0.lock().unwrap().incoming_retransmission_request(request)
191        }
192
193        fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError> {
194            tracing::debug!(id = self.1, "incoming_acknowledged_frames called");
195            self.0.lock().unwrap().incoming_acknowledged_frames(ack)
196        }
197
198        fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError> {
199            tracing::debug!(id = self.1, "frame_complete called");
200            self.0.lock().unwrap().frame_complete(id)
201        }
202
203        fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError> {
204            tracing::debug!(id = self.1, "frame_received called");
205            self.0.lock().unwrap().frame_emitted(id)
206        }
207
208        fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError> {
209            tracing::debug!(id = self.1, "frame_discarded called");
210            self.0.lock().unwrap().frame_discarded(id)
211        }
212
213        fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError> {
214            tracing::debug!(id = self.1, "segment_sent called");
215            self.0.lock().unwrap().segment_sent(segment)
216        }
217    }
218
219    #[test_log::test(tokio::test)]
220    async fn session_socket_must_correctly_dispatch_segment_and_frame_state_events() -> anyhow::Result<()> {
221        const NUM_FRAMES: usize = 2;
222
223        const NUM_SEGMENTS: usize = NUM_FRAMES * FRAME_SIZE / MTU + 1;
224
225        let mut alice_seq = mockall::Sequence::new();
226        let mut alice_state = MockSockState::new();
227        alice_state.expect_session_id().return_const("alice".into());
228
229        alice_state
230            .expect_run()
231            .once()
232            .in_sequence(&mut alice_seq)
233            .return_once(|_| Ok::<_, SessionError>(()));
234        alice_state
235            .expect_segment_sent()
236            .times(NUM_SEGMENTS)
237            .in_sequence(&mut alice_seq)
238            .returning(|_| Ok::<_, SessionError>(()));
239        alice_state
240            .expect_stop()
241            .once()
242            .in_sequence(&mut alice_seq)
243            .return_once(|| Ok::<_, SessionError>(()));
244        alice_state
245            .expect_segment_sent() // terminating segment
246            .once()
247            .in_sequence(&mut alice_seq)
248            .return_once(|_| Ok::<_, SessionError>(()));
249
250        let mut bob_seq = mockall::Sequence::new();
251        let mut bob_state = MockSockState::new();
252        bob_state.expect_session_id().return_const("bob".into());
253
254        bob_state
255            .expect_run()
256            .once()
257            .in_sequence(&mut bob_seq)
258            .return_once(|_| Ok::<_, SessionError>(()));
259        bob_state
260            .expect_incoming_segment()
261            .times(NUM_SEGMENTS - 1)
262            .in_sequence(&mut bob_seq)
263            .returning(|_, _| Ok::<_, SessionError>(()));
264        bob_state
265            .expect_frame_complete()
266            .once()
267            .in_sequence(&mut bob_seq)
268            .with(mockall::predicate::eq(2))
269            .returning(|_| Ok::<_, SessionError>(()));
270        bob_state
271            .expect_frame_discarded()
272            .once()
273            .in_sequence(&mut bob_seq)
274            .with(mockall::predicate::eq(1))
275            .returning(|_| Ok::<_, SessionError>(()));
276        bob_state
277            .expect_frame_emitted()
278            .once()
279            .in_sequence(&mut bob_seq)
280            .with(mockall::predicate::eq(2))
281            .returning(|_| Ok::<_, SessionError>(()));
282        bob_state
283            .expect_stop()
284            .once()
285            .in_sequence(&mut bob_seq)
286            .return_once(|| Ok::<_, SessionError>(()));
287        bob_state
288            .expect_segment_sent() // terminating segment
289            .once()
290            .in_sequence(&mut bob_seq)
291            .return_once(|_| Ok::<_, SessionError>(()));
292
293        let (alice, bob) = setup_alice_bob::<MTU>(
294            FaultyNetworkConfig {
295                avg_delay: Duration::from_millis(10),
296                ids_to_drop: HashSet::from_iter([0_usize]),
297                ..Default::default()
298            },
299            None,
300            None,
301        );
302
303        let cfg = SessionSocketConfig {
304            frame_size: FRAME_SIZE,
305            frame_timeout: Duration::from_millis(55),
306            ..Default::default()
307        };
308
309        let mut alice_socket = SessionSocket::new(
310            alice,
311            CloneableMockState::new(alice_state, "alice"),
312            cfg,
313            #[cfg(feature = "telemetry")]
314            NoopTracker,
315        )?;
316        let mut bob_socket = SessionSocket::new(
317            bob,
318            CloneableMockState::new(bob_state, "bob"),
319            cfg,
320            #[cfg(feature = "telemetry")]
321            NoopTracker,
322        )?;
323
324        let alice_sent_data = hopr_crypto_random::random_bytes::<{ NUM_FRAMES * FRAME_SIZE }>();
325        alice_socket
326            .write_all(&alice_sent_data)
327            .timeout(futures_time::time::Duration::from_secs(2))
328            .await
329            .context("write_all timeout")??;
330        alice_socket.flush().await?;
331
332        // One entire frame is discarded
333        let mut bob_recv_data = [0u8; (NUM_FRAMES - 1) * FRAME_SIZE];
334        bob_socket
335            .read_exact(&mut bob_recv_data)
336            .timeout(futures_time::time::Duration::from_secs(2))
337            .await
338            .context("read_exact timeout")??;
339
340        tracing::debug!("stopping");
341        alice_socket.close().await?;
342        bob_socket.close().await?;
343
344        Ok(())
345    }
346}