1use std::sync::atomic::{AtomicBool, AtomicU64};
2
3use hopr_api::{
4 chain::{ChannelId, HoprBalance},
5 types::internal::prelude::TicketBuilder,
6};
7
8use crate::{OutgoingIndexStore, TicketQueue};
9
10#[derive(Debug)]
12struct OutgoingIndexEntry {
13 index: AtomicU64,
14 is_dirty: AtomicBool,
15}
16
17impl Default for OutgoingIndexEntry {
18 fn default() -> Self {
19 Self::new(0)
20 }
21}
22
23impl OutgoingIndexEntry {
24 fn new(index: u64) -> Self {
26 OutgoingIndexEntry {
27 index: AtomicU64::new(index),
28 is_dirty: AtomicBool::new(true),
29 }
30 }
31
32 fn increment(&self) -> u64 {
36 let v = self.index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
37 if v <= TicketBuilder::MAX_TICKET_INDEX {
38 self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
39 }
40 v.min(TicketBuilder::MAX_TICKET_INDEX)
41 }
42
43 fn set(&self, new_value: u64) -> u64 {
49 let current = self.index.fetch_max(new_value, std::sync::atomic::Ordering::Relaxed);
50 if current < new_value {
51 self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
52 }
53 new_value.max(current).min(TicketBuilder::MAX_TICKET_INDEX)
54 }
55
56 fn is_dirty(&self) -> bool {
58 self.is_dirty.load(std::sync::atomic::Ordering::Acquire)
59 }
60
61 fn mark_clean(&self) {
63 self.is_dirty.store(false, std::sync::atomic::Ordering::Release);
64 }
65
66 fn get(&self) -> u64 {
70 self.index
71 .load(std::sync::atomic::Ordering::Relaxed)
72 .min(TicketBuilder::MAX_TICKET_INDEX)
73 }
74}
75
76#[derive(Debug, Default)]
77pub struct OutgoingIndexCache {
78 cache: dashmap::DashMap<(ChannelId, u32), std::sync::Arc<OutgoingIndexEntry>>,
79 removed: dashmap::DashSet<(ChannelId, u32)>,
80}
81
82impl OutgoingIndexCache {
83 pub fn next(&self, channel_id: &ChannelId, epoch: u32) -> u64 {
85 self.cache.entry((*channel_id, epoch)).or_default().increment()
86 }
87
88 pub fn upsert(&self, channel_id: &ChannelId, epoch: u32, index: u64) -> u64 {
96 self.cache
97 .entry((*channel_id, epoch))
98 .or_insert_with(|| std::sync::Arc::new(OutgoingIndexEntry::new(index)))
99 .set(index)
100 }
101
102 pub fn remove(&self, channel_id: &ChannelId, epoch: u32) -> bool {
106 if let Some(((id, ep), _)) = self.cache.remove(&(*channel_id, epoch)) {
107 self.removed.insert((id, ep));
108 true
109 } else {
110 false
111 }
112 }
113
114 pub fn save<S: OutgoingIndexStore + Send + Sync + 'static>(
118 &self,
119 store: std::sync::Arc<parking_lot::RwLock<S>>,
120 ) -> Result<(), anyhow::Error> {
121 let cache = self.cache.clone();
123 let removed = self.removed.clone();
124 let mut failed = 0;
125
126 for entry in cache.iter().filter(|e| e.value().is_dirty()) {
127 let (channel_id, epoch) = entry.key();
128 let index = entry.value().get();
129 if let Err(error) = store.write().save_outgoing_index(channel_id, *epoch, index) {
130 tracing::error!(%error, %channel_id, epoch, "failed to save outgoing index");
131 failed += 1;
132 } else {
133 tracing::trace!(%channel_id, epoch, index, "saved outgoing index");
134 entry.value().mark_clean();
135 }
136 }
137
138 for (channel_id, channel_epoch) in removed.iter().map(|e| (e.0, e.1)) {
139 if let Err(error) = store.write().delete_outgoing_index(&channel_id, channel_epoch) {
140 tracing::error!(%error, %channel_id, %channel_epoch, "failed to remove outgoing index");
141 failed += 1;
142 } else {
143 tracing::trace!(%channel_id, %channel_epoch, "removed outgoing index");
144 self.removed.remove(&(channel_id, channel_epoch));
145 }
146 }
147
148 if failed > 0 {
149 anyhow::bail!("failed to save {} outgoing index entries", failed);
150 }
151 Ok(())
152 }
153}
154
155#[derive(Debug)]
156pub struct ChannelTicketQueue<Q> {
157 pub(crate) queue: std::sync::Arc<parking_lot::RwLock<(Q, ChannelTicketStats)>>,
158 pub(crate) redeem_lock: std::sync::Arc<AtomicBool>,
159}
160
161impl<Q: TicketQueue> From<Q> for ChannelTicketQueue<Q> {
162 fn from(queue: Q) -> Self {
163 let stats = ChannelTicketStats {
164 winning_tickets: queue.len().unwrap_or(0) as u128,
165 ..Default::default()
166 };
167 Self {
168 queue: std::sync::Arc::new(parking_lot::RwLock::new((queue, stats))),
169 redeem_lock: std::sync::Arc::new(AtomicBool::new(false)),
170 }
171 }
172}
173
174#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
175pub(crate) struct ChannelTicketStats {
176 pub winning_tickets: u128,
177 pub redeemed_value: HoprBalance,
178 pub neglected_value: HoprBalance,
179 pub rejected_value: HoprBalance,
180}
181
182#[cfg(test)]
183mod tests {
184 use std::{sync::Arc, thread};
185
186 use super::*;
187 use crate::MemoryStore;
188
189 const MAX: u64 = TicketBuilder::MAX_TICKET_INDEX;
190
191 fn store() -> Arc<parking_lot::RwLock<MemoryStore>> {
192 Arc::new(parking_lot::RwLock::new(MemoryStore::default()))
193 }
194
195 #[test]
196 fn default_initializes_to_zero_and_dirty() {
197 let e = OutgoingIndexEntry::default();
198 assert_eq!(e.get(), 0);
199 assert!(e.is_dirty());
200 }
201
202 #[test]
203 fn increment_saturates_return_value_at_max() {
204 let e = OutgoingIndexEntry::new(MAX);
205 assert_eq!(e.increment(), MAX);
206 assert_eq!(e.get(), MAX);
207 }
208
209 #[test]
210 fn set_does_not_decrease_value_when_new_value_is_lower() {
211 let e = OutgoingIndexEntry::new(20);
212 e.mark_clean();
213 assert_eq!(e.set(10), 20);
214 assert_eq!(e.get(), 20);
215 assert!(!e.is_dirty());
216 }
217
218 #[test]
219 fn concurrent_set_uses_max_semantics() {
220 let e = Arc::new(OutgoingIndexEntry::new(0));
221 e.mark_clean();
222
223 let vals = [1u64, 7, 3, 42, 9];
224 let mut handles = vec![];
225
226 for v in vals {
227 let e2 = Arc::clone(&e);
228 handles.push(thread::spawn(move || {
229 e2.set(v);
230 }));
231 }
232
233 for h in handles {
234 h.join().unwrap();
235 }
236
237 assert_eq!(e.get(), 42.min(MAX));
238 assert!(e.is_dirty());
239 }
240
241 #[test]
242 fn next_creates_entry_with_zero_and_increments_sequentially() {
243 let cache = OutgoingIndexCache::default();
244 let channel_id = Default::default();
245 let epoch = 1;
246
247 assert_eq!(cache.next(&channel_id, epoch), 0);
248 assert_eq!(cache.next(&channel_id, epoch), 1);
249 assert_eq!(cache.next(&channel_id, epoch), 2);
250 }
251
252 #[test]
253 fn next_is_scoped_by_channel_and_epoch() {
254 let cache = OutgoingIndexCache::default();
255 let channel_a = Default::default();
256 let channel_b = ChannelId::create(&[b"other"]);
257
258 assert_eq!(cache.next(&channel_a, 1), 0);
259 assert_eq!(cache.next(&channel_a, 2), 0);
260 assert_eq!(cache.next(&channel_b, 1), 0);
261 assert_eq!(cache.next(&channel_a, 1), 1);
262 }
263
264 #[test]
265 fn set_inserts_when_key_is_missing() {
266 let cache = OutgoingIndexCache::default();
267 let channel_id = Default::default();
268
269 assert_eq!(cache.upsert(&channel_id, 1, 17), 17);
270 assert_eq!(cache.next(&channel_id, 1), 17);
271 assert_eq!(cache.next(&channel_id, 1), 18);
272 }
273
274 #[test]
275 fn set_does_not_decrease_existing_value() {
276 let cache = OutgoingIndexCache::default();
277 let channel_id = Default::default();
278
279 assert_eq!(cache.upsert(&channel_id, 1, 10), 10);
280 assert_eq!(cache.upsert(&channel_id, 1, 7), 10);
281 assert_eq!(cache.next(&channel_id, 1), 10);
282 assert_eq!(cache.next(&channel_id, 1), 11);
283 }
284
285 #[test]
286 fn next_saturates_at_max() {
287 let cache = OutgoingIndexCache::default();
288 let channel_id = Default::default();
289
290 assert_eq!(cache.upsert(&channel_id, 1, MAX), MAX);
291 assert_eq!(cache.next(&channel_id, 1), MAX);
292 assert_eq!(cache.next(&channel_id, 1), MAX);
293 }
294
295 #[test]
296 fn remove_existing_entry_returns_true_and_persists_deletion_on_save() -> anyhow::Result<()> {
297 let cache = OutgoingIndexCache::default();
298 let channel_id = Default::default();
299 let epoch = 1;
300 let store = store();
301
302 assert_eq!(cache.upsert(&channel_id, epoch, 5), 5);
303 cache.save(store.clone())?;
304
305 assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, Some(5));
306
307 assert!(cache.remove(&channel_id, epoch));
308 assert!(cache.save(store.clone()).is_ok());
309
310 assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, None);
311 assert!(!cache.remove(&channel_id, epoch));
312
313 Ok(())
314 }
315
316 #[test]
317 fn remove_missing_entry_returns_false() {
318 let cache = OutgoingIndexCache::default();
319 let channel_id = Default::default();
320
321 assert!(!cache.remove(&channel_id, 1));
322 }
323
324 #[test]
325 fn save_persists_only_dirty_entries() -> anyhow::Result<()> {
326 let cache = OutgoingIndexCache::default();
327 let channel_a = Default::default();
328 let channel_b = ChannelId::create(&[b"other"]);
329 let store = store();
330
331 cache.upsert(&channel_a, 1, 10);
332 cache.upsert(&channel_b, 2, 20);
333
334 cache.save(store.clone())?;
335 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
336 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
337
338 cache.save(store.clone())?;
339 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
340 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
341
342 cache.next(&channel_a, 1);
343 cache.save(store.clone())?;
344 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(11));
345 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
346
347 Ok(())
348 }
349
350 #[test]
351 fn save_persists_removed_entries_only_once() -> anyhow::Result<()> {
352 let cache = OutgoingIndexCache::default();
353 let channel_id = Default::default();
354 let store = store();
355
356 cache.upsert(&channel_id, 1, 3);
357 cache.save(store.clone())?;
358 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(3));
359
360 assert!(cache.remove(&channel_id, 1));
361 cache.save(store.clone())?;
362 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
363
364 cache.save(store.clone())?;
365 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
366
367 Ok(())
368 }
369
370 #[test]
371 fn save_is_idempotent_after_success() -> anyhow::Result<()> {
372 let cache = OutgoingIndexCache::default();
373 let channel_id = Default::default();
374 let store = store();
375
376 cache.upsert(&channel_id, 1, 9);
377 cache.save(store.clone())?;
378 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
379
380 cache.save(store.clone())?;
381 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
382
383 cache.next(&channel_id, 1);
384 cache.save(store.clone())?;
385 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(10));
386
387 Ok(())
388 }
389
390 #[test]
391 fn concurrent_next_on_same_key_is_monotonic() -> anyhow::Result<()> {
392 use std::thread;
393
394 let cache = Arc::new(OutgoingIndexCache::default());
395 let channel_id = Default::default();
396 let epoch = 1;
397
398 let mut handles = vec![];
399 for _ in 0..8 {
400 let cache = Arc::clone(&cache);
401 let channel_id = channel_id;
402 handles.push(thread::spawn(move || cache.next(&channel_id, epoch)));
403 }
404
405 let mut values = handles
406 .into_iter()
407 .map(|h| h.join())
408 .collect::<Result<Vec<_>, _>>()
409 .map_err(|_| anyhow::anyhow!("join error"))?;
410 values.sort_unstable();
411
412 assert_eq!(values, (0..8).collect::<Vec<_>>());
413 assert_eq!(cache.next(&channel_id, epoch), 8);
414
415 Ok(())
416 }
417
418 #[test]
419 fn save_handles_multiple_keys_and_removals_together() -> anyhow::Result<()> {
420 let cache = OutgoingIndexCache::default();
421 let channel_a = Default::default();
422 let channel_b = ChannelId::create(&[b"other"]);
423 let store = store();
424
425 cache.upsert(&channel_a, 1, 4);
426 cache.upsert(&channel_b, 2, 7);
427 cache.save(store.clone())?;
428
429 assert!(cache.remove(&channel_a, 1));
430 cache.next(&channel_b, 2);
431 cache.save(store.clone())?;
432
433 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, None);
434 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(8));
435
436 Ok(())
437 }
438}