1use std::{cmp::Ordering, collections::BinaryHeap, marker::PhantomData, time::Duration};
2
3use async_trait::async_trait;
4use hopr_crypto_random::random_float;
5use hopr_internal_types::prelude::*;
6use hopr_primitive_types::prelude::*;
7use tracing::trace;
8
9use crate::{
10 ChannelPath,
11 channel_graph::{ChannelEdge, ChannelGraph, Node},
12 errors::{PathError, Result},
13 selectors::{EdgeWeighting, PathSelector},
14};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
18struct WeightedChannelPath {
19 path: Vec<Address>,
20 weight: U256,
21 fully_explored: bool,
22}
23
24impl WeightedChannelPath {
25 fn extend<CW: EdgeWeighting<U256>>(mut self, edge: &ChannelEdge) -> Self {
27 if !self.fully_explored {
28 self.path.push(edge.channel.destination);
29 self.weight += CW::calculate_weight(edge);
30 }
31 self
32 }
33}
34
35impl Default for WeightedChannelPath {
36 fn default() -> Self {
37 Self {
38 path: Vec::with_capacity(INTERMEDIATE_HOPS),
39 weight: U256::zero(),
40 fully_explored: false,
41 }
42 }
43}
44
45impl PartialOrd for WeightedChannelPath {
46 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
47 Some(self.cmp(other))
48 }
49}
50
51impl Ord for WeightedChannelPath {
52 fn cmp(&self, other: &Self) -> Ordering {
61 if other.fully_explored == self.fully_explored {
62 match self.path.len().cmp(&other.path.len()) {
63 Ordering::Equal => self.weight.cmp(&other.weight),
64 o => o,
65 }
66 } else if other.fully_explored {
67 Ordering::Greater
68 } else {
69 Ordering::Less
70 }
71 }
72}
73
74#[derive(Clone, Copy, Debug, Default)]
80pub struct RandomizedEdgeWeighting;
81
82impl EdgeWeighting<U256> for RandomizedEdgeWeighting {
83 fn calculate_weight(edge: &ChannelEdge) -> U256 {
90 edge.channel
91 .balance
92 .amount()
93 .mul_f64(random_float())
94 .expect("Could not multiply edge weight with float")
95 .max(1.into())
96 }
97}
98
99#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
100pub struct DfsPathSelectorConfig {
101 #[default(100)]
104 pub max_iterations: usize,
105 #[default(0.5)]
108 pub node_score_threshold: f64,
109 #[default(0.0)]
112 pub edge_score_threshold: f64,
113 #[default(Some(Duration::from_millis(100)))]
116 pub max_first_hop_latency: Option<Duration>,
117 #[default(false)]
122 pub allow_zero_edge_weight: bool,
123}
124
125#[derive(Clone, Debug)]
127pub struct DfsPathSelector<CW> {
128 graph: std::sync::Arc<async_lock::RwLock<ChannelGraph>>,
129 cfg: DfsPathSelectorConfig,
130 _cw: PhantomData<CW>,
131}
132
133impl<CW: EdgeWeighting<U256>> DfsPathSelector<CW> {
134 pub fn new(graph: std::sync::Arc<async_lock::RwLock<ChannelGraph>>, cfg: DfsPathSelectorConfig) -> Self {
137 Self {
138 graph,
139 cfg,
140 _cw: PhantomData,
141 }
142 }
143
144 #[tracing::instrument(level = "trace", skip(self))]
159 fn is_next_hop_usable(
160 &self,
161 next_hop: &Node,
162 edge: &ChannelEdge,
163 initial_source: &Address,
164 final_destination: &Address,
165 current_path: &[Address],
166 ) -> bool {
167 debug_assert_eq!(next_hop.address, edge.channel.destination);
168
169 if next_hop.address.eq(initial_source) {
171 trace!("source loopback not allowed");
172 return false;
173 }
174
175 if next_hop.address.eq(final_destination) {
178 trace!("destination loopback not allowed");
179 return false;
180 }
181
182 if next_hop.node_score < self.cfg.node_score_threshold {
184 trace!("node quality threshold not satisfied");
185 return false;
186 }
187
188 if edge
190 .edge_score
191 .is_some_and(|score| score < self.cfg.edge_score_threshold)
192 {
193 trace!("channel score threshold not satisfied");
194 return false;
195 }
196
197 if current_path.is_empty()
199 && self
200 .cfg
201 .max_first_hop_latency
202 .is_some_and(|limit| next_hop.latency.average().is_none_or(|avg_latency| avg_latency > limit))
203 {
204 trace!("first hop latency too high");
205 return false;
206 }
207
208 if current_path.contains(&next_hop.address) {
211 trace!("circles not allowed");
212 return false;
213 }
214
215 if !self.cfg.allow_zero_edge_weight && edge.channel.balance.is_zero() {
217 trace!(%next_hop, "zero stake channels not allowed");
218 return false;
219 }
220
221 trace!("usable node");
222 true
223 }
224}
225
226#[async_trait]
227impl<CW> PathSelector for DfsPathSelector<CW>
228where
229 CW: EdgeWeighting<U256> + Send + Sync,
230{
231 async fn select_path(
239 &self,
240 source: Address,
241 destination: Address,
242 min_hops: usize,
243 max_hops: usize,
244 ) -> Result<ChannelPath> {
245 if !(1..=INTERMEDIATE_HOPS).contains(&max_hops) || !(1..=max_hops).contains(&min_hops) {
248 return Err(GeneralError::InvalidInput.into());
249 }
250
251 let graph = self.graph.read().await;
252
253 let mut queue = graph
255 .open_channels_from(source)
256 .filter(|(node, edge)| self.is_next_hop_usable(node, edge, &source, &destination, &[]))
257 .map(|(_, edge)| WeightedChannelPath::default().extend::<CW>(edge))
258 .collect::<BinaryHeap<_>>();
259
260 trace!(last_peer = %source, queue_len = queue.len(), "got next possible steps");
261
262 let mut iters = 0;
263 while let Some(mut current) = queue.pop() {
264 let current_len = current.path.len();
265
266 trace!(
267 ?current,
268 ?queue,
269 queue_len = queue.len(),
270 iters,
271 min_hops,
272 max_hops,
273 "testing next path in queue"
274 );
275 if current_len == max_hops || current.fully_explored || iters > self.cfg.max_iterations {
276 return if current_len >= min_hops && current_len <= max_hops {
277 Ok(ChannelPath::from_iter(current.path))
278 } else {
279 trace!(current_len, min_hops, max_hops, iters, "path not found");
280 Err(PathError::PathNotFound(
281 max_hops,
282 source.to_string(),
283 destination.to_string(),
284 ))
285 };
286 }
287
288 let last_peer = *current.path.last().unwrap();
290 let mut new_channels = graph
291 .open_channels_from(last_peer)
292 .filter(|(next_hop, edge)| {
293 self.is_next_hop_usable(next_hop, edge, &source, &destination, ¤t.path)
294 })
295 .peekable();
296
297 if new_channels.peek().is_some() {
299 queue.extend(new_channels.map(|(_, edge)| current.clone().extend::<CW>(edge)));
300 trace!(%last_peer, queue_len = queue.len(), "got next possible steps");
301 } else {
302 current.fully_explored = true;
306 trace!(path = ?current, "fully explored");
307 queue.push(current);
308 }
309
310 iters += 1;
311 }
312
313 Err(PathError::PathNotFound(
314 max_hops,
315 source.to_string(),
316 destination.to_string(),
317 ))
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use std::{ops::Deref, str::FromStr, sync::Arc};
324
325 use async_lock::RwLock;
326 use regex::Regex;
327
328 use super::*;
329 use crate::{
330 ChainPath, Path, ValidatedPath,
331 channel_graph::NodeScoreUpdate,
332 tests::{ADDRESSES, PATH_ADDRS},
333 };
334
335 fn create_channel(src: Address, dst: Address, status: ChannelStatus, stake: HoprBalance) -> ChannelEntry {
336 ChannelEntry::new(src, dst, stake, U256::zero(), status, U256::zero())
337 }
338
339 async fn check_path(path: &ChannelPath, graph: &ChannelGraph, dst: Address) -> anyhow::Result<()> {
340 let _ = ValidatedPath::new(
341 graph.my_address(),
342 ChainPath::from_channel_path(path.clone(), dst),
343 graph,
344 PATH_ADDRS.deref(),
345 )
346 .await?;
347
348 assert!(!path.contains_cycle(), "path must not be cyclic");
349 assert!(!path.hops().contains(&dst), "path must not contain destination");
350
351 Ok(())
352 }
353
354 fn define_graph<Q, S>(def: &str, me: Address, quality: Q, score: S) -> ChannelGraph
372 where
373 Q: Fn(Address) -> f64,
374 S: Fn(Address, Address) -> f64,
375 {
376 let mut graph = ChannelGraph::new(me, Default::default());
377
378 if def.is_empty() {
379 return graph;
380 }
381
382 let re: Regex = Regex::new(r"^\s*(\d+)\s*(\[\d+\])?\s*(<?->?)\s*(\[\d+\])?\s*(\d+)\s*$").unwrap();
383 let re_stake = Regex::new(r"^\[(\d+)\]$").unwrap();
384
385 let mut match_stake_and_update_channel = |src: Address, dest: Address, stake_str: &str| {
386 let stake_caps = re_stake.captures(stake_str).unwrap();
387
388 if stake_caps.get(0).is_none() {
389 panic!("no matching stake. got {}", stake_str);
390 }
391 graph.update_channel(create_channel(
392 src,
393 dest,
394 ChannelStatus::Open,
395 U256::from_str(stake_caps.get(1).unwrap().as_str())
396 .expect("failed to create U256 from given stake")
397 .into(),
398 ));
399
400 graph.update_node_score(
401 &src,
402 NodeScoreUpdate::Initialize(Duration::from_millis(10), quality(src)),
403 );
404 graph.update_node_score(
405 &dest,
406 NodeScoreUpdate::Initialize(Duration::from_millis(10), quality(dest)),
407 );
408
409 graph.update_channel_score(&src, &dest, score(src, dest));
410 };
411
412 for edge in def.split(",") {
413 let caps = re.captures(edge).unwrap();
414
415 if caps.get(0).is_none() {
416 panic!("no matching edge. got `{edge}`");
417 }
418
419 let addr_a = ADDRESSES[usize::from_str(caps.get(1).unwrap().as_str()).unwrap()];
420 let addr_b = ADDRESSES[usize::from_str(caps.get(5).unwrap().as_str()).unwrap()];
421
422 let dir = caps.get(3).unwrap().as_str();
423
424 match dir {
425 "->" => {
426 if let Some(stake_b) = caps.get(4) {
427 panic!(
428 "Cannot assign stake for counterparty because channel is unidirectional. Got `{}`",
429 stake_b.as_str()
430 );
431 }
432
433 let stake_opt_a = caps.get(2).ok_or("missing stake for initiator").unwrap();
434
435 match_stake_and_update_channel(addr_a, addr_b, stake_opt_a.as_str());
436 }
437 "<-" => {
438 if let Some(stake_a) = caps.get(2) {
439 panic!(
440 "Cannot assign stake for counterparty because channel is unidirectional. Got `{}`",
441 stake_a.as_str()
442 );
443 }
444
445 let stake_opt_b = caps.get(4).ok_or("missing stake for counterparty").unwrap();
446
447 match_stake_and_update_channel(addr_b, addr_a, stake_opt_b.as_str());
448 }
449 "<->" => {
450 let stake_opt_a = caps.get(2).ok_or("missing stake for initiator").unwrap();
451
452 match_stake_and_update_channel(addr_a, addr_b, stake_opt_a.as_str());
453
454 let stake_opt_b = caps.get(4).ok_or("missing stake for counterparty").unwrap();
455
456 match_stake_and_update_channel(addr_b, addr_a, stake_opt_b.as_str());
457 }
458 _ => panic!("unknown direction infix"),
459 };
460 }
461
462 graph
463 }
464
465 #[derive(Default)]
466 pub struct TestWeights;
467 impl EdgeWeighting<U256> for TestWeights {
468 fn calculate_weight(edge: &ChannelEdge) -> U256 {
469 edge.channel.balance.amount() + 1u32
470 }
471 }
472
473 #[tokio::test]
474 async fn test_should_not_find_path_if_isolated() {
475 let graph = Arc::new(RwLock::new(define_graph("", ADDRESSES[0], |_| 1.0, |_, _| 0.0)));
476
477 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
478
479 selector
480 .select_path(ADDRESSES[0], ADDRESSES[5], 1, 2)
481 .await
482 .expect_err("should not find a path");
483 }
484
485 #[tokio::test]
486 async fn test_should_not_find_zero_weight_path() {
487 let graph = Arc::new(RwLock::new(define_graph(
488 "0 [0] -> 1",
489 ADDRESSES[0],
490 |_| 1.0,
491 |_, _| 0.0,
492 )));
493
494 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
495
496 selector
497 .select_path(ADDRESSES[0], ADDRESSES[5], 1, 1)
498 .await
499 .expect_err("should not find a path");
500 }
501
502 #[tokio::test]
503 async fn test_should_not_find_one_hop_path_when_unrelated_channels_are_in_the_graph() {
504 let graph = Arc::new(RwLock::new(define_graph(
505 "1 [1] -> 2",
506 ADDRESSES[0],
507 |_| 1.0,
508 |_, _| 0.0,
509 )));
510
511 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
512
513 selector
514 .select_path(ADDRESSES[0], ADDRESSES[5], 1, 1)
515 .await
516 .expect_err("should not find a path");
517 }
518
519 #[tokio::test]
520 async fn test_should_not_find_one_hop_path_in_empty_graph() {
521 let graph = Arc::new(RwLock::new(define_graph("", ADDRESSES[0], |_| 1.0, |_, _| 0.0)));
522
523 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
524
525 selector
526 .select_path(ADDRESSES[0], ADDRESSES[5], 1, 1)
527 .await
528 .expect_err("should not find a path");
529 }
530
531 #[tokio::test]
532 async fn test_should_not_find_path_with_unreliable_node() {
533 let graph = Arc::new(RwLock::new(define_graph(
534 "0 [1] -> 1",
535 ADDRESSES[0],
536 |_| 0_f64,
537 |_, _| 0.0,
538 )));
539
540 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
541
542 selector
543 .select_path(ADDRESSES[0], ADDRESSES[5], 1, 1)
544 .await
545 .expect_err("should not find a path");
546 }
547
548 #[tokio::test]
549 async fn test_should_not_find_loopback_path() {
550 let graph = Arc::new(RwLock::new(define_graph(
551 "0 [1] <-> [1] 1",
552 ADDRESSES[0],
553 |_| 1_f64,
554 |_, _| 0.0,
555 )));
556
557 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
558
559 selector
560 .select_path(ADDRESSES[0], ADDRESSES[5], 2, 2)
561 .await
562 .expect_err("should not find a path");
563 }
564
565 #[tokio::test]
566 async fn test_should_not_include_destination_in_path() {
567 let graph = Arc::new(RwLock::new(define_graph(
568 "0 [1] -> 1",
569 ADDRESSES[0],
570 |_| 1_f64,
571 |_, _| 0.0,
572 )));
573
574 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
575
576 selector
577 .select_path(ADDRESSES[0], ADDRESSES[1], 1, 1)
578 .await
579 .expect_err("should not find a path");
580 }
581
582 #[tokio::test]
583 async fn test_should_find_path_in_reliable_star() -> anyhow::Result<()> {
584 let graph = Arc::new(RwLock::new(define_graph(
585 "0 [1] <-> [2] 1, 0 [1] <-> [3] 2, 0 [1] <-> [4] 3, 0 [1] <-> [5] 4",
586 ADDRESSES[1],
587 |_| 1_f64,
588 |_, _| 0.0,
589 )));
590
591 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
592 let path = selector.select_path(ADDRESSES[1], ADDRESSES[5], 1, 2).await?;
593
594 check_path(&path, graph.read().await.deref(), ADDRESSES[5]).await?;
595 assert_eq!(2, path.num_hops(), "should have 2 hops");
596
597 Ok(())
598 }
599
600 #[tokio::test]
601 async fn test_should_find_path_in_reliable_arrow_with_lower_weight() -> anyhow::Result<()> {
602 let graph = Arc::new(RwLock::new(define_graph(
603 "0 [1] -> 1, 1 [1] -> 2, 2 [1] -> 3, 1 [1] -> 3",
604 ADDRESSES[0],
605 |_| 1_f64,
606 |_, _| 0.0,
607 )));
608 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
609
610 let path = selector.select_path(ADDRESSES[0], ADDRESSES[5], 3, 3).await?;
611 check_path(&path, graph.read().await.deref(), ADDRESSES[5]).await?;
612 assert_eq!(3, path.num_hops(), "should have 3 hops");
613
614 Ok(())
615 }
616
617 #[tokio::test]
618 async fn test_should_find_path_in_reliable_arrow_with_higher_weight() -> anyhow::Result<()> {
619 let graph = Arc::new(RwLock::new(define_graph(
620 "0 [1] -> 1, 1 [2] -> 2, 2 [3] -> 3, 1 [2] -> 3",
621 ADDRESSES[0],
622 |_| 1_f64,
623 |_, _| 0.0,
624 )));
625 let selector = DfsPathSelector::<TestWeights>::new(graph.clone(), Default::default());
626
627 let path = selector.select_path(ADDRESSES[0], ADDRESSES[5], 3, 3).await?;
628 check_path(&path, graph.read().await.deref(), ADDRESSES[5]).await?;
629 assert_eq!(3, path.num_hops(), "should have 3 hops");
630
631 Ok(())
632 }
633
634 #[tokio::test]
635 async fn test_should_find_path_in_reliable_arrow_with_random_weight() -> anyhow::Result<()> {
636 let graph = Arc::new(RwLock::new(define_graph(
637 "0 [29] -> 1, 1 [5] -> 2, 2 [31] -> 3, 1 [2] -> 3",
638 ADDRESSES[0],
639 |_| 1_f64,
640 |_, _| 0.0,
641 )));
642 let selector = DfsPathSelector::<RandomizedEdgeWeighting>::new(graph.clone(), Default::default());
643
644 let path = selector.select_path(ADDRESSES[0], ADDRESSES[5], 3, 3).await?;
645 check_path(&path, graph.read().await.deref(), ADDRESSES[5]).await?;
646 assert_eq!(3, path.num_hops(), "should have 3 hops");
647
648 Ok(())
649 }
650}