1use std::hash::Hash;
5
6pub use futures::future::AbortHandle;
7
8pub mod prelude {
11 #[cfg(feature = "async-lock")]
12 pub use async_lock::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
13 pub use futures::future::{AbortHandle, abortable};
14 #[cfg(all(feature = "runtime-tokio", not(feature = "async-lock")))]
15 pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
16 #[cfg(feature = "runtime-tokio")]
17 pub use tokio::{
18 task::{JoinError, JoinHandle, spawn, spawn_blocking, spawn_local},
19 time::{sleep, timeout as timeout_fut},
20 };
21}
22
23#[macro_export]
24macro_rules! spawn_as_abortable {
25 ($($expr:expr),*) => {{
26 let (proc, abort_handle) = $crate::prelude::abortable($($expr),*);
27 let _jh = $crate::prelude::spawn(proc);
28 abort_handle
29 }}
30}
31
32#[auto_impl::auto_impl(&, Box, Arc)]
34pub trait Abortable {
35 fn abort_task(&self);
40
41 fn was_aborted(&self) -> bool;
46}
47
48impl Abortable for AbortHandle {
49 fn abort_task(&self) {
50 self.abort();
51 }
52
53 fn was_aborted(&self) -> bool {
54 self.is_aborted()
55 }
56}
57
58#[cfg(feature = "runtime-tokio")]
59impl Abortable for tokio::task::JoinHandle<()> {
60 fn abort_task(&self) {
61 self.abort();
62 }
63
64 fn was_aborted(&self) -> bool {
65 self.is_finished()
66 }
67}
68
69pub struct AbortableList<T>(indexmap::IndexMap<T, Box<dyn Abortable + Send + Sync>>);
80
81impl<T> Default for AbortableList<T> {
82 fn default() -> Self {
83 Self(indexmap::IndexMap::new())
84 }
85}
86
87impl<T: std::fmt::Debug> std::fmt::Debug for AbortableList<T> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_list().entries(self.0.keys()).finish()
90 }
91}
92
93impl<T: Hash + Eq> AbortableList<T> {
94 pub fn insert<A: Abortable + Send + Sync + 'static>(&mut self, process: T, task: A) {
96 self.0.insert(process, Box::new(task));
97 }
98
99 pub fn contains(&self, process: &T) -> bool {
101 self.0.contains_key(process)
102 }
103
104 pub fn abort_one(&mut self, process: &T) -> bool {
109 if let Some(item) = self.0.shift_remove(process).filter(|t| !t.was_aborted()) {
110 item.abort_task();
111 true
112 } else {
113 false
114 }
115 }
116
117 pub fn extend_from(&mut self, mut other: AbortableList<T>) {
122 self.0.extend(other.0.drain(..));
123 }
124
125 pub fn flat_map_extend_from<U>(&mut self, mut other: AbortableList<U>, key_map: impl Fn(U) -> T) {
130 self.0.extend(other.0.drain(..).map(|(k, v)| (key_map(k), v)));
131 }
132}
133impl<T> AbortableList<T> {
134 pub fn is_empty(&self) -> bool {
136 self.0.is_empty()
137 }
138
139 pub fn size(&self) -> usize {
141 self.0.len()
142 }
143
144 pub fn iter_names(&self) -> impl Iterator<Item = &T> {
146 self.0.keys()
147 }
148
149 pub fn abort_all(&self) {
153 for (_, task) in self.0.iter().rev().filter(|(_, task)| !task.was_aborted()) {
154 task.abort_task();
155 }
156 }
157}
158
159impl<T> Abortable for AbortableList<T> {
160 fn abort_task(&self) {
161 self.abort_all();
162 }
163
164 fn was_aborted(&self) -> bool {
165 self.0.iter().all(|(_, task)| task.was_aborted())
166 }
167}
168
169impl<T> Drop for AbortableList<T> {
170 fn drop(&mut self) {
171 self.abort_all();
172 self.0.clear();
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::sync::{
179 Arc,
180 atomic::{AtomicBool, Ordering},
181 };
182
183 use super::*;
184
185 #[derive(Default)]
186 struct MockTask {
187 aborted: AtomicBool,
188 }
189
190 impl Abortable for MockTask {
191 fn abort_task(&self) {
192 self.aborted.store(true, Ordering::SeqCst);
193 }
194
195 fn was_aborted(&self) -> bool {
196 self.aborted.load(Ordering::SeqCst)
197 }
198 }
199
200 #[test]
201 fn test_insert_and_contains() {
202 let mut list = AbortableList::default();
203 let task1 = Arc::new(MockTask::default());
204 let task2 = Arc::new(MockTask::default());
205
206 list.insert("task1", task1.clone());
207 list.insert("task2", task2.clone());
208
209 assert!(list.contains(&"task1"));
210 assert!(list.contains(&"task2"));
211 assert!(!list.contains(&"task3"));
212 assert_eq!(list.size(), 2);
213 assert!(!list.is_empty());
214 }
215
216 #[test]
217 fn test_abort_one() {
218 let mut list = AbortableList::default();
219 let task1 = Arc::new(MockTask::default());
220
221 list.insert("task1", task1.clone());
222 assert!(list.abort_one(&"task1"));
223 assert!(task1.was_aborted());
224 assert!(!list.contains(&"task1"));
225 assert_eq!(list.size(), 0);
226
227 assert!(!list.abort_one(&"task1"));
229 }
230
231 #[test]
232 fn test_abort_one_already_aborted() {
233 let mut list = AbortableList::default();
234 let task1 = Arc::new(MockTask::default());
235 task1.abort_task();
236
237 list.insert("task1", task1.clone());
238 assert!(!list.abort_one(&"task1"));
240 assert!(!list.contains(&"task1"));
242 }
243
244 #[test]
245 fn test_debug_impl() {
246 let mut list = AbortableList::default();
247 list.insert("task1", MockTask::default());
248 list.insert("task2", MockTask::default());
249 let debug_str = format!("{:?}", list);
250 assert!(debug_str.contains("task1"));
251 assert!(debug_str.contains("task2"));
252 }
253
254 #[test]
255 fn test_abort_all() {
256 let mut list = AbortableList::default();
257 let task1 = Arc::new(MockTask::default());
258 let task2 = Arc::new(MockTask::default());
259
260 list.insert(1, task1.clone());
261 list.insert(2, task2.clone());
262
263 list.abort_all();
264
265 assert!(task1.was_aborted());
266 assert!(task2.was_aborted());
267 assert_eq!(list.size(), 2);
269 }
270
271 #[test]
272 fn test_drop_aborts_all() {
273 let task1 = Arc::new(MockTask::default());
274 let task2 = Arc::new(MockTask::default());
275
276 {
277 let mut list = AbortableList::default();
278 list.insert(1, task1.clone());
279 list.insert(2, task2.clone());
280 }
281
282 assert!(task1.was_aborted());
283 assert!(task2.was_aborted());
284 }
285
286 #[test]
287 fn test_extend_from() {
288 let mut list1 = AbortableList::default();
289 let mut list2 = AbortableList::default();
290
291 let task1 = Arc::new(MockTask::default());
292 let task2 = Arc::new(MockTask::default());
293
294 list1.insert(1, task1.clone());
295 list2.insert(2, task2.clone());
296
297 list1.extend_from(list2);
298
299 assert_eq!(list1.size(), 2);
300 assert!(list1.contains(&1));
301 assert!(list1.contains(&2));
302
303 assert!(!task2.was_aborted());
305 }
306
307 #[test]
308 fn test_flat_map_extend_from() {
309 let mut list1 = AbortableList::default();
310 let mut list2 = AbortableList::default();
311
312 let task1 = Arc::new(MockTask::default());
313 let task2 = Arc::new(MockTask::default());
314
315 list1.insert("a", task1.clone());
316 list2.insert(1, task2.clone());
317
318 list1.flat_map_extend_from(list2, |k| if k == 1 { "b" } else { "c" });
319
320 assert_eq!(list1.size(), 2);
321 assert!(list1.contains(&"a"));
322 assert!(list1.contains(&"b"));
323 }
324
325 #[test]
326 fn test_nested_abortable_list() {
327 let mut outer = AbortableList::default();
328 let mut inner = AbortableList::default();
329
330 let task1 = Arc::new(MockTask::default());
331 inner.insert(1, task1.clone());
332
333 outer.insert("inner", inner);
334
335 outer.abort_all();
336 assert!(task1.was_aborted());
337 }
338
339 #[test]
340 fn test_was_aborted_all() {
341 let mut list = AbortableList::default();
342 let task1 = Arc::new(MockTask::default());
343 let task2 = Arc::new(MockTask::default());
344
345 list.insert(1, task1.clone());
346 list.insert(2, task2.clone());
347
348 assert!(!list.was_aborted());
349
350 task1.abort_task();
351 assert!(!list.was_aborted());
352
353 task2.abort_task();
354 assert!(list.was_aborted());
355 }
356
357 #[test]
358 fn test_iter_names() {
359 let mut list = AbortableList::default();
360 list.insert("a", MockTask::default());
361 list.insert("b", MockTask::default());
362 list.insert("c", MockTask::default());
363
364 let names: Vec<&&str> = list.iter_names().collect();
365 assert_eq!(names, vec![&"a", &"b", &"c"]);
366 }
367
368 #[test]
369 fn test_reverse_insertion_order_on_abort() {
370 use std::sync::Mutex;
371 let abort_order = Arc::new(Mutex::new(Vec::new()));
372
373 struct OrderedMockTask {
374 id: i32,
375 order: Arc<Mutex<Vec<i32>>>,
376 }
377
378 impl Abortable for OrderedMockTask {
379 fn abort_task(&self) {
380 self.order.lock().unwrap().push(self.id);
381 }
382
383 fn was_aborted(&self) -> bool {
384 self.order.lock().unwrap().contains(&self.id)
385 }
386 }
387
388 let mut list = AbortableList::default();
389 list.insert(
390 1,
391 OrderedMockTask {
392 id: 1,
393 order: abort_order.clone(),
394 },
395 );
396 list.insert(
397 2,
398 OrderedMockTask {
399 id: 2,
400 order: abort_order.clone(),
401 },
402 );
403 list.insert(
404 3,
405 OrderedMockTask {
406 id: 3,
407 order: abort_order.clone(),
408 },
409 );
410
411 list.abort_all();
412
413 let order = abort_order.lock().unwrap();
414 assert_eq!(*order, vec![3, 2, 1]);
415 }
416}