Skip to main content

hopr_protocol_session/socket/
state.rs

1use futures::channel::mpsc::Sender;
2
3use crate::{
4    errors::SessionError,
5    processing::types::FrameInspector,
6    protocol::{FrameAcknowledgements, FrameId, Segment, SegmentId, SegmentRequest, SeqIndicator, SessionMessage},
7};
8
9/// Components the `SessionSocket` exposes to a [`SocketState`].
10///
11/// This is the primary communication interface between the state and the socket.
12pub struct SocketComponents<const C: usize> {
13    /// Allows inspecting incomplete frames that are currently held by the socket.
14    ///
15    /// Some states might strictly require a frame inspector and may therefore
16    /// return an error in [`SocketState::run`] if not present.
17    pub inspector: Option<FrameInspector>,
18    /// Allows emitting control messages to the socket.
19    ///
20    /// It is a regular `SessionMessage` injected into the downstream.
21    pub ctl_tx: Sender<SessionMessage<C>>,
22}
23
24/// Abstraction of the `SessionSocket` state.
25pub trait SocketState<const C: usize>: Send {
26    /// Gets ID of this Session.
27    fn session_id(&self) -> &str;
28
29    /// Starts the necessary processes inside the state.
30    /// Should be idempotent if called multiple times.
31    fn run(&mut self, components: SocketComponents<C>) -> Result<(), SessionError>;
32
33    /// Stops processes inside the state for the given direction.
34    fn stop(&mut self) -> Result<(), SessionError>;
35
36    /// Called when the Socket receives a new segment from Downstream.
37    /// When the error is returned, the incoming segment is not passed Upstream.
38    fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
39
40    /// Called when [segment retransmission request](SegmentRequest) is received from Downstream.
41    fn incoming_retransmission_request(&mut self, request: SegmentRequest<C>) -> Result<(), SessionError>;
42
43    /// Called when an [acknowledgement of frames](FrameAcknowledgements) is received from Downstream.
44    fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<C>) -> Result<(), SessionError>;
45
46    /// Called when a complete Frame has been finalized from segments received from Downstream.
47    fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
48
49    /// Called when a complete Frame emitted to Upstream in-sequence.
50    fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
51
52    /// Called when a frame could not be completed from the segments received from Downstream.
53    fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
54
55    /// Called when a segment of a Frame was sent to the Downstream.
56    fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
57
58    /// Convenience method to dispatch a `SessionMessage` to one of the available handlers.
59    fn incoming_message(&mut self, message: &SessionMessage<C>) -> Result<(), SessionError> {
60        match &message {
61            SessionMessage::Segment(s) => self.incoming_segment(&s.id(), s.seq_flags),
62            SessionMessage::Request(r) => self.incoming_retransmission_request(r.clone()),
63            SessionMessage::Acknowledge(a) => self.incoming_acknowledged_frames(a.clone()),
64        }
65    }
66}
67
68/// Represents a stateless Session socket.
69///
70/// Does nothing by default, only logs warnings and events for tracing.
71#[derive(Clone)]
72pub struct Stateless<const C: usize>(String);
73
74impl<const C: usize> Stateless<C> {
75    pub(crate) fn new<I: std::fmt::Display>(session_id: I) -> Self {
76        Self(session_id.to_string())
77    }
78}
79
80impl<const C: usize> SocketState<C> for Stateless<C> {
81    fn session_id(&self) -> &str {
82        &self.0
83    }
84
85    fn run(&mut self, _: SocketComponents<C>) -> Result<(), SessionError> {
86        Ok(())
87    }
88
89    fn stop(&mut self) -> Result<(), SessionError> {
90        Ok(())
91    }
92
93    fn incoming_segment(&mut self, _: &SegmentId, _: SeqIndicator) -> Result<(), SessionError> {
94        Ok(())
95    }
96
97    fn incoming_retransmission_request(&mut self, _: SegmentRequest<C>) -> Result<(), SessionError> {
98        Ok(())
99    }
100
101    fn incoming_acknowledged_frames(&mut self, _: FrameAcknowledgements<C>) -> Result<(), SessionError> {
102        Ok(())
103    }
104
105    fn frame_complete(&mut self, _: FrameId) -> Result<(), SessionError> {
106        Ok(())
107    }
108
109    fn frame_emitted(&mut self, _: FrameId) -> Result<(), SessionError> {
110        Ok(())
111    }
112
113    fn frame_discarded(&mut self, _: FrameId) -> Result<(), SessionError> {
114        Ok(())
115    }
116
117    fn segment_sent(&mut self, _: &Segment) -> Result<(), SessionError> {
118        Ok(())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::{collections::HashSet, time::Duration};
125
126    use anyhow::Context;
127    use futures::{AsyncReadExt, AsyncWriteExt};
128    use futures_time::future::FutureExt;
129
130    use super::*;
131    #[cfg(feature = "telemetry")]
132    use crate::socket::telemetry::NoopTracker;
133    use crate::{
134        SessionSocket, SessionSocketConfig,
135        utils::test::{FaultyNetworkConfig, setup_alice_bob},
136    };
137
138    const FRAME_SIZE: usize = 1500;
139
140    const MTU: usize = 1000;
141
142    mockall::mock! {
143        SockState {}
144        impl SocketState<MTU> for SockState {
145            fn session_id(&self) -> &str;
146            fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError>;
147            fn stop(&mut self) -> Result<(), SessionError>;
148            fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError>;
149            fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError>;
150            fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError>;
151            fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError>;
152            fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError>;
153            fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError>;
154            fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError>;
155        }
156    }
157
158    #[derive(Clone)]
159    struct CloneableMockState<'a>(std::sync::Arc<std::sync::Mutex<MockSockState>>, &'a str);
160
161    impl<'a> CloneableMockState<'a> {
162        pub fn new(state: MockSockState, id: &'a str) -> Self {
163            Self(std::sync::Arc::new(std::sync::Mutex::new(state)), id)
164        }
165    }
166
167    impl SocketState<MTU> for CloneableMockState<'_> {
168        fn session_id(&self) -> &str {
169            let _ = self.0.lock().unwrap().session_id();
170            self.1
171        }
172
173        fn run(&mut self, components: SocketComponents<MTU>) -> Result<(), SessionError> {
174            tracing::debug!(id = self.1, "run called");
175            self.0.lock().unwrap().run(components)
176        }
177
178        fn stop(&mut self) -> Result<(), SessionError> {
179            tracing::debug!(id = self.1, "stop called");
180            self.0.lock().unwrap().stop()
181        }
182
183        fn incoming_segment(&mut self, id: &SegmentId, ind: SeqIndicator) -> Result<(), SessionError> {
184            tracing::debug!(id = self.1, "incoming_segment called");
185            self.0.lock().unwrap().incoming_segment(id, ind)
186        }
187
188        fn incoming_retransmission_request(&mut self, request: SegmentRequest<MTU>) -> Result<(), SessionError> {
189            tracing::debug!(id = self.1, "incoming_retransmission_request called");
190            self.0.lock().unwrap().incoming_retransmission_request(request)
191        }
192
193        fn incoming_acknowledged_frames(&mut self, ack: FrameAcknowledgements<MTU>) -> Result<(), SessionError> {
194            tracing::debug!(id = self.1, "incoming_acknowledged_frames called");
195            self.0.lock().unwrap().incoming_acknowledged_frames(ack)
196        }
197
198        fn frame_complete(&mut self, id: FrameId) -> Result<(), SessionError> {
199            tracing::debug!(id = self.1, "frame_complete called");
200            self.0.lock().unwrap().frame_complete(id)
201        }
202
203        fn frame_emitted(&mut self, id: FrameId) -> Result<(), SessionError> {
204            tracing::debug!(id = self.1, "frame_received called");
205            self.0.lock().unwrap().frame_emitted(id)
206        }
207
208        fn frame_discarded(&mut self, id: FrameId) -> Result<(), SessionError> {
209            tracing::debug!(id = self.1, "frame_discarded called");
210            self.0.lock().unwrap().frame_discarded(id)
211        }
212
213        fn segment_sent(&mut self, segment: &Segment) -> Result<(), SessionError> {
214            tracing::debug!(id = self.1, "segment_sent called");
215            self.0.lock().unwrap().segment_sent(segment)
216        }
217    }
218
219    #[test_log::test(tokio::test)]
220    async fn session_socket_must_correctly_dispatch_segment_and_frame_state_events() -> anyhow::Result<()> {
221        const NUM_FRAMES: usize = 2;
222
223        const NUM_SEGMENTS: usize = NUM_FRAMES * FRAME_SIZE / MTU + 1;
224
225        let mut alice_seq = mockall::Sequence::new();
226        let mut alice_state = MockSockState::new();
227        alice_state.expect_session_id().return_const("alice".into());
228
229        alice_state
230            .expect_run()
231            .once()
232            .in_sequence(&mut alice_seq)
233            .return_once(|_| Ok::<_, SessionError>(()));
234        alice_state
235            .expect_segment_sent()
236            .times(NUM_SEGMENTS)
237            .in_sequence(&mut alice_seq)
238            .returning(|_| Ok::<_, SessionError>(()));
239        alice_state
240            .expect_stop()
241            .once()
242            .in_sequence(&mut alice_seq)
243            .return_once(|| Ok::<_, SessionError>(()));
244        alice_state
245            .expect_segment_sent() // terminating segment
246            .once()
247            .in_sequence(&mut alice_seq)
248            .return_once(|_| Ok::<_, SessionError>(()));
249        // PinnedDrop on SessionSocket calls state.stop() again on drop.
250        alice_state
251            .expect_stop()
252            .once()
253            .in_sequence(&mut alice_seq)
254            .return_once(|| Ok::<_, SessionError>(()));
255
256        let mut bob_seq = mockall::Sequence::new();
257        let mut bob_state = MockSockState::new();
258        bob_state.expect_session_id().return_const("bob".into());
259
260        bob_state
261            .expect_run()
262            .once()
263            .in_sequence(&mut bob_seq)
264            .return_once(|_| Ok::<_, SessionError>(()));
265        bob_state
266            .expect_incoming_segment()
267            .times(NUM_SEGMENTS - 1)
268            .in_sequence(&mut bob_seq)
269            .returning(|_, _| Ok::<_, SessionError>(()));
270        bob_state
271            .expect_frame_complete()
272            .once()
273            .in_sequence(&mut bob_seq)
274            .with(mockall::predicate::eq(2))
275            .returning(|_| Ok::<_, SessionError>(()));
276        bob_state
277            .expect_frame_discarded()
278            .once()
279            .in_sequence(&mut bob_seq)
280            .with(mockall::predicate::eq(1))
281            .returning(|_| Ok::<_, SessionError>(()));
282        bob_state
283            .expect_frame_emitted()
284            .once()
285            .in_sequence(&mut bob_seq)
286            .with(mockall::predicate::eq(2))
287            .returning(|_| Ok::<_, SessionError>(()));
288        bob_state
289            .expect_stop()
290            .once()
291            .in_sequence(&mut bob_seq)
292            .return_once(|| Ok::<_, SessionError>(()));
293        bob_state
294            .expect_segment_sent() // terminating segment
295            .once()
296            .in_sequence(&mut bob_seq)
297            .return_once(|_| Ok::<_, SessionError>(()));
298        // PinnedDrop on SessionSocket calls state.stop() again on drop.
299        bob_state
300            .expect_stop()
301            .once()
302            .in_sequence(&mut bob_seq)
303            .return_once(|| Ok::<_, SessionError>(()));
304
305        let (alice, bob) = setup_alice_bob::<MTU>(
306            FaultyNetworkConfig {
307                avg_delay: Duration::from_millis(10),
308                ids_to_drop: HashSet::from_iter([0_usize]),
309                ..Default::default()
310            },
311            None,
312            None,
313        );
314
315        let cfg = SessionSocketConfig {
316            frame_size: FRAME_SIZE,
317            frame_timeout: Duration::from_millis(55),
318            ..Default::default()
319        };
320
321        let mut alice_socket = SessionSocket::new(
322            alice,
323            CloneableMockState::new(alice_state, "alice"),
324            cfg,
325            #[cfg(feature = "telemetry")]
326            NoopTracker,
327        )?;
328        let mut bob_socket = SessionSocket::new(
329            bob,
330            CloneableMockState::new(bob_state, "bob"),
331            cfg,
332            #[cfg(feature = "telemetry")]
333            NoopTracker,
334        )?;
335
336        let alice_sent_data = hopr_types::crypto_random::random_bytes::<{ NUM_FRAMES * FRAME_SIZE }>();
337        alice_socket
338            .write_all(&alice_sent_data)
339            .timeout(futures_time::time::Duration::from_secs(2))
340            .await
341            .context("write_all timeout")??;
342        alice_socket.flush().await?;
343
344        // One entire frame is discarded
345        let mut bob_recv_data = [0u8; (NUM_FRAMES - 1) * FRAME_SIZE];
346        bob_socket
347            .read_exact(&mut bob_recv_data)
348            .timeout(futures_time::time::Duration::from_secs(2))
349            .await
350            .context("read_exact timeout")??;
351
352        tracing::debug!("stopping");
353        alice_socket.close().await?;
354        bob_socket.close().await?;
355
356        Ok(())
357    }
358}