1use std::sync::atomic::{AtomicBool, AtomicU64};
2
3use hopr_api::{
4 chain::{ChannelId, HoprBalance},
5 types::internal::prelude::TicketBuilder,
6};
7
8use crate::{OutgoingIndexStore, TicketManagerError, TicketQueue, backend::ValueCachedQueue};
9
10#[derive(Debug)]
12struct OutgoingIndexEntry {
13 index: AtomicU64,
14 is_dirty: AtomicBool,
15}
16
17impl Default for OutgoingIndexEntry {
18 fn default() -> Self {
19 Self::new(0)
20 }
21}
22
23impl OutgoingIndexEntry {
24 fn new(index: u64) -> Self {
26 OutgoingIndexEntry {
27 index: AtomicU64::new(index),
28 is_dirty: AtomicBool::new(true),
29 }
30 }
31
32 fn increment(&self) -> u64 {
36 let v = self.index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
37 if v <= TicketBuilder::MAX_TICKET_INDEX {
38 self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
39 }
40 v.min(TicketBuilder::MAX_TICKET_INDEX)
41 }
42
43 fn set(&self, new_value: u64) -> u64 {
49 let current = self.index.fetch_max(new_value, std::sync::atomic::Ordering::Relaxed);
50 if current < new_value {
51 self.is_dirty.store(true, std::sync::atomic::Ordering::Release);
52 }
53 new_value.max(current).min(TicketBuilder::MAX_TICKET_INDEX)
54 }
55
56 fn is_dirty(&self) -> bool {
58 self.is_dirty.load(std::sync::atomic::Ordering::Acquire)
59 }
60
61 fn mark_clean(&self) {
63 self.is_dirty.store(false, std::sync::atomic::Ordering::Release);
64 }
65
66 fn get(&self) -> u64 {
70 self.index
71 .load(std::sync::atomic::Ordering::Relaxed)
72 .min(TicketBuilder::MAX_TICKET_INDEX)
73 }
74}
75
76#[derive(Debug, Default)]
77pub struct OutgoingIndexCache {
78 cache: dashmap::DashMap<(ChannelId, u32), std::sync::Arc<OutgoingIndexEntry>>,
79 removed: dashmap::DashSet<(ChannelId, u32)>,
80}
81
82impl OutgoingIndexCache {
83 pub fn next(&self, channel_id: &ChannelId, epoch: u32) -> u64 {
85 self.cache.entry((*channel_id, epoch)).or_default().increment()
86 }
87
88 pub fn upsert(&self, channel_id: &ChannelId, epoch: u32, index: u64) -> u64 {
96 self.cache
97 .entry((*channel_id, epoch))
98 .or_insert_with(|| std::sync::Arc::new(OutgoingIndexEntry::new(index)))
99 .set(index)
100 }
101
102 pub fn remove(&self, channel_id: &ChannelId, epoch: u32) -> bool {
106 if let Some(((id, ep), _)) = self.cache.remove(&(*channel_id, epoch)) {
107 self.removed.insert((id, ep));
108 true
109 } else {
110 false
111 }
112 }
113
114 pub fn save<S: OutgoingIndexStore + Send + Sync + 'static>(
118 &self,
119 store: std::sync::Arc<parking_lot::RwLock<S>>,
120 ) -> Result<(), anyhow::Error> {
121 let cache = self.cache.clone();
123 let removed = self.removed.clone();
124 let mut failed = 0;
125
126 for entry in cache.iter().filter(|e| e.value().is_dirty()) {
127 let (channel_id, epoch) = entry.key();
128 let index = entry.value().get();
129 if let Err(error) = store.write().save_outgoing_index(channel_id, *epoch, index) {
130 tracing::error!(%error, %channel_id, epoch, "failed to save outgoing index");
131 failed += 1;
132 } else {
133 tracing::trace!(%channel_id, epoch, index, "saved outgoing index");
134 entry.value().mark_clean();
135 }
136 }
137
138 for (channel_id, channel_epoch) in removed.iter().map(|e| (e.0, e.1)) {
139 if let Err(error) = store.write().delete_outgoing_index(&channel_id, channel_epoch) {
140 tracing::error!(%error, %channel_id, %channel_epoch, "failed to remove outgoing index");
141 failed += 1;
142 } else {
143 tracing::trace!(%channel_id, %channel_epoch, "removed outgoing index");
144 self.removed.remove(&(channel_id, channel_epoch));
145 }
146 }
147
148 if failed > 0 {
149 anyhow::bail!("failed to save {} outgoing index entries", failed);
150 }
151 Ok(())
152 }
153}
154
155#[derive(Debug)]
156pub struct CachedQueueMap<Q>(
157 pub(crate) dashmap::DashMap<ChannelId, ChannelTicketQueue<ValueCachedQueue<Q>>, ahash::RandomState>,
158);
159
160impl<Q> Default for CachedQueueMap<Q> {
161 fn default() -> Self {
162 Self(dashmap::DashMap::with_hasher(ahash::RandomState::default()))
163 }
164}
165
166pub trait UnrealizedValue {
167 fn unrealized_value(
170 &self,
171 _channel_id: &ChannelId,
172 _min_index: Option<u64>,
173 ) -> Result<Option<HoprBalance>, TicketManagerError> {
174 Ok(None)
175 }
176}
177impl UnrealizedValue for () {}
178
179impl<Q: TicketQueue> UnrealizedValue for CachedQueueMap<Q> {
180 fn unrealized_value(
182 &self,
183 channel_id: &ChannelId,
184 min_index: Option<u64>,
185 ) -> Result<Option<HoprBalance>, TicketManagerError> {
186 if let Some(ticket_queue) = self.0.get(channel_id) {
187 let queue = ticket_queue.queue.read();
190
191 if let Some(epoch) = queue
195 .0
196 .peek()
197 .map_err(TicketManagerError::store)?
198 .map(|t| t.verified_ticket().channel_epoch)
199 {
200 Ok(Some(
201 queue
202 .0
203 .total_value(epoch, min_index)
204 .map_err(TicketManagerError::store)?,
205 ))
206 } else {
207 Ok(Some(HoprBalance::zero()))
209 }
210 } else {
211 Ok(None)
212 }
213 }
214}
215
216#[derive(Debug)]
217pub struct ChannelTicketQueue<Q> {
218 pub(crate) queue: std::sync::Arc<parking_lot::RwLock<(Q, ChannelTicketStats)>>,
219 pub(crate) redeem_lock: std::sync::Arc<AtomicBool>,
220}
221
222impl<Q: TicketQueue> From<Q> for ChannelTicketQueue<Q> {
223 fn from(queue: Q) -> Self {
224 let stats = ChannelTicketStats {
225 winning_tickets: queue.len().unwrap_or(0) as u128,
226 ..Default::default()
227 };
228 Self {
229 queue: std::sync::Arc::new(parking_lot::RwLock::new((queue, stats))),
230 redeem_lock: std::sync::Arc::new(AtomicBool::new(false)),
231 }
232 }
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
236pub(crate) struct ChannelTicketStats {
237 pub winning_tickets: u128,
238 pub redeemed_value: HoprBalance,
239 pub neglected_value: HoprBalance,
240 pub rejected_value: HoprBalance,
241}
242
243#[cfg(test)]
244mod tests {
245 use std::{sync::Arc, thread};
246
247 use super::*;
248 use crate::MemoryStore;
249
250 const MAX: u64 = TicketBuilder::MAX_TICKET_INDEX;
251
252 fn store() -> Arc<parking_lot::RwLock<MemoryStore>> {
253 Arc::new(parking_lot::RwLock::new(MemoryStore::default()))
254 }
255
256 #[test]
257 fn default_initializes_to_zero_and_dirty() {
258 let e = OutgoingIndexEntry::default();
259 assert_eq!(e.get(), 0);
260 assert!(e.is_dirty());
261 }
262
263 #[test]
264 fn increment_saturates_return_value_at_max() {
265 let e = OutgoingIndexEntry::new(MAX);
266 assert_eq!(e.increment(), MAX);
267 assert_eq!(e.get(), MAX);
268 }
269
270 #[test]
271 fn set_does_not_decrease_value_when_new_value_is_lower() {
272 let e = OutgoingIndexEntry::new(20);
273 e.mark_clean();
274 assert_eq!(e.set(10), 20);
275 assert_eq!(e.get(), 20);
276 assert!(!e.is_dirty());
277 }
278
279 #[test]
280 fn concurrent_set_uses_max_semantics() {
281 let e = Arc::new(OutgoingIndexEntry::new(0));
282 e.mark_clean();
283
284 let vals = [1u64, 7, 3, 42, 9];
285 let mut handles = vec![];
286
287 for v in vals {
288 let e2 = Arc::clone(&e);
289 handles.push(thread::spawn(move || {
290 e2.set(v);
291 }));
292 }
293
294 for h in handles {
295 h.join().unwrap();
296 }
297
298 assert_eq!(e.get(), 42.min(MAX));
299 assert!(e.is_dirty());
300 }
301
302 #[test]
303 fn next_creates_entry_with_zero_and_increments_sequentially() {
304 let cache = OutgoingIndexCache::default();
305 let channel_id = Default::default();
306 let epoch = 1;
307
308 assert_eq!(cache.next(&channel_id, epoch), 0);
309 assert_eq!(cache.next(&channel_id, epoch), 1);
310 assert_eq!(cache.next(&channel_id, epoch), 2);
311 }
312
313 #[test]
314 fn next_is_scoped_by_channel_and_epoch() {
315 let cache = OutgoingIndexCache::default();
316 let channel_a = Default::default();
317 let channel_b = ChannelId::create(&[b"other"]);
318
319 assert_eq!(cache.next(&channel_a, 1), 0);
320 assert_eq!(cache.next(&channel_a, 2), 0);
321 assert_eq!(cache.next(&channel_b, 1), 0);
322 assert_eq!(cache.next(&channel_a, 1), 1);
323 }
324
325 #[test]
326 fn set_inserts_when_key_is_missing() {
327 let cache = OutgoingIndexCache::default();
328 let channel_id = Default::default();
329
330 assert_eq!(cache.upsert(&channel_id, 1, 17), 17);
331 assert_eq!(cache.next(&channel_id, 1), 17);
332 assert_eq!(cache.next(&channel_id, 1), 18);
333 }
334
335 #[test]
336 fn set_does_not_decrease_existing_value() {
337 let cache = OutgoingIndexCache::default();
338 let channel_id = Default::default();
339
340 assert_eq!(cache.upsert(&channel_id, 1, 10), 10);
341 assert_eq!(cache.upsert(&channel_id, 1, 7), 10);
342 assert_eq!(cache.next(&channel_id, 1), 10);
343 assert_eq!(cache.next(&channel_id, 1), 11);
344 }
345
346 #[test]
347 fn next_saturates_at_max() {
348 let cache = OutgoingIndexCache::default();
349 let channel_id = Default::default();
350
351 assert_eq!(cache.upsert(&channel_id, 1, MAX), MAX);
352 assert_eq!(cache.next(&channel_id, 1), MAX);
353 assert_eq!(cache.next(&channel_id, 1), MAX);
354 }
355
356 #[test]
357 fn remove_existing_entry_returns_true_and_persists_deletion_on_save() -> anyhow::Result<()> {
358 let cache = OutgoingIndexCache::default();
359 let channel_id = Default::default();
360 let epoch = 1;
361 let store = store();
362
363 assert_eq!(cache.upsert(&channel_id, epoch, 5), 5);
364 cache.save(store.clone())?;
365
366 assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, Some(5));
367
368 assert!(cache.remove(&channel_id, epoch));
369 assert!(cache.save(store.clone()).is_ok());
370
371 assert_eq!(store.read().load_outgoing_index(&channel_id, epoch)?, None);
372 assert!(!cache.remove(&channel_id, epoch));
373
374 Ok(())
375 }
376
377 #[test]
378 fn remove_missing_entry_returns_false() {
379 let cache = OutgoingIndexCache::default();
380 let channel_id = Default::default();
381
382 assert!(!cache.remove(&channel_id, 1));
383 }
384
385 #[test]
386 fn save_persists_only_dirty_entries() -> anyhow::Result<()> {
387 let cache = OutgoingIndexCache::default();
388 let channel_a = Default::default();
389 let channel_b = ChannelId::create(&[b"other"]);
390 let store = store();
391
392 cache.upsert(&channel_a, 1, 10);
393 cache.upsert(&channel_b, 2, 20);
394
395 cache.save(store.clone())?;
396 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
397 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
398
399 cache.save(store.clone())?;
400 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(10));
401 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
402
403 cache.next(&channel_a, 1);
404 cache.save(store.clone())?;
405 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, Some(11));
406 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(20));
407
408 Ok(())
409 }
410
411 #[test]
412 fn save_persists_removed_entries_only_once() -> anyhow::Result<()> {
413 let cache = OutgoingIndexCache::default();
414 let channel_id = Default::default();
415 let store = store();
416
417 cache.upsert(&channel_id, 1, 3);
418 cache.save(store.clone())?;
419 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(3));
420
421 assert!(cache.remove(&channel_id, 1));
422 cache.save(store.clone())?;
423 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
424
425 cache.save(store.clone())?;
426 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, None);
427
428 Ok(())
429 }
430
431 #[test]
432 fn save_is_idempotent_after_success() -> anyhow::Result<()> {
433 let cache = OutgoingIndexCache::default();
434 let channel_id = Default::default();
435 let store = store();
436
437 cache.upsert(&channel_id, 1, 9);
438 cache.save(store.clone())?;
439 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
440
441 cache.save(store.clone())?;
442 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(9));
443
444 cache.next(&channel_id, 1);
445 cache.save(store.clone())?;
446 assert_eq!(store.read().load_outgoing_index(&channel_id, 1)?, Some(10));
447
448 Ok(())
449 }
450
451 #[test]
452 fn concurrent_next_on_same_key_is_monotonic() -> anyhow::Result<()> {
453 use std::thread;
454
455 let cache = Arc::new(OutgoingIndexCache::default());
456 let channel_id = Default::default();
457 let epoch = 1;
458
459 let mut handles = vec![];
460 for _ in 0..8 {
461 let cache = Arc::clone(&cache);
462 handles.push(thread::spawn(move || cache.next(&channel_id, epoch)));
463 }
464
465 let mut values = handles
466 .into_iter()
467 .map(|h| h.join())
468 .collect::<Result<Vec<_>, _>>()
469 .map_err(|_| anyhow::anyhow!("join error"))?;
470 values.sort_unstable();
471
472 assert_eq!(values, (0..8).collect::<Vec<_>>());
473 assert_eq!(cache.next(&channel_id, epoch), 8);
474
475 Ok(())
476 }
477
478 #[test]
479 fn save_handles_multiple_keys_and_removals_together() -> anyhow::Result<()> {
480 let cache = OutgoingIndexCache::default();
481 let channel_a = Default::default();
482 let channel_b = ChannelId::create(&[b"other"]);
483 let store = store();
484
485 cache.upsert(&channel_a, 1, 4);
486 cache.upsert(&channel_b, 2, 7);
487 cache.save(store.clone())?;
488
489 assert!(cache.remove(&channel_a, 1));
490 cache.next(&channel_b, 2);
491 cache.save(store.clone())?;
492
493 assert_eq!(store.read().load_outgoing_index(&channel_a, 1)?, None);
494 assert_eq!(store.read().load_outgoing_index(&channel_b, 2)?, Some(8));
495
496 Ok(())
497 }
498}