1use std::{
29 collections::HashMap,
30 fmt::{Debug, Display, Formatter},
31 sync::Arc,
32};
33
34use async_lock::{RwLock, RwLockUpgradableReadGuardArc};
35use async_trait::async_trait;
36use hopr_async_runtime::prelude::{JoinHandle, spawn};
37use hopr_crypto_types::prelude::Hash;
38use hopr_db_sql::{
39 api::tickets::{AggregationPrerequisites, HoprDbTicketOperations},
40 channels::HoprDbChannelOperations,
41};
42use hopr_internal_types::prelude::*;
43#[cfg(all(feature = "prometheus", not(test)))]
44use hopr_metrics::metrics::SimpleCounter;
45use hopr_transport_ticket_aggregation::TicketAggregatorTrait;
46use serde::{Deserialize, Serialize, Serializer};
47use serde_with::serde_as;
48use tracing::{debug, error, info, warn};
49use validator::Validate;
50
51use crate::{Strategy, strategy::SingularStrategy};
52
53#[cfg(all(feature = "prometheus", not(test)))]
54lazy_static::lazy_static! {
55 static ref METRIC_COUNT_AGGREGATIONS: SimpleCounter =
56 SimpleCounter::new("hopr_strategy_aggregating_aggregation_count", "Count of initiated automatic aggregations").unwrap();
57}
58
59use hopr_platform::time::native::current_time;
60
61const MAX_AGGREGATABLE_TICKET_COUNT: u32 = hopr_db_sql::tickets::MAX_TICKETS_TO_AGGREGATE_BATCH as u32;
62
63#[inline]
64fn default_aggregation_threshold() -> Option<u32> {
65 Some(250)
66}
67
68#[inline]
69fn just_true() -> bool {
70 true
71}
72
73#[inline]
74fn default_unrealized_balance_ratio() -> Option<f64> {
75 Some(0.9)
76}
77
78fn serialize_optional_f64<S>(x: &Option<f64>, s: S) -> Result<S::Ok, S::Error>
79where
80 S: Serializer,
81{
82 if let Some(v) = x {
83 s.serialize_f64(*v)
84 } else {
85 s.serialize_none()
86 }
87}
88
89#[serde_as]
91#[derive(Debug, Clone, Copy, PartialEq, smart_default::SmartDefault, Validate, Serialize, Deserialize)]
92pub struct AggregatingStrategyConfig {
93 #[validate(range(min = 2, max = MAX_AGGREGATABLE_TICKET_COUNT))]
100 #[serde(default = "default_aggregation_threshold")]
101 #[default(default_aggregation_threshold())]
102 pub aggregation_threshold: Option<u32>,
103
104 #[validate(range(min = 0_f64, max = 1.0_f64))]
112 #[default(default_unrealized_balance_ratio())]
113 #[serde(serialize_with = "serialize_optional_f64")]
114 pub unrealized_balance_ratio: Option<f64>,
115
116 #[default(just_true())]
124 pub aggregate_on_channel_close: bool,
125}
126
127impl From<AggregatingStrategyConfig> for AggregationPrerequisites {
128 fn from(value: AggregatingStrategyConfig) -> Self {
129 AggregationPrerequisites {
130 min_ticket_count: value.aggregation_threshold.map(|x| x as usize),
131 min_unaggregated_ratio: value.unrealized_balance_ratio,
132 }
133 }
134}
135
136pub struct AggregatingStrategy<Db>
142where
143 Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
144{
145 db: Db,
146 ticket_aggregator: Arc<dyn TicketAggregatorTrait + Send + Sync + 'static>,
147 cfg: AggregatingStrategyConfig,
148 #[allow(clippy::type_complexity)]
149 agg_tasks: Arc<RwLock<HashMap<Hash, (bool, JoinHandle<()>)>>>,
150}
151
152impl<Db> Debug for AggregatingStrategy<Db>
153where
154 Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
155{
156 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157 write!(f, "{:?}", Strategy::Aggregating(self.cfg))
158 }
159}
160
161impl<Db> Display for AggregatingStrategy<Db>
162where
163 Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
164{
165 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166 write!(f, "{}", Strategy::Aggregating(self.cfg))
167 }
168}
169
170impl<Db> AggregatingStrategy<Db>
171where
172 Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
173{
174 pub fn new(
175 cfg: AggregatingStrategyConfig,
176 db: Db,
177 ticket_aggregator: Arc<dyn TicketAggregatorTrait + Send + Sync + 'static>,
178 ) -> Self {
179 Self {
180 db,
181 cfg,
182 ticket_aggregator,
183 agg_tasks: Arc::new(RwLock::new(HashMap::new())),
184 }
185 }
186}
187
188impl<Db> AggregatingStrategy<Db>
189where
190 Db: HoprDbChannelOperations + HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug + 'static,
191{
192 async fn try_start_aggregation(
193 &self,
194 channel_id: Hash,
195 criteria: AggregationPrerequisites,
196 ) -> crate::errors::Result<()> {
197 if !self.is_strategy_aggregating_in_channel(channel_id).await {
198 debug!("checking aggregation in {channel_id} with criteria {criteria:?}...");
199
200 let agg_tasks_clone = self.agg_tasks.clone();
201 let aggregator_clone = self.ticket_aggregator.clone();
202 let (can_remove_tx, can_remove_rx) = futures::channel::oneshot::channel();
203 let task = spawn(async move {
204 match aggregator_clone.aggregate_tickets(&channel_id, criteria).await {
205 Ok(_) => {
206 debug!(%channel_id, "aggregation attempted without issues for a channel");
207
208 #[cfg(all(feature = "prometheus", not(test)))]
209 METRIC_COUNT_AGGREGATIONS.increment();
210 }
211 Err(error) => {
212 error!(%channel_id, %error, "aggregation failed to complete for a channel");
213 }
214 }
215
216 let _ = can_remove_rx.await;
218 if let Some((done, _)) = agg_tasks_clone.write_arc().await.get_mut(&channel_id) {
219 *done = true;
220 }
221 });
222
223 self.agg_tasks.write_arc().await.insert(channel_id, (false, task));
224 let _ = can_remove_tx.send(()); } else {
226 warn!(channel = %channel_id, "this strategy already aggregates in channel");
227 }
228
229 Ok(())
230 }
231
232 async fn is_strategy_aggregating_in_channel(&self, channel_id: Hash) -> bool {
233 let tasks_read_locked = self.agg_tasks.upgradable_read_arc().await;
234 let existing = tasks_read_locked.get(&channel_id).map(|(done, _)| *done);
235 if let Some(done) = existing {
236 if done {
238 let mut tasks_write_locked = RwLockUpgradableReadGuardArc::upgrade(tasks_read_locked).await;
239
240 if let Some((_, task)) = tasks_write_locked.remove(&channel_id) {
241 let _ = task.await;
243 false
244 } else {
245 false
247 }
248 } else {
249 true
251 }
252 } else {
253 false
255 }
256 }
257}
258
259#[async_trait]
260impl<Db> SingularStrategy for AggregatingStrategy<Db>
261where
262 Db: HoprDbChannelOperations + HoprDbTicketOperations + Clone + Send + Sync + std::fmt::Debug + 'static,
263{
264 async fn on_tick(&self) -> crate::errors::Result<()> {
265 let incoming = self
266 .db
267 .get_incoming_channels(None)
268 .await
269 .map_err(hopr_db_sql::api::errors::DbError::from)?
270 .into_iter()
271 .filter(|c| !c.closure_time_passed(current_time()))
272 .map(|c| c.get_id());
273
274 for channel_id in incoming {
275 if let Err(e) = self.try_start_aggregation(channel_id, self.cfg.into()).await {
276 debug!("skipped aggregation in channel {channel_id}: {e}");
277 }
278 }
279
280 Ok(())
281 }
282
283 async fn on_own_channel_changed(
284 &self,
285 channel: &ChannelEntry,
286 direction: ChannelDirection,
287 change: ChannelChange,
288 ) -> crate::errors::Result<()> {
289 if !self.cfg.aggregate_on_channel_close || direction != ChannelDirection::Incoming {
290 return Ok(());
291 }
292
293 if let ChannelChange::Status { left: old, right: new } = change {
294 if old != ChannelStatus::Open || !matches!(new, ChannelStatus::PendingToClose(_)) {
295 debug!("ignoring channel {channel} state change that's not in PendingToClose");
296 return Ok(());
297 }
298
299 info!(%channel, "going to aggregate tickets in channel because it transitioned to PendingToClose");
300
301 let on_close_agg_prerequisites = AggregationPrerequisites {
303 min_ticket_count: Some(2),
304 min_unaggregated_ratio: None,
305 };
306
307 Ok(self
308 .try_start_aggregation(channel.get_id(), on_close_agg_prerequisites)
309 .await?)
310 } else {
311 Ok(())
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use std::{pin::pin, sync::Arc, time::Duration};
319
320 use anyhow::Context;
321 use futures::{FutureExt, StreamExt, pin_mut};
322 use hex_literal::hex;
323 use hopr_crypto_types::prelude::*;
324 use hopr_db_sql::{
325 HoprDbGeneralModelOperations, TargetDb,
326 accounts::HoprDbAccountOperations,
327 api::{info::DomainSeparator, tickets::HoprDbTicketOperations},
328 channels::HoprDbChannelOperations,
329 db::HoprDb,
330 errors::DbSqlError,
331 info::HoprDbInfoOperations,
332 };
333 use hopr_internal_types::prelude::*;
334 use hopr_primitive_types::prelude::*;
335 use hopr_transport_ticket_aggregation::{
336 AwaitingAggregator, TicketAggregationInteraction, TicketAggregationProcessed,
337 };
338 use lazy_static::lazy_static;
339 use tokio::time::timeout;
340 use tracing::{debug, error};
341
342 use crate::{
343 aggregating::{MAX_AGGREGATABLE_TICKET_COUNT, default_aggregation_threshold},
344 strategy::SingularStrategy,
345 };
346
347 #[test]
348 fn default_ticket_aggregation_count_is_lower_than_maximum_allowed_ticket_count() -> anyhow::Result<()> {
349 assert!(default_aggregation_threshold().unwrap() < MAX_AGGREGATABLE_TICKET_COUNT);
350
351 Ok(())
352 }
353
354 lazy_static! {
355 static ref PEERS: Vec<OffchainKeypair> = [
356 hex!("b91a28ff9840e9c93e5fafd581131f0b9f33f3e61b02bf5dd83458aa0221f572"),
357 hex!("82283757872f99541ce33a47b90c2ce9f64875abf08b5119a8a434b2fa83ea98")
358 ]
359 .iter()
360 .map(|private| OffchainKeypair::from_secret(private).expect("lazy static keypair should be valid"))
361 .collect();
362 static ref PEERS_CHAIN: Vec<ChainKeypair> = [
363 hex!("51d3003d908045a4d76d0bfc0d84f6ff946b5934b7ea6a2958faf02fead4567a"),
364 hex!("e1f89073a01831d0eed9fe2c67e7d65c144b9d9945320f6d325b1cccc2d124e9")
365 ]
366 .iter()
367 .map(|private| ChainKeypair::from_secret(private).expect("lazy static keypair should be valid"))
368 .collect();
369 }
370
371 fn mock_acknowledged_ticket(
372 signer: &ChainKeypair,
373 destination: &ChainKeypair,
374 index: u64,
375 index_offset: u32,
376 ) -> anyhow::Result<AcknowledgedTicket> {
377 let price_per_packet: U256 = 20_u32.into();
378 let ticket_win_prob = 1.0f64;
379
380 let channel_id = generate_channel_id(&signer.into(), &destination.into());
381
382 let channel_epoch = 1u64;
383 let domain_separator = Hash::default();
384
385 let response = Response::try_from(
386 Hash::create(&[channel_id.as_ref(), &channel_epoch.to_be_bytes(), &index.to_be_bytes()]).as_ref(),
387 )?;
388
389 Ok(TicketBuilder::default()
390 .addresses(signer, destination)
391 .amount(price_per_packet.div_f64(ticket_win_prob)?)
392 .index(index)
393 .index_offset(index_offset)
394 .win_prob(ticket_win_prob.try_into()?)
395 .channel_epoch(1)
396 .challenge(response.to_challenge()?)
397 .build_signed(signer, &domain_separator)?
398 .into_acknowledged(response))
399 }
400
401 async fn populate_db_with_ack_tickets(
402 db: HoprDb,
403 amount: usize,
404 ) -> anyhow::Result<(Vec<AcknowledgedTicket>, ChannelEntry)> {
405 let db_clone = db.clone();
406 let (acked_tickets, total_value) = db
407 .begin_transaction_in_db(TargetDb::Tickets)
408 .await?
409 .perform(|tx| {
410 Box::pin(async move {
411 let mut acked_tickets = Vec::new();
412 let mut total_value = HoprBalance::zero();
413
414 for i in 0..amount {
415 let acked_ticket = mock_acknowledged_ticket(&PEERS_CHAIN[0], &PEERS_CHAIN[1], i as u64, 1)
416 .expect("should be able to create an ack ticket");
417 debug!("inserting {acked_ticket}");
418
419 db_clone.upsert_ticket(Some(tx), acked_ticket.clone()).await?;
420
421 total_value += acked_ticket.verified_ticket().amount;
422 acked_tickets.push(acked_ticket);
423 }
424
425 Ok::<_, DbSqlError>((acked_tickets, total_value))
426 })
427 })
428 .await?;
429
430 let channel = ChannelEntry::new(
431 (&PEERS_CHAIN[0]).into(),
432 (&PEERS_CHAIN[1]).into(),
433 total_value,
434 0_u32.into(),
435 ChannelStatus::Open,
436 1u32.into(),
437 );
438
439 Ok((acked_tickets, channel))
440 }
441
442 async fn init_db(db: HoprDb) -> anyhow::Result<()> {
443 let db_clone = db.clone();
444 db.begin_transaction()
445 .await?
446 .perform(|tx| {
447 Box::pin(async move {
448 db_clone
449 .set_domain_separator(Some(tx), DomainSeparator::Channel, Hash::default())
450 .await?;
451 for i in 0..PEERS_CHAIN.len() {
452 debug!(
453 "linking {} with {}",
454 PEERS[i].public(),
455 PEERS_CHAIN[i].public().to_address()
456 );
457 db_clone
458 .insert_account(
459 Some(tx),
460 AccountEntry {
461 public_key: *PEERS[i].public(),
462 chain_addr: PEERS_CHAIN[i].public().to_address(),
463 entry_type: AccountType::NotAnnounced,
464 published_at: 1,
465 },
466 )
467 .await?;
468 }
469 Ok::<_, DbSqlError>(())
470 })
471 })
472 .await?;
473
474 Ok(())
475 }
476
477 fn spawn_aggregation_interaction(
478 db_alice: HoprDb,
479 db_bob: HoprDb,
480 key_alice: &ChainKeypair,
481 key_bob: &ChainKeypair,
482 ) -> anyhow::Result<(
483 AwaitingAggregator<(), (), HoprDb>,
484 futures::channel::oneshot::Receiver<()>,
485 )> {
486 let mut alice = TicketAggregationInteraction::<(), ()>::new(db_alice, key_alice);
487 let mut bob = TicketAggregationInteraction::<(), ()>::new(db_bob.clone(), key_bob);
488
489 let (tx, awaiter) = futures::channel::oneshot::channel::<()>();
490 let bob_aggregator = bob.writer();
491
492 tokio::task::spawn(async move {
493 let mut finalizer = None;
494
495 match bob.next().await {
496 Some(TicketAggregationProcessed::Send(_, acked_tickets, request_finalizer)) => {
497 let _ = finalizer.insert(request_finalizer);
498 match alice.writer().receive_aggregation_request(
499 PEERS[1].public().into(),
500 acked_tickets.into_iter().collect(),
501 (),
502 ) {
503 Ok(_) => {}
504 Err(e) => error!(error = %e, "Failed to received aggregation ticket"),
505 }
506 }
507 _ => panic!("unexpected action happened"),
510 };
511
512 match alice.next().await {
513 Some(TicketAggregationProcessed::Reply(_, aggregated_ticket, ())) => {
514 match bob
515 .writer()
516 .receive_ticket(PEERS[0].public().into(), aggregated_ticket, ())
517 {
518 Ok(_) => {}
519 Err(e) => error!(error = %e, "Failed to receive a ticket"),
520 }
521 }
522
523 _ => panic!("unexpected action happened"),
524 };
525
526 match bob.next().await {
527 Some(TicketAggregationProcessed::Receive(_destination, _ticket, ())) => (),
528 _ => panic!("unexpected action happened"),
529 };
530
531 finalizer.expect("should have a value present").finalize();
532 let _ = tx.send(());
533 });
534
535 Ok((
536 AwaitingAggregator::new(db_bob, bob_aggregator, Duration::from_secs(5)),
537 awaiter,
538 ))
539 }
540
541 #[tokio::test]
542 async fn test_strategy_aggregation_on_tick() -> anyhow::Result<()> {
543 let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
546 let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
547
548 init_db(db_alice.clone()).await?;
549 init_db(db_bob.clone()).await?;
550
551 let (bob_notify_tx, bob_notify_rx) = futures::channel::mpsc::unbounded();
552 db_bob.start_ticket_processing(bob_notify_tx.into())?;
553
554 let (_, channel) = populate_db_with_ack_tickets(db_bob.clone(), 5).await?;
555
556 db_alice.upsert_channel(None, channel).await?;
557 db_bob.upsert_channel(None, channel).await?;
558
559 let (bob_aggregator, awaiter) =
560 spawn_aggregation_interaction(db_alice.clone(), db_bob.clone(), &PEERS_CHAIN[0], &PEERS_CHAIN[1])?;
561
562 let cfg = super::AggregatingStrategyConfig {
563 aggregation_threshold: Some(5),
564 unrealized_balance_ratio: None,
565 aggregate_on_channel_close: false,
566 };
567
568 let aggregation_strategy = super::AggregatingStrategy::new(cfg, db_bob.clone(), Arc::new(bob_aggregator));
569
570 aggregation_strategy.on_tick().await?;
572
573 let f1 = pin!(awaiter);
575 let f2 = pin!(tokio::time::sleep(Duration::from_secs(5)).fuse());
576 let _ = futures::future::select(f1, f2).await;
577
578 pin_mut!(bob_notify_rx);
579 let notified_ticket = bob_notify_rx.next().await.expect("should have a ticket");
580
581 let tickets = db_bob.get_tickets((&channel).into()).await?;
582 assert_eq!(tickets.len(), 1, "there should be a single aggregated ticket");
583 assert_eq!(notified_ticket, tickets[0]);
584
585 Ok(())
586 }
587
588 #[tokio::test]
589 async fn test_strategy_aggregation_on_tick_when_unrealized_balance_exceeded() -> anyhow::Result<()> {
590 let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
593 let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
594
595 init_db(db_alice.clone()).await?;
596 init_db(db_bob.clone()).await?;
597
598 let (bob_notify_tx, bob_notify_rx) = futures::channel::mpsc::unbounded();
599 db_bob.start_ticket_processing(bob_notify_tx.into())?;
600
601 let (_, channel) = populate_db_with_ack_tickets(db_bob.clone(), 4).await?;
602
603 db_alice.upsert_channel(None, channel).await?;
604 db_bob.upsert_channel(None, channel).await?;
605
606 let (bob_aggregator, awaiter) =
607 spawn_aggregation_interaction(db_alice.clone(), db_bob.clone(), &PEERS_CHAIN[0], &PEERS_CHAIN[1])?;
608
609 let cfg = super::AggregatingStrategyConfig {
610 aggregation_threshold: None,
611 unrealized_balance_ratio: Some(0.75),
612 aggregate_on_channel_close: false,
613 };
614
615 let aggregation_strategy = super::AggregatingStrategy::new(cfg, db_bob.clone(), Arc::new(bob_aggregator));
616
617 aggregation_strategy.on_tick().await?;
619
620 let f1 = pin!(awaiter);
622 let f2 = pin!(tokio::time::sleep(Duration::from_secs(5)));
623 let _ = futures::future::select(f1, f2).await;
624
625 pin_mut!(bob_notify_rx);
626 let notified_ticket = bob_notify_rx.next().await.expect("should have a ticket");
627
628 let tickets = db_bob.get_tickets((&channel).into()).await?;
629 assert_eq!(tickets.len(), 1, "there should be a single aggregated ticket");
630 assert_eq!(notified_ticket, tickets[0]);
631
632 Ok(())
633 }
634
635 #[tokio::test]
636 async fn test_strategy_aggregation_on_tick_should_not_agg_when_unrealized_balance_exceeded_via_aggregated_tickets()
637 -> anyhow::Result<()> {
638 let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
641 let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
642
643 init_db(db_alice.clone()).await?;
644 init_db(db_bob.clone()).await?;
645
646 db_bob.start_ticket_processing(None)?;
647
648 const NUM_TICKETS: usize = 4;
649 let (mut acked_tickets, mut channel) = populate_db_with_ack_tickets(db_bob.clone(), NUM_TICKETS).await?;
650
651 let (bob_aggregator, awaiter) =
652 spawn_aggregation_interaction(db_alice.clone(), db_bob.clone(), &PEERS_CHAIN[0], &PEERS_CHAIN[1])?;
653
654 acked_tickets[0] = mock_acknowledged_ticket(&PEERS_CHAIN[0], &PEERS_CHAIN[1], 0, 2)?;
656
657 debug!("upserting {}", acked_tickets[0]);
658 db_bob.upsert_ticket(None, acked_tickets[0].clone()).await?;
659
660 let tickets = db_bob.get_tickets((&channel).into()).await?;
661 assert_eq!(tickets.len(), NUM_TICKETS, "nothing should be aggregated");
662
663 channel.balance = HoprBalance::from(100_u32);
664
665 db_alice.upsert_channel(None, channel).await?;
666 db_bob.upsert_channel(None, channel).await?;
667
668 let cfg = super::AggregatingStrategyConfig {
669 aggregation_threshold: None,
670 unrealized_balance_ratio: Some(0.75),
671 aggregate_on_channel_close: false,
672 };
673
674 let aggregation_strategy = super::AggregatingStrategy::new(cfg, db_bob.clone(), Arc::new(bob_aggregator));
675
676 aggregation_strategy.on_tick().await?;
678
679 let tickets = db_bob.get_tickets((&channel).into()).await?;
680 assert_eq!(tickets.len(), NUM_TICKETS, "nothing should be aggregated");
681 std::mem::drop(awaiter);
682
683 Ok(())
684 }
685
686 #[tokio::test]
687 async fn test_strategy_aggregation_on_channel_close() -> anyhow::Result<()> {
688 let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
691 let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
692
693 init_db(db_alice.clone()).await?;
694 init_db(db_bob.clone()).await?;
695
696 let (bob_notify_tx, bob_notify_rx) = futures::channel::mpsc::unbounded();
697 db_bob.start_ticket_processing(bob_notify_tx.into())?;
698
699 let (_, mut channel) = populate_db_with_ack_tickets(db_bob.clone(), 5).await?;
700
701 let cfg = super::AggregatingStrategyConfig {
702 aggregation_threshold: Some(100),
703 unrealized_balance_ratio: None,
704 aggregate_on_channel_close: true,
705 };
706
707 channel.status = ChannelStatus::PendingToClose(std::time::SystemTime::now());
708
709 db_alice.upsert_channel(None, channel).await?;
710 db_bob.upsert_channel(None, channel).await?;
711
712 let (bob_aggregator, awaiter) =
713 spawn_aggregation_interaction(db_alice.clone(), db_bob.clone(), &PEERS_CHAIN[0], &PEERS_CHAIN[1])?;
714
715 let aggregation_strategy = super::AggregatingStrategy::new(cfg, db_alice.clone(), Arc::new(bob_aggregator));
716
717 aggregation_strategy
718 .on_own_channel_changed(
719 &channel,
720 ChannelDirection::Incoming,
721 ChannelChange::Status {
722 left: ChannelStatus::Open,
723 right: ChannelStatus::PendingToClose(std::time::SystemTime::now()),
724 },
725 )
726 .await?;
727
728 timeout(Duration::from_secs(5), awaiter).await.context("Timeout")??;
730
731 pin_mut!(bob_notify_rx);
732 let notified_ticket = bob_notify_rx.next().await.expect("should have a ticket");
733
734 let tickets = db_bob.get_tickets((&channel).into()).await?;
735 assert_eq!(tickets.len(), 1, "there should be a single aggregated ticket");
736 assert_eq!(notified_ticket, tickets[0]);
737
738 Ok(())
739 }
740
741 #[tokio::test]
742 async fn test_strategy_aggregation_on_tick_should_not_agg_on_channel_close_if_only_single_ticket()
743 -> anyhow::Result<()> {
744 let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
747 let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
748
749 init_db(db_alice.clone()).await?;
750 init_db(db_bob.clone()).await?;
751
752 db_bob.start_ticket_processing(None)?;
753
754 const NUM_TICKETS: usize = 1;
755 let (_, channel) = populate_db_with_ack_tickets(db_bob.clone(), NUM_TICKETS).await?;
756
757 let (bob_aggregator, awaiter) =
758 spawn_aggregation_interaction(db_alice.clone(), db_bob.clone(), &PEERS_CHAIN[0], &PEERS_CHAIN[1])?;
759
760 let tickets = db_bob.get_tickets((&channel).into()).await?;
761 assert_eq!(tickets.len(), NUM_TICKETS, "should have a single ticket");
762
763 db_alice.upsert_channel(None, channel).await?;
764 db_bob.upsert_channel(None, channel).await?;
765
766 let cfg = super::AggregatingStrategyConfig {
767 aggregation_threshold: Some(100),
768 unrealized_balance_ratio: Some(0.75),
769 aggregate_on_channel_close: true,
770 };
771
772 let aggregation_strategy = super::AggregatingStrategy::new(cfg, db_bob.clone(), Arc::new(bob_aggregator));
773
774 aggregation_strategy
775 .on_own_channel_changed(
776 &channel,
777 ChannelDirection::Incoming,
778 ChannelChange::Status {
779 left: ChannelStatus::Open,
780 right: ChannelStatus::PendingToClose(std::time::SystemTime::now()),
781 },
782 )
783 .await?;
784
785 timeout(Duration::from_millis(500), awaiter)
786 .await
787 .expect_err("should timeout");
788
789 let tickets = db_bob.get_tickets((&channel).into()).await?;
790 assert_eq!(tickets.len(), NUM_TICKETS, "nothing should be aggregated");
791 Ok(())
792 }
793}