Skip to main content

hopr_ticket_manager/
utils.rs

1use std::sync::atomic::{AtomicBool, AtomicU64};
2
3use hopr_api::{
4    chain::{ChannelId, HoprBalance},
5    types::internal::prelude::TicketBuilder,
6};
7
8use crate::{OutgoingIndexStore, TicketQueue};
9
10/// Tracks outgoing ticket indices for a channel, starting from 0.
11#[derive(Debug)]
12struct OutgoingIndexEntry {
13    index: AtomicU64,
14    is_dirty: AtomicBool,
15}
16
17impl Default for OutgoingIndexEntry {
18    fn default() -> Self {
19        Self::new(0)
20    }
21}
22
23impl OutgoingIndexEntry {
24    /// Creates a new index entry and marks it as dirty.
25    fn new(index: u64) -> Self {
26        OutgoingIndexEntry {
27            index: AtomicU64::new(index),
28            is_dirty: AtomicBool::new(true),
29        }
30    }
31
32    /// Increments the index and marks it as dirty if within bounds.
33    ///
34    /// The value returned is the value before the increment, saturating at [`TicketBuilder::MAX_TICKET_INDEX`].
35    fn increment(&self) -> u64 {
36        let v = self.index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
37        if v <= TicketBuilder::MAX_TICKET_INDEX {
38            self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
39        }
40        v.min(TicketBuilder::MAX_TICKET_INDEX)
41    }
42
43    /// Sets the index to the maximum of `new_value` and the current index value.
44    ///
45    /// Marks the index as dirty if the value was increased.
46    ///
47    /// Returns the new value of the index, saturating at [`TicketBuilder::MAX_TICKET_INDEX`].
48    fn set(&self, new_value: u64) -> u64 {
49        let current = self.index.fetch_max(new_value, std::sync::atomic::Ordering::Relaxed);
50        if current < new_value {
51            self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
52        }
53        new_value.max(current).min(TicketBuilder::MAX_TICKET_INDEX)
54    }
55
56    /// Checks if the index is marked as dirty.
57    fn is_dirty(&self) -> bool {
58        self.is_dirty.load(std::sync::atomic::Ordering::Acquire)
59    }
60
61    /// Marks the index as clean.
62    fn mark_clean(&self) {
63        self.is_dirty.store(false, std::sync::atomic::Ordering::Release);
64    }
65
66    /// Gets the index.
67    ///
68    /// The returned value will always be less than [`TicketBuilder::MAX_TICKET_INDEX`].
69    fn get(&self) -> u64 {
70        self.index
71            .load(std::sync::atomic::Ordering::Relaxed)
72            .min(TicketBuilder::MAX_TICKET_INDEX)
73    }
74}
75
76#[derive(Debug, Default)]
77pub struct OutgoingIndexCache {
78    cache: dashmap::DashMap<(ChannelId, u32), std::sync::Arc<OutgoingIndexEntry>>,
79    removed: dashmap::DashSet<(ChannelId, u32)>,
80}
81
82impl OutgoingIndexCache {
83    /// Returns the next outgoing index for the given channel and epoch.
84    pub fn next(&self, channel_id: &ChannelId, epoch: u32) -> u64 {
85        self.cache.entry((*channel_id, epoch)).or_default().increment()
86    }
87
88    /// Inserts the index for the given channel and `epoch`, or updates
89    /// the existing value if it is less than the provided `index`.
90    ///
91    /// Returns the index value that is either:
92    ///  - equal to `index` if no index for the given channel and epoch existed and the value was inserted, or
93    ///  - equal to the existing index value, if the provided `index` is less than the existing index value, or
94    ///  - equal to the provided `index` value if it is greater than the existing index value and the value is updated.
95    pub fn upsert(&self, channel_id: &ChannelId, epoch: u32, index: u64) -> u64 {
96        self.cache
97            .entry((*channel_id, epoch))
98            .or_insert_with(|| std::sync::Arc::new(OutgoingIndexEntry::new(index)))
99            .set(index)
100    }
101
102    /// Removes the index for the given channel and `epoch` if it exists.
103    ///
104    /// Returns `true` if the index was removed, `false` otherwise.
105    pub fn remove(&self, channel_id: &ChannelId, epoch: u32) -> bool {
106        if let Some(((id, ep), _)) = self.cache.remove(&(*channel_id, epoch)) {
107            self.removed.insert((id, ep));
108            true
109        } else {
110            false
111        }
112    }
113
114    /// Synchronizes the current state with the provided store.
115    ///
116    /// Saves only those values that were changed since the last save operation.
117    pub fn save<S: OutgoingIndexStore + Send + Sync + 'static>(
118        &self,
119        store: std::sync::Arc<parking_lot::RwLock<S>>,
120    ) -> Result<(), anyhow::Error> {
121        // Clone entries so that we do not hold any locks
122        let cache = self.cache.clone();
123        let removed = self.removed.clone();
124        let mut failed = 0;
125
126        for entry in cache.iter().filter(|e| e.value().is_dirty()) {
127            let (channel_id, epoch) = entry.key();
128            let index = entry.value().get();
129            if let Err(error) = store.write().save_outgoing_index(channel_id, *epoch, index) {
130                tracing::error!(%error, %channel_id, epoch, "failed to save outgoing index");
131                failed += 1;
132            } else {
133                tracing::trace!(%channel_id, epoch, index, "saved outgoing index");
134                entry.value().mark_clean();
135            }
136        }
137
138        for (channel_id, channel_epoch) in removed.iter().map(|e| (e.0, e.1)) {
139            if let Err(error) = store.write().delete_outgoing_index(&channel_id, channel_epoch) {
140                tracing::error!(%error, %channel_id, %channel_epoch, "failed to remove outgoing index");
141                failed += 1;
142            } else {
143                tracing::trace!(%channel_id, %channel_epoch, "removed outgoing index");
144                self.removed.remove(&(channel_id, channel_epoch));
145            }
146        }
147
148        if failed > 0 {
149            anyhow::bail!("failed to save {} outgoing index entries", failed);
150        }
151        Ok(())
152    }
153}
154
155#[derive(Debug)]
156pub struct ChannelTicketQueue<Q> {
157    pub(crate) queue: std::sync::Arc<parking_lot::RwLock<(Q, ChannelTicketStats)>>,
158    pub(crate) redeem_lock: std::sync::Arc<AtomicBool>,
159}
160
161impl<Q: TicketQueue> From<Q> for ChannelTicketQueue<Q> {
162    fn from(queue: Q) -> Self {
163        let stats = ChannelTicketStats {
164            winning_tickets: queue.len().unwrap_or(0) as u128,
165            ..Default::default()
166        };
167        Self {
168            queue: std::sync::Arc::new(parking_lot::RwLock::new((queue, stats))),
169            redeem_lock: std::sync::Arc::new(AtomicBool::new(false)),
170        }
171    }
172}
173
174#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
175pub(crate) struct ChannelTicketStats {
176    pub winning_tickets: u128,
177    pub redeemed_value: HoprBalance,
178    pub neglected_value: HoprBalance,
179    pub rejected_value: HoprBalance,
180}
181
182#[cfg(test)]
183mod tests {
184    use std::{sync::Arc, thread};
185
186    use super::*;
187    use crate::MemoryStore;
188
189    const MAX: u64 = TicketBuilder::MAX_TICKET_INDEX;
190
191    fn store() -> Arc<parking_lot::RwLock<MemoryStore>> {
192        Arc::new(parking_lot::RwLock::new(MemoryStore::default()))
193    }
194
195    #[test]
196    fn default_initializes_to_zero_and_dirty() {
197        let e = OutgoingIndexEntry::default();
198        assert_eq!(e.get(), 0);
199        assert!(e.is_dirty());
200    }
201
202    #[test]
203    fn increment_saturates_return_value_at_max() {
204        let e = OutgoingIndexEntry::new(MAX);
205        assert_eq!(e.increment(), MAX);
206        assert_eq!(e.get(), MAX);
207    }
208
209    #[test]
210    fn set_does_not_decrease_value_when_new_value_is_lower() {
211        let e = OutgoingIndexEntry::new(20);
212        e.mark_clean();
213        assert_eq!(e.set(10), 20);
214        assert_eq!(e.get(), 20);
215        assert!(!e.is_dirty());
216    }
217
218    #[test]
219    fn concurrent_set_uses_max_semantics() {
220        let e = Arc::new(OutgoingIndexEntry::new(0));
221        e.mark_clean();
222
223        let vals = [1u64, 7, 3, 42, 9];
224        let mut handles = vec![];
225
226        for v in vals {
227            let e2 = Arc::clone(&e);
228            handles.push(thread::spawn(move || {
229                e2.set(v);
230            }));
231        }
232
233        for h in handles {
234            h.join().unwrap();
235        }
236
237        assert_eq!(e.get(), 42.min(MAX));
238        assert!(e.is_dirty());
239    }
240
241    #[test]
242    fn next_creates_entry_with_zero_and_increments_sequentially() {
243        let cache = OutgoingIndexCache::default();
244        let channel_id = Default::default();
245        let epoch = 1;
246
247        assert_eq!(cache.next(&channel_id, epoch), 0);
248        assert_eq!(cache.next(&channel_id, epoch), 1);
249        assert_eq!(cache.next(&channel_id, epoch), 2);
250    }
251
252    #[test]
253    fn next_is_scoped_by_channel_and_epoch() {
254        let cache = OutgoingIndexCache::default();
255        let channel_a = Default::default();
256        let channel_b = ChannelId::create(&[b"other"]);
257
258        assert_eq!(cache.next(&channel_a, 1), 0);
259        assert_eq!(cache.next(&channel_a, 2), 0);
260        assert_eq!(cache.next(&channel_b, 1), 0);
261        assert_eq!(cache.next(&channel_a, 1), 1);
262    }
263
264    #[test]
265    fn set_inserts_when_key_is_missing() {
266        let cache = OutgoingIndexCache::default();
267        let channel_id = Default::default();
268
269        assert_eq!(cache.upsert(&channel_id, 1, 17), 17);
270        assert_eq!(cache.next(&channel_id, 1), 17);
271        assert_eq!(cache.next(&channel_id, 1), 18);
272    }
273
274    #[test]
275    fn set_does_not_decrease_existing_value() {
276        let cache = OutgoingIndexCache::default();
277        let channel_id = Default::default();
278
279        assert_eq!(cache.upsert(&channel_id, 1, 10), 10);
280        assert_eq!(cache.upsert(&channel_id, 1, 7), 10);
281        assert_eq!(cache.next(&channel_id, 1), 10);
282        assert_eq!(cache.next(&channel_id, 1), 11);
283    }
284
285    #[test]
286    fn next_saturates_at_max() {
287        let cache = OutgoingIndexCache::default();
288        let channel_id = Default::default();
289
290        assert_eq!(cache.upsert(&channel_id, 1, MAX), MAX);
291        assert_eq!(cache.next(&channel_id, 1), MAX);
292        assert_eq!(cache.next(&channel_id, 1), MAX);
293    }
294
295    #[test]
296    fn remove_existing_entry_returns_true_and_persists_deletion_on_save() -> anyhow::Result<()> {
297        let cache = OutgoingIndexCache::default();
298        let channel_id = Default::default();
299        let epoch = 1;
300        let store = store();
301
302        assert_eq!(cache.upsert(&channel_id, epoch, 5), 5);
303        cache.save(store.clone())?;
304
305        assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, Some(5));
306
307        assert!(cache.remove(&channel_id, epoch));
308        assert!(cache.save(store.clone()).is_ok());
309
310        assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, None);
311        assert!(!cache.remove(&channel_id, epoch));
312
313        Ok(())
314    }
315
316    #[test]
317    fn remove_missing_entry_returns_false() {
318        let cache = OutgoingIndexCache::default();
319        let channel_id = Default::default();
320
321        assert!(!cache.remove(&channel_id, 1));
322    }
323
324    #[test]
325    fn save_persists_only_dirty_entries() -> anyhow::Result<()> {
326        let cache = OutgoingIndexCache::default();
327        let channel_a = Default::default();
328        let channel_b = ChannelId::create(&[b"other"]);
329        let store = store();
330
331        cache.upsert(&channel_a, 1, 10);
332        cache.upsert(&channel_b, 2, 20);
333
334        cache.save(store.clone())?;
335        assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
336        assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
337
338        cache.save(store.clone())?;
339        assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
340        assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
341
342        cache.next(&channel_a, 1);
343        cache.save(store.clone())?;
344        assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(11));
345        assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
346
347        Ok(())
348    }
349
350    #[test]
351    fn save_persists_removed_entries_only_once() -> anyhow::Result<()> {
352        let cache = OutgoingIndexCache::default();
353        let channel_id = Default::default();
354        let store = store();
355
356        cache.upsert(&channel_id, 1, 3);
357        cache.save(store.clone())?;
358        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(3));
359
360        assert!(cache.remove(&channel_id, 1));
361        cache.save(store.clone())?;
362        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
363
364        cache.save(store.clone())?;
365        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
366
367        Ok(())
368    }
369
370    #[test]
371    fn save_is_idempotent_after_success() -> anyhow::Result<()> {
372        let cache = OutgoingIndexCache::default();
373        let channel_id = Default::default();
374        let store = store();
375
376        cache.upsert(&channel_id, 1, 9);
377        cache.save(store.clone())?;
378        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
379
380        cache.save(store.clone())?;
381        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
382
383        cache.next(&channel_id, 1);
384        cache.save(store.clone())?;
385        assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(10));
386
387        Ok(())
388    }
389
390    #[test]
391    fn concurrent_next_on_same_key_is_monotonic() -> anyhow::Result<()> {
392        use std::thread;
393
394        let cache = Arc::new(OutgoingIndexCache::default());
395        let channel_id = Default::default();
396        let epoch = 1;
397
398        let mut handles = vec![];
399        for _ in 0..8 {
400            let cache = Arc::clone(&cache);
401            let channel_id = channel_id;
402            handles.push(thread::spawn(move || cache.next(&channel_id, epoch)));
403        }
404
405        let mut values = handles
406            .into_iter()
407            .map(|h| h.join())
408            .collect::<Result<Vec<_>, _>>()
409            .map_err(|_| anyhow::anyhow!("join error"))?;
410        values.sort_unstable();
411
412        assert_eq!(values, (0..8).collect::<Vec<_>>());
413        assert_eq!(cache.next(&channel_id, epoch), 8);
414
415        Ok(())
416    }
417
418    #[test]
419    fn save_handles_multiple_keys_and_removals_together() -> anyhow::Result<()> {
420        let cache = OutgoingIndexCache::default();
421        let channel_a = Default::default();
422        let channel_b = ChannelId::create(&[b"other"]);
423        let store = store();
424
425        cache.upsert(&channel_a, 1, 4);
426        cache.upsert(&channel_b, 2, 7);
427        cache.save(store.clone())?;
428
429        assert!(cache.remove(&channel_a, 1));
430        cache.next(&channel_b, 2);
431        cache.save(store.clone())?;
432
433        assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, None);
434        assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(8));
435
436        Ok(())
437    }
438}