use hopr_crypto_random::random_float;
use hopr_internal_types::prelude::*;
use hopr_primitive_types::prelude::*;
use std::cmp::{max, Ordering};
use std::collections::BinaryHeap;
use std::marker::PhantomData;
use tracing::trace;
use crate::channel_graph::{ChannelEdge, ChannelGraph, Node};
use crate::errors::{PathError, Result};
use crate::path::ChannelPath;
use crate::selectors::{EdgeWeighting, PathSelector};
#[derive(Clone, Debug, PartialEq, Eq)]
struct WeightedChannelPath {
path: Vec<Address>,
weight: U256,
fully_explored: bool,
}
impl WeightedChannelPath {
pub fn extend<CW: EdgeWeighting<U256>>(mut self, edge: &ChannelEdge) -> Self {
if !self.fully_explored {
self.path.push(edge.channel.destination);
self.weight += CW::calculate_weight(edge);
}
self
}
}
impl Default for WeightedChannelPath {
fn default() -> Self {
Self {
path: Vec::with_capacity(INTERMEDIATE_HOPS),
weight: U256::zero(),
fully_explored: false,
}
}
}
impl PartialOrd for WeightedChannelPath {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WeightedChannelPath {
fn cmp(&self, other: &Self) -> Ordering {
if other.fully_explored == self.fully_explored {
match self.path.len().cmp(&other.path.len()) {
Ordering::Equal => self.weight.cmp(&other.weight),
o => o,
}
} else if other.fully_explored {
Ordering::Greater
} else {
Ordering::Less
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct RandomizedEdgeWeighting;
impl EdgeWeighting<U256> for RandomizedEdgeWeighting {
fn calculate_weight(edge: &ChannelEdge) -> U256 {
max(
U256::one(),
edge.channel
.balance
.amount()
.mul_f64(random_float())
.expect("Could not multiply edge weight with float"),
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, smart_default::SmartDefault)]
pub struct DfsPathSelectorConfig {
#[default(100)]
pub max_iterations: usize,
#[default(0.2)]
pub quality_threshold: f64,
#[default(0.0)]
pub score_threshold: f64,
#[default(false)]
pub allow_zero_edge_weight: bool,
}
#[derive(Clone, Debug, Default)]
pub struct DfsPathSelector<CW> {
cfg: DfsPathSelectorConfig,
_cw: PhantomData<CW>,
}
impl<CW: EdgeWeighting<U256>> DfsPathSelector<CW> {
pub fn new(cfg: DfsPathSelectorConfig) -> Self {
Self { cfg, _cw: PhantomData }
}
fn is_next_hop_usable(
&self,
next_hop: &Node,
edge: &ChannelEdge,
initial_source: &Address,
final_destination: &Address,
current_path: &[Address],
) -> bool {
debug_assert_eq!(next_hop.address, edge.channel.destination);
if next_hop.address.eq(initial_source) {
trace!(%next_hop, "source loopback not allowed");
return false;
}
if next_hop.address.eq(final_destination) {
trace!(%next_hop, "destination loopback not allowed");
return false;
}
if next_hop.quality < self.cfg.quality_threshold {
trace!(%next_hop, "node quality threshold not satisfied");
return false;
}
if edge.score.is_some_and(|score| score < self.cfg.score_threshold) {
trace!(%next_hop, "channel score threshold not satisfied");
return false;
}
if current_path.contains(&next_hop.address) {
trace!(%next_hop, "circles not allowed");
return false;
}
if !self.cfg.allow_zero_edge_weight && edge.channel.balance.is_zero() {
trace!(%next_hop, "zero stake channels not allowed");
return false;
}
trace!(%next_hop, ?current_path, "usable node");
true
}
}
impl<CW> PathSelector<CW> for DfsPathSelector<CW>
where
CW: EdgeWeighting<U256>,
{
fn select_path(
&self,
graph: &ChannelGraph,
source: Address,
destination: Address,
min_hops: usize,
max_hops: usize,
) -> Result<ChannelPath> {
if !(1..=INTERMEDIATE_HOPS).contains(&max_hops) || !(1..=max_hops).contains(&min_hops) {
return Err(GeneralError::InvalidInput.into());
}
let mut queue = graph
.open_channels_from(source)
.filter(|(node, edge)| self.is_next_hop_usable(node, edge, &source, &destination, &[]))
.map(|(_, edge)| WeightedChannelPath::default().extend::<CW>(edge))
.collect::<BinaryHeap<_>>();
trace!(last_peer = %source, queue_len = queue.len(), "got next possible steps");
let mut iters = 0;
while let Some(mut current) = queue.pop() {
let current_len = current.path.len();
trace!(
?current,
?queue,
queue_len = queue.len(),
iters,
min_hops,
max_hops,
"testing next path in queue"
);
if current_len == max_hops || current.fully_explored || iters > self.cfg.max_iterations {
return if current_len >= min_hops && current_len <= max_hops {
Ok(ChannelPath::new_valid(current.path))
} else {
trace!(current_len, min_hops, max_hops, iters, "path not found");
Err(PathError::PathNotFound(
max_hops,
source.to_string(),
destination.to_string(),
))
};
}
let last_peer = *current.path.last().unwrap();
let mut new_channels = graph
.open_channels_from(last_peer)
.filter(|(next_hop, edge)| {
self.is_next_hop_usable(next_hop, edge, &source, &destination, ¤t.path)
})
.peekable();
if new_channels.peek().is_some() {
queue.extend(new_channels.map(|(_, edge)| current.clone().extend::<CW>(edge)));
trace!(%last_peer, queue_len = queue.len(), "got next possible steps");
} else {
current.fully_explored = true;
trace!(path = ?current, "fully explored");
queue.push(current);
}
iters += 1;
}
Err(PathError::PathNotFound(
max_hops,
source.to_string(),
destination.to_string(),
))
}
}
pub type DefaultPathSelector = DfsPathSelector<RandomizedEdgeWeighting>;
#[cfg(test)]
mod tests {
use super::*;
use core::panic;
use regex::Regex;
use std::str::FromStr;
use crate::path::Path;
lazy_static::lazy_static! {
static ref ADDRESSES: [Address; 6] = [
Address::from_str("0x0000c178cf70d966be0a798e666ce2782c7b2288").unwrap(),
Address::from_str("0x1000d5786d9e6799b3297da1ad55605b91e2c882").unwrap(),
Address::from_str("0x200060ddced1e33c9647a71f4fc2cf4ed33e4a9d").unwrap(),
Address::from_str("0x30004105095c8c10f804109b4d1199a9ac40ed46").unwrap(),
Address::from_str("0x4000a288c38fa8a0f4b79127747257af4a03a623").unwrap(),
Address::from_str("0x50002f462ec709cf181bbe44a7e952487bd4591d").unwrap(),
];
}
fn create_channel(src: Address, dst: Address, status: ChannelStatus, stake: Balance) -> ChannelEntry {
ChannelEntry::new(src, dst, stake, U256::zero(), status, U256::zero())
}
fn check_path(path: &ChannelPath, graph: &ChannelGraph, dst: Address) -> anyhow::Result<()> {
let other = ChannelPath::new(path.hops().into(), graph)?;
assert_eq!(other, path.clone(), "valid paths must be equal");
assert!(!path.contains_cycle(), "path must not be cyclic");
assert!(!path.hops().contains(&dst), "path must not contain destination");
Ok(())
}
fn define_graph<Q, S>(def: &str, me: Address, quality: Q, score: S) -> ChannelGraph
where
Q: Fn(Address) -> f64,
S: Fn(Address, Address) -> f64,
{
let mut graph = ChannelGraph::new(me);
if def.is_empty() {
return graph;
}
let re: Regex = Regex::new(r"^\s*(\d+)\s*(\[\d+\])?\s*(<?->?)\s*(\[\d+\])?\s*(\d+)\s*$").unwrap();
let re_stake = Regex::new(r"^\[(\d+)\]$").unwrap();
let mut match_stake_and_update_channel = |src: Address, dest: Address, stake_str: &str| {
let stake_caps = re_stake.captures(stake_str).unwrap();
if stake_caps.get(0).is_none() {
panic!("no matching stake. got {}", stake_str);
}
graph.update_channel(create_channel(
src,
dest,
ChannelStatus::Open,
Balance::new(
U256::from_str(stake_caps.get(1).unwrap().as_str())
.expect("failed to create U256 from given stake"),
BalanceType::HOPR,
),
));
graph.update_node_quality(&src, quality(src));
graph.update_node_quality(&dest, quality(dest));
graph.update_channel_score(&src, &dest, score(src, dest));
};
for edge in def.split(",") {
let caps = re.captures(edge).unwrap();
if caps.get(0).is_none() {
panic!("no matching edge. got `{edge}`");
}
let addr_a = ADDRESSES[usize::from_str(caps.get(1).unwrap().as_str()).unwrap()];
let addr_b = ADDRESSES[usize::from_str(caps.get(5).unwrap().as_str()).unwrap()];
let dir = caps.get(3).unwrap().as_str();
match dir {
"->" => {
if let Some(stake_b) = caps.get(4) {
panic!(
"Cannot assign stake for counterparty because channel is unidirectional. Got `{}`",
stake_b.as_str()
);
}
let stake_opt_a = caps.get(2).ok_or("missing stake for initiator").unwrap();
match_stake_and_update_channel(addr_a, addr_b, stake_opt_a.as_str());
}
"<-" => {
if let Some(stake_a) = caps.get(2) {
panic!(
"Cannot assign stake for counterparty because channel is unidirectional. Got `{}`",
stake_a.as_str()
);
}
let stake_opt_b = caps.get(4).ok_or("missing stake for counterparty").unwrap();
match_stake_and_update_channel(addr_b, addr_a, stake_opt_b.as_str());
}
"<->" => {
let stake_opt_a = caps.get(2).ok_or("missing stake for initiator").unwrap();
match_stake_and_update_channel(addr_a, addr_b, stake_opt_a.as_str());
let stake_opt_b = caps.get(4).ok_or("missing stake for counterparty").unwrap();
match_stake_and_update_channel(addr_b, addr_a, stake_opt_b.as_str());
}
_ => panic!("unknown direction infix"),
};
}
graph
}
#[derive(Default)]
pub struct TestWeights;
impl EdgeWeighting<U256> for TestWeights {
fn calculate_weight(edge: &ChannelEdge) -> U256 {
edge.channel.balance.amount() + 1u32
}
}
#[test]
fn test_should_not_find_path_if_isolated() {
let isolated = define_graph("", ADDRESSES[0], |_| 1.0, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 1, 2)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_find_zero_weight_path() {
let isolated = define_graph("0 [0] -> 1", ADDRESSES[0], |_| 1.0, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 1, 1)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_find_one_hop_path_when_unrelated_channels_are_in_the_graph() {
let isolated = define_graph("1 [1] -> 2", ADDRESSES[0], |_| 1.0, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 1, 1)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_find_one_hop_path_in_empty_graph() {
let isolated = define_graph("", ADDRESSES[0], |_| 1.0, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 1, 1)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_find_path_with_unreliable_node() {
let isolated = define_graph("0 [1] -> 1", ADDRESSES[0], |_| 0_f64, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 1, 1)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_find_loopback_path() {
let isolated = define_graph("0 [1] <-> [1] 1", ADDRESSES[0], |_| 1_f64, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[5], 2, 2)
.expect_err("should not find a path");
}
#[test]
fn test_should_not_include_destination_in_path() {
let isolated = define_graph("0 [1] -> 1", ADDRESSES[0], |_| 1_f64, |_, _| 0.0);
let selector = DfsPathSelector::<TestWeights>::default();
selector
.select_path(&isolated, ADDRESSES[0], ADDRESSES[1], 1, 1)
.expect_err("should not find a path");
}
#[test]
fn test_should_find_path_in_reliable_star() -> anyhow::Result<()> {
let star = define_graph(
"0 [1] <-> [2] 1, 0 [1] <-> [3] 2, 0 [1] <-> [4] 3, 0 [1] <-> [5] 4",
ADDRESSES[1],
|_| 1_f64,
|_, _| 0.0,
);
let selector = DfsPathSelector::<TestWeights>::default();
let path = selector.select_path(&star, ADDRESSES[1], ADDRESSES[5], 1, 2)?;
check_path(&path, &star, ADDRESSES[5])?;
assert_eq!(2, path.length(), "should have 2 hops");
Ok(())
}
#[test]
fn test_should_find_path_in_reliable_arrow_with_lower_weight() -> anyhow::Result<()> {
let arrow = define_graph(
"0 [1] -> 1, 1 [1] -> 2, 2 [1] -> 3, 1 [1] -> 3",
ADDRESSES[0],
|_| 1_f64,
|_, _| 0.0,
);
let selector = DfsPathSelector::<TestWeights>::default();
let path = selector.select_path(&arrow, ADDRESSES[0], ADDRESSES[5], 3, 3)?;
check_path(&path, &arrow, ADDRESSES[5])?;
assert_eq!(3, path.length(), "should have 3 hops");
Ok(())
}
#[test]
fn test_should_find_path_in_reliable_arrow_with_higher_weight() -> anyhow::Result<()> {
let arrow = define_graph(
"0 [1] -> 1, 1 [2] -> 2, 2 [3] -> 3, 1 [2] -> 3",
ADDRESSES[0],
|_| 1_f64,
|_, _| 0.0,
);
let selector = DfsPathSelector::<TestWeights>::default();
let path = selector.select_path(&arrow, ADDRESSES[0], ADDRESSES[5], 3, 3)?;
check_path(&path, &arrow, ADDRESSES[5])?;
assert_eq!(3, path.length(), "should have 3 hops");
Ok(())
}
#[test]
fn test_should_find_path_in_reliable_arrow_with_random_weight() -> anyhow::Result<()> {
let arrow = define_graph(
"0 [29] -> 1, 1 [5] -> 2, 2 [31] -> 3, 1 [2] -> 3",
ADDRESSES[0],
|_| 1_f64,
|_, _| 0.0,
);
let selector = DfsPathSelector::<RandomizedEdgeWeighting>::default();
let path = selector.select_path(&arrow, ADDRESSES[0], ADDRESSES[5], 3, 3)?;
check_path(&path, &arrow, ADDRESSES[5])?;
assert_eq!(3, path.length(), "should have 3 hops");
Ok(())
}
}