Skip to main content

hopr_async_runtime/
lib.rs

1//! Executor API for HOPR which exposes the necessary async functions depending on the enabled
2//! runtime.
3
4use std::hash::Hash;
5
6pub use futures::future::AbortHandle;
7
8// Both features could be enabled during testing; therefore, we only use tokio when it's
9// exclusively enabled.
10pub mod prelude {
11    #[cfg(feature = "async-lock")]
12    pub use async_lock::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
13    pub use futures::future::{AbortHandle, abortable};
14    #[cfg(all(feature = "runtime-tokio", not(feature = "async-lock")))]
15    pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
16    #[cfg(feature = "runtime-tokio")]
17    pub use tokio::{
18        task::{JoinError, JoinHandle, spawn, spawn_blocking, spawn_local},
19        time::{sleep, timeout as timeout_fut},
20    };
21}
22
23#[macro_export]
24macro_rules! spawn_as_abortable {
25    ($($expr:expr),*) => {{
26        let (proc, abort_handle) = $crate::prelude::abortable($($expr),*);
27        let _jh = $crate::prelude::spawn(proc);
28        abort_handle
29    }}
30}
31
32/// Abstraction over tasks that can be aborted (such as join or abort handles).
33#[auto_impl::auto_impl(&, Box, Arc)]
34pub trait Abortable {
35    /// Notifies the task that it should abort.
36    ///
37    /// Must be idempotent and not panic if it was already called before, due to implementation-specific
38    /// semantics of [`Abortable::was_aborted`].
39    fn abort_task(&self);
40
41    /// Returns `true` if [`abort_task`](Abortable::abort_task) was already called or the task has finished.
42    ///
43    /// It is implementation-specific whether `true` actually means that the task has been finished.
44    /// The [`Abortable::abort_task`] therefore can be also called if `true` is returned without a consequence.
45    fn was_aborted(&self) -> bool;
46}
47
48impl Abortable for AbortHandle {
49    fn abort_task(&self) {
50        self.abort();
51    }
52
53    fn was_aborted(&self) -> bool {
54        self.is_aborted()
55    }
56}
57
58#[cfg(feature = "runtime-tokio")]
59impl Abortable for tokio::task::JoinHandle<()> {
60    fn abort_task(&self) {
61        self.abort();
62    }
63
64    fn was_aborted(&self) -> bool {
65        self.is_finished()
66    }
67}
68
69/// List of [`Abortable`] tasks with each task identified by a unique key of type `T`.
70///
71/// Abortable objects, such as join or abort handles, do not by design abort when dropped.
72/// Sometimes this behavior is not desirable, and spawned run-away tasks may still continue to live
73/// e.g.: after an error is raised.
74///
75/// This object allows safely managing abortable tasks and will terminate all the tasks in reverse insertion order once
76/// dropped.
77///
78/// Additionally, this object also implements [`Abortable`] allowing it to be arbitrarily nested.
79pub struct AbortableList<T>(indexmap::IndexMap<T, Box<dyn Abortable + Send + Sync>>);
80
81impl<T> Default for AbortableList<T> {
82    fn default() -> Self {
83        Self(indexmap::IndexMap::new())
84    }
85}
86
87impl<T: std::fmt::Debug> std::fmt::Debug for AbortableList<T> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_list().entries(self.0.keys()).finish()
90    }
91}
92
93impl<T: Hash + Eq> AbortableList<T> {
94    /// Appends a new [`abortable task`](Abortable) to the end of this list.
95    pub fn insert<A: Abortable + Send + Sync + 'static>(&mut self, process: T, task: A) {
96        self.0.insert(process, Box::new(task));
97    }
98
99    /// Checks if the list contains a task with the given key.
100    pub fn contains(&self, process: &T) -> bool {
101        self.0.contains_key(process)
102    }
103
104    /// Looks up a task by its key, removes it and aborts it.
105    ///
106    /// Returns `true` if the task was aborted and removed.
107    /// Otherwise, returns `false` (including a situation when the task was present but already aborted).
108    pub fn abort_one(&mut self, process: &T) -> bool {
109        if let Some(item) = self.0.shift_remove(process).filter(|t| !t.was_aborted()) {
110            item.abort_task();
111            true
112        } else {
113            false
114        }
115    }
116
117    /// Extends this list by appending `other`.
118    ///
119    /// The tasks from `other` are moved to this list without aborting them.
120    /// Afterward, `other` will be empty.
121    pub fn extend_from(&mut self, mut other: AbortableList<T>) {
122        self.0.extend(other.0.drain(..));
123    }
124
125    /// Extends this list by appending `other` while mapping its keys to the ones in this list.
126    ///
127    /// The tasks from `other` are moved to this list without aborting them.
128    /// Afterward, `other` will be empty.
129    pub fn flat_map_extend_from<U>(&mut self, mut other: AbortableList<U>, key_map: impl Fn(U) -> T) {
130        self.0.extend(other.0.drain(..).map(|(k, v)| (key_map(k), v)));
131    }
132}
133impl<T> AbortableList<T> {
134    /// Checks if the list is empty.
135    pub fn is_empty(&self) -> bool {
136        self.0.is_empty()
137    }
138
139    /// Returns the number of abortable tasks in the list.
140    pub fn size(&self) -> usize {
141        self.0.len()
142    }
143
144    /// Returns an iterator over the task names in the insertion order.
145    pub fn iter_names(&self) -> impl Iterator<Item = &T> {
146        self.0.keys()
147    }
148
149    /// Aborts all tasks in this list in the reverse insertion order.
150    ///
151    /// Skips tasks which were [already aborted](Abortable::was_aborted).
152    pub fn abort_all(&self) {
153        for (_, task) in self.0.iter().rev().filter(|(_, task)| !task.was_aborted()) {
154            task.abort_task();
155        }
156    }
157}
158
159impl<T> Abortable for AbortableList<T> {
160    fn abort_task(&self) {
161        self.abort_all();
162    }
163
164    fn was_aborted(&self) -> bool {
165        self.0.iter().all(|(_, task)| task.was_aborted())
166    }
167}
168
169impl<T> Drop for AbortableList<T> {
170    fn drop(&mut self) {
171        self.abort_all();
172        self.0.clear();
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::sync::{
179        Arc,
180        atomic::{AtomicBool, Ordering},
181    };
182
183    use super::*;
184
185    #[derive(Default)]
186    struct MockTask {
187        aborted: AtomicBool,
188    }
189
190    impl Abortable for MockTask {
191        fn abort_task(&self) {
192            self.aborted.store(true, Ordering::SeqCst);
193        }
194
195        fn was_aborted(&self) -> bool {
196            self.aborted.load(Ordering::SeqCst)
197        }
198    }
199
200    #[test]
201    fn test_insert_and_contains() {
202        let mut list = AbortableList::default();
203        let task1 = Arc::new(MockTask::default());
204        let task2 = Arc::new(MockTask::default());
205
206        list.insert("task1", task1.clone());
207        list.insert("task2", task2.clone());
208
209        assert!(list.contains(&"task1"));
210        assert!(list.contains(&"task2"));
211        assert!(!list.contains(&"task3"));
212        assert_eq!(list.size(), 2);
213        assert!(!list.is_empty());
214    }
215
216    #[test]
217    fn test_abort_one() {
218        let mut list = AbortableList::default();
219        let task1 = Arc::new(MockTask::default());
220
221        list.insert("task1", task1.clone());
222        assert!(list.abort_one(&"task1"));
223        assert!(task1.was_aborted());
224        assert!(!list.contains(&"task1"));
225        assert_eq!(list.size(), 0);
226
227        // Aborting already removed task
228        assert!(!list.abort_one(&"task1"));
229    }
230
231    #[test]
232    fn test_abort_one_already_aborted() {
233        let mut list = AbortableList::default();
234        let task1 = Arc::new(MockTask::default());
235        task1.abort_task();
236
237        list.insert("task1", task1.clone());
238        // abort_one returns false if already aborted
239        assert!(!list.abort_one(&"task1"));
240        // Check that it was still removed from the list even if already aborted
241        assert!(!list.contains(&"task1"));
242    }
243
244    #[test]
245    fn test_debug_impl() {
246        let mut list = AbortableList::default();
247        list.insert("task1", MockTask::default());
248        list.insert("task2", MockTask::default());
249        let debug_str = format!("{:?}", list);
250        assert!(debug_str.contains("task1"));
251        assert!(debug_str.contains("task2"));
252    }
253
254    #[test]
255    fn test_abort_all() {
256        let mut list = AbortableList::default();
257        let task1 = Arc::new(MockTask::default());
258        let task2 = Arc::new(MockTask::default());
259
260        list.insert(1, task1.clone());
261        list.insert(2, task2.clone());
262
263        list.abort_all();
264
265        assert!(task1.was_aborted());
266        assert!(task2.was_aborted());
267        // abort_all doesn't remove from list
268        assert_eq!(list.size(), 2);
269    }
270
271    #[test]
272    fn test_drop_aborts_all() {
273        let task1 = Arc::new(MockTask::default());
274        let task2 = Arc::new(MockTask::default());
275
276        {
277            let mut list = AbortableList::default();
278            list.insert(1, task1.clone());
279            list.insert(2, task2.clone());
280        }
281
282        assert!(task1.was_aborted());
283        assert!(task2.was_aborted());
284    }
285
286    #[test]
287    fn test_extend_from() {
288        let mut list1 = AbortableList::default();
289        let mut list2 = AbortableList::default();
290
291        let task1 = Arc::new(MockTask::default());
292        let task2 = Arc::new(MockTask::default());
293
294        list1.insert(1, task1.clone());
295        list2.insert(2, task2.clone());
296
297        list1.extend_from(list2);
298
299        assert_eq!(list1.size(), 2);
300        assert!(list1.contains(&1));
301        assert!(list1.contains(&2));
302
303        // Ensure task2 was not aborted during extend
304        assert!(!task2.was_aborted());
305    }
306
307    #[test]
308    fn test_flat_map_extend_from() {
309        let mut list1 = AbortableList::default();
310        let mut list2 = AbortableList::default();
311
312        let task1 = Arc::new(MockTask::default());
313        let task2 = Arc::new(MockTask::default());
314
315        list1.insert("a", task1.clone());
316        list2.insert(1, task2.clone());
317
318        list1.flat_map_extend_from(list2, |k| if k == 1 { "b" } else { "c" });
319
320        assert_eq!(list1.size(), 2);
321        assert!(list1.contains(&"a"));
322        assert!(list1.contains(&"b"));
323    }
324
325    #[test]
326    fn test_nested_abortable_list() {
327        let mut outer = AbortableList::default();
328        let mut inner = AbortableList::default();
329
330        let task1 = Arc::new(MockTask::default());
331        inner.insert(1, task1.clone());
332
333        outer.insert("inner", inner);
334
335        outer.abort_all();
336        assert!(task1.was_aborted());
337    }
338
339    #[test]
340    fn test_was_aborted_all() {
341        let mut list = AbortableList::default();
342        let task1 = Arc::new(MockTask::default());
343        let task2 = Arc::new(MockTask::default());
344
345        list.insert(1, task1.clone());
346        list.insert(2, task2.clone());
347
348        assert!(!list.was_aborted());
349
350        task1.abort_task();
351        assert!(!list.was_aborted());
352
353        task2.abort_task();
354        assert!(list.was_aborted());
355    }
356
357    #[test]
358    fn test_iter_names() {
359        let mut list = AbortableList::default();
360        list.insert("a", MockTask::default());
361        list.insert("b", MockTask::default());
362        list.insert("c", MockTask::default());
363
364        let names: Vec<&&str> = list.iter_names().collect();
365        assert_eq!(names, vec![&"a", &"b", &"c"]);
366    }
367
368    #[test]
369    fn test_reverse_insertion_order_on_abort() {
370        use std::sync::Mutex;
371        let abort_order = Arc::new(Mutex::new(Vec::new()));
372
373        struct OrderedMockTask {
374            id: i32,
375            order: Arc<Mutex<Vec<i32>>>,
376        }
377
378        impl Abortable for OrderedMockTask {
379            fn abort_task(&self) {
380                self.order.lock().unwrap().push(self.id);
381            }
382
383            fn was_aborted(&self) -> bool {
384                self.order.lock().unwrap().contains(&self.id)
385            }
386        }
387
388        let mut list = AbortableList::default();
389        list.insert(
390            1,
391            OrderedMockTask {
392                id: 1,
393                order: abort_order.clone(),
394            },
395        );
396        list.insert(
397            2,
398            OrderedMockTask {
399                id: 2,
400                order: abort_order.clone(),
401            },
402        );
403        list.insert(
404            3,
405            OrderedMockTask {
406                id: 3,
407                order: abort_order.clone(),
408            },
409        );
410
411        list.abort_all();
412
413        let order = abort_order.lock().unwrap();
414        assert_eq!(*order, vec![3, 2, 1]);
415    }
416}