Skip to main content

hopr_network_graph/petgraph/
algorithm.rs

1//! Adaptation of the algorithm for `petgraph::algo::simple_path::all_simple_paths_multi` to accept
2//! a cost function interacting with the edge weight.
3
4use std::{
5    collections::HashSet,
6    hash::{BuildHasher, Hash},
7    iter::from_fn,
8};
9
10use indexmap::IndexSet;
11use petgraph::{
12    Direction::Outgoing,
13    visit::{EdgeRef, IntoEdgeReferences, IntoEdgesDirected, NodeCount},
14};
15
16/// Calculate all simple paths from a source node to any of several target nodes.
17///
18/// This function is a variant of `all_simple_paths` that accepts a `HashSet` of
19/// target nodes instead of a single one. A path is yielded as soon as it reaches any
20/// node in the `to` set.
21///
22/// # Performance Considerations
23///
24/// The efficiency of this function hinges on the graph's structure. It provides significant
25/// performance gains on graphs where paths share long initial segments (e.g., trees and DAGs),
26/// as the benefit of a single traversal outweighs the `HashSet` lookup overhead.
27///
28/// Conversely, in dense graphs where paths diverge quickly or for targets very close
29/// to the source, the lookup overhead could make repeated calls to `all_simple_paths`
30/// a faster alternative.
31///
32/// **Note**: If security is not a concern, a faster hasher (e.g., `FxBuildHasher`)
33/// can be specified to minimize the `HashSet` lookup overhead.
34///
35/// # Arguments
36/// * `graph`: an input graph.
37/// * `from`: an initial node of desired paths.
38/// * `to`: a `HashSet` of target nodes. A path is yielded as soon as it reaches any node in this set.
39/// * `min_intermediate_nodes`: the minimum number of nodes in the desired paths.
40/// * `max_intermediate_nodes`: the maximum number of nodes in the desired paths (optional).
41/// * `initial_cost`: the starting cost value before any edges are traversed.
42/// * `min_cost`: an optional threshold. If the accumulated cost drops below this value (via `PartialOrd`), the branch
43///   is pruned — it is neither yielded nor explored further.
44/// * `cost_fn`: an accumulator function `(accumulated_cost, &edge_weight, edge_count) -> new_cost` applied at each
45///   edge. The `edge_count` is the 0-based hop number from the source (i.e., 0 for the first edge, 1 for the second,
46///   etc.).
47///
48/// # Returns
49/// Returns an iterator that produces `(path, cost)` tuples for all simple paths from `from` node to any node in the
50/// `to` set, which contains at least `min_intermediate_nodes` and at most `max_intermediate_nodes` intermediate nodes,
51/// if given, or limited by the graph's order otherwise. The cost is the result of folding `cost_fn` over the edge
52/// weights along the path. Paths whose accumulated cost falls below `min_cost` at any point are excluded.
53///
54/// # Complexity
55/// * Time complexity: for computing the first **k** paths, the running time will be **O(k|V| + k|E|)**.
56/// * Auxiliary space: **O(|V|)**.
57///
58/// where **|V|** is the number of nodes and **|E|** is the number of edges.
59///
60/// # Example
61/// ```rust,ignore
62/// use petgraph::prelude::*;
63/// use std::collections::HashSet;
64/// use std::collections::hash_map::RandomState;
65///
66/// let mut graph = DiGraph::<&str, i32>::new();
67///
68/// let a = graph.add_node("a");
69/// let b = graph.add_node("b");
70/// let c = graph.add_node("c");
71/// let d = graph.add_node("d");
72/// graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (b, d, 1)]);
73///
74/// // Find paths from "a" to either "c" or "d", accumulating edge costs.
75/// let targets = HashSet::from_iter([c, d]);
76/// let mut paths = all_simple_paths_multi::<Vec<_>, _, RandomState, _, _>(
77///     &graph, a, &targets, 0, None, 0i32, None, |cost, weight, _| cost + weight,
78/// )
79///     .collect::<Vec<_>>();
80///
81/// paths.sort_by_key(|(p, _)| p.clone());
82/// let expected_paths = vec![
83///     (vec![a, b, c], 2),
84///     (vec![a, b, d], 2),
85/// ];
86///
87/// assert_eq!(paths, expected_paths);
88/// ```
89#[allow(clippy::too_many_arguments)]
90pub fn all_simple_paths_multi<'a, TargetColl, G, S, F, C>(
91    graph: G,
92    from: G::NodeId,
93    to: &'a HashSet<G::NodeId, S>,
94    min_intermediate_nodes: usize,
95    max_intermediate_nodes: Option<usize>,
96    initial_cost: C,
97    min_cost: Option<C>,
98    cost_fn: F,
99) -> impl Iterator<Item = (TargetColl, C)> + 'a
100where
101    G: NodeCount + IntoEdgesDirected + 'a,
102    <G as IntoEdgesDirected>::EdgesDirected: 'a,
103    G::NodeId: Eq + Hash,
104    TargetColl: FromIterator<G::NodeId>,
105    S: BuildHasher + Default,
106    C: Clone + PartialOrd + 'a,
107    F: Fn(C, &<<G as IntoEdgeReferences>::EdgeRef as EdgeRef>::Weight, usize) -> C + 'a,
108{
109    let max_nodes = if let Some(l) = max_intermediate_nodes {
110        l + 2
111    } else {
112        graph.node_count()
113    };
114
115    let min_nodes = min_intermediate_nodes + 2;
116
117    // list of visited nodes
118    let mut visited: IndexSet<G::NodeId, S> = IndexSet::from_iter(Some(from));
119    // list of edges from currently exploring path nodes,
120    // last elem is list of edges of last visited node
121    let mut stack = vec![graph.edges_directed(from, Outgoing)];
122    // accumulated cost at each depth level, parallel to visited
123    let mut costs: Vec<C> = vec![initial_cost];
124
125    from_fn(move || {
126        while let Some(edges) = stack.last_mut() {
127            if let Some(edge) = edges.next() {
128                let child = edge.target();
129
130                if visited.contains(&child) {
131                    continue;
132                }
133
134                // initialized by `from` so is always at least len 1
135                let current_nodes = visited.len();
136                let new_cost = cost_fn(costs.last().unwrap().clone(), edge.weight(), current_nodes - 1);
137
138                // Prune branch if cost drops below threshold
139                if let Some(ref min) = min_cost
140                    && new_cost < *min
141                {
142                    continue;
143                }
144
145                let mut valid_path: Option<(TargetColl, C)> = None;
146
147                // Check if we've reached a target node
148                if to.contains(&child) && (current_nodes + 1) >= min_nodes {
149                    valid_path = Some((
150                        visited.iter().cloned().chain(Some(child)).collect::<TargetColl>(),
151                        new_cost.clone(),
152                    ));
153                }
154
155                // Expand the search only if within max length and unexplored target nodes remain
156                if (current_nodes < (max_nodes - 1)) && to.iter().any(|n| *n != child && !visited.contains(n)) {
157                    visited.insert(child);
158                    stack.push(graph.edges_directed(child, Outgoing));
159                    costs.push(new_cost);
160                }
161
162                // yield the valid path if found
163                if valid_path.is_some() {
164                    return valid_path;
165                }
166            } else {
167                // All edges of the current node have been explored
168                stack.pop();
169                visited.pop();
170                costs.pop();
171            }
172        }
173        None
174    })
175}
176
177#[cfg(test)]
178mod test {
179    use std::collections::{HashSet, hash_map::RandomState};
180
181    use petgraph::prelude::{DiGraph, UnGraph};
182
183    use super::all_simple_paths_multi;
184
185    /// Collect paths as sorted Vec<Vec<usize>> for deterministic snapshots.
186    fn sorted_paths<T, I: Iterator<Item = (Vec<petgraph::graph::NodeIndex>, T)>>(iter: I) -> Vec<Vec<usize>> {
187        let mut paths: Vec<Vec<usize>> = iter.map(|(v, _)| v.into_iter().map(|i| i.index()).collect()).collect();
188        paths.sort();
189        paths
190    }
191
192    #[test]
193    fn undirected_graph_should_find_all_paths_to_multiple_targets() {
194        let graph = UnGraph::<i32, i32>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
195        let targets = HashSet::from_iter([3.into(), 4.into()]);
196        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
197            &graph,
198            0.into(),
199            &targets,
200            0,
201            None,
202            0,
203            None,
204            |c, _, _| c,
205        ));
206        insta::assert_yaml_snapshot!(paths);
207    }
208
209    #[test]
210    fn directed_graph_should_find_all_paths_to_multiple_targets() {
211        let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
212        let targets = HashSet::from_iter([3.into(), 4.into()]);
213        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
214            &graph,
215            0.into(),
216            &targets,
217            0,
218            None,
219            0,
220            None,
221            |c, _, _| c,
222        ));
223        insta::assert_yaml_snapshot!(paths);
224    }
225
226    #[test]
227    fn undirected_graph_should_respect_max_intermediate_nodes() {
228        let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
229        let targets = HashSet::from_iter([3.into(), 4.into()]);
230        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
231            &graph,
232            0.into(),
233            &targets,
234            0,
235            Some(2),
236            0,
237            None,
238            |c, _, _| c,
239        ));
240        insta::assert_yaml_snapshot!(paths);
241    }
242
243    #[test]
244    fn max_intermediate_nodes_should_not_be_exceeded_when_target_connects_to_target() {
245        // Chain: 0->1->2->3, targets={2,3}, max_intermediate_nodes=1
246        // Only [0,1,2] is valid (1 intermediate node).
247        // Bug: the algorithm also yields [0,1,2,3] (2 intermediate nodes) because
248        // it expands through target 2 to reach target 3, pushing visited to max_nodes,
249        // then yields the grandchild path without checking the max length.
250        let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
251        let targets = HashSet::from_iter([2.into(), 3.into()]);
252        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
253            &graph,
254            0.into(),
255            &targets,
256            0,
257            Some(1),
258            0,
259            None,
260            |c, _, _| c,
261        ));
262        insta::assert_yaml_snapshot!(paths);
263    }
264
265    #[test]
266    fn directed_graph_should_respect_max_intermediate_nodes() {
267        let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
268        let targets = HashSet::from_iter([3.into(), 4.into()]);
269        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
270            &graph,
271            0.into(),
272            &targets,
273            0,
274            Some(2),
275            0,
276            None,
277            |c, _, _| c,
278        ));
279        insta::assert_yaml_snapshot!(paths);
280    }
281
282    #[test]
283    fn inline_targets_should_yield_both_short_and_long_paths() {
284        let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
285        let targets = HashSet::from_iter([2.into(), 3.into()]);
286        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
287            &graph,
288            0.into(),
289            &targets,
290            0,
291            None,
292            0,
293            None,
294            |c, _, _| c,
295        ));
296        insta::assert_yaml_snapshot!(paths);
297    }
298
299    #[test]
300    fn cyclic_graph_should_yield_only_simple_paths() {
301        let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 0), (1, 3)]);
302        let targets = HashSet::from_iter([2.into(), 3.into()]);
303        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
304            &graph,
305            0.into(),
306            &targets,
307            0,
308            None,
309            0,
310            None,
311            |c, _, _| c,
312        ));
313        insta::assert_yaml_snapshot!(paths);
314    }
315
316    #[test]
317    fn source_in_target_set_should_not_yield_zero_length_path() {
318        let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
319        let targets = HashSet::from_iter([0.into(), 1.into(), 2.into()]);
320        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
321            &graph,
322            0.into(),
323            &targets,
324            0,
325            None,
326            0,
327            None,
328            |c, _, _| c,
329        ));
330        insta::assert_yaml_snapshot!(paths);
331    }
332
333    #[test]
334    fn non_trivial_graph_should_find_all_simple_paths() {
335        let graph = DiGraph::<i32, ()>::from_edges([
336            (0, 1),
337            (1, 2),
338            (2, 3),
339            (3, 4),
340            (0, 5),
341            (1, 5),
342            (1, 3),
343            (5, 4),
344            (4, 2),
345            (4, 3),
346        ]);
347        let targets = HashSet::from_iter([2.into(), 3.into()]);
348        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
349            &graph,
350            1.into(),
351            &targets,
352            0,
353            None,
354            0,
355            None,
356            |c, _, _| c,
357        ));
358        insta::assert_yaml_snapshot!(paths);
359    }
360
361    #[test]
362    fn min_intermediate_nodes_should_exclude_shorter_paths() {
363        let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]);
364        let targets = HashSet::from_iter([1.into(), 3.into()]);
365        let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
366            &graph,
367            0.into(),
368            &targets,
369            2,
370            None,
371            0,
372            None,
373            |c, _, _| c,
374        ));
375        insta::assert_yaml_snapshot!(paths);
376    }
377
378    #[test]
379    fn multiplicative_cost_should_accumulate_along_path() {
380        // 0 --0.9--> 1 --0.8--> 2 --0.7--> 3
381        //                  \--0.6--> 4
382        let mut graph = DiGraph::<(), f64>::new();
383        let n: Vec<_> = (0..5).map(|_| graph.add_node(())).collect();
384        graph.extend_with_edges([
385            (n[0], n[1], 0.9),
386            (n[1], n[2], 0.8),
387            (n[2], n[3], 0.7),
388            (n[1], n[4], 0.6),
389        ]);
390
391        let targets = HashSet::from_iter([n[3], n[4]]);
392        let results: Vec<(Vec<_>, f64)> =
393            all_simple_paths_multi::<_, _, RandomState, _, _>(&graph, n[0], &targets, 0, None, 1.0, None, |c, w, _| {
394                c * w
395            })
396            .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
397            .collect();
398
399        // Path 0->1->2->3: cost = 1.0 * 0.9 * 0.8 * 0.7 = 0.504
400        // Path 0->1->4:     cost = 1.0 * 0.9 * 0.6 = 0.54
401        assert_eq!(results.len(), 2);
402        for (path, cost) in &results {
403            match path.as_slice() {
404                [0, 1, 2, 3] => assert!((cost - 0.504).abs() < 1e-9),
405                [0, 1, 4] => assert!((cost - 0.54).abs() < 1e-9),
406                other => panic!("unexpected path: {other:?}"),
407            }
408        }
409    }
410
411    #[test]
412    fn min_cost_should_prune_path_falling_below_threshold() {
413        // 0 --0.9--> 1 --0.8--> 2 --0.7--> 3
414        //                  \--0.6--> 4
415        // With min_cost = 0.51, path 0->1->2->3 (cost 0.504) is pruned at the 2->3 edge,
416        // but 0->1->4 (cost 0.54) survives.
417        let mut graph = DiGraph::<(), f64>::new();
418        let n: Vec<_> = (0..5).map(|_| graph.add_node(())).collect();
419        graph.extend_with_edges([
420            (n[0], n[1], 0.9),
421            (n[1], n[2], 0.8),
422            (n[2], n[3], 0.7),
423            (n[1], n[4], 0.6),
424        ]);
425
426        let targets = HashSet::from_iter([n[3], n[4]]);
427        let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
428            &graph,
429            n[0],
430            &targets,
431            0,
432            None,
433            1.0,
434            Some(0.51),
435            |c, w, _| c * w,
436        )
437        .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
438        .collect();
439
440        assert_eq!(results.len(), 1);
441        assert_eq!(results[0].0, vec![0, 1, 4]);
442        assert!((results[0].1 - 0.54).abs() < 1e-9);
443    }
444
445    #[test]
446    fn min_cost_should_prune_entire_branch_on_low_first_edge() {
447        // 0 --0.1--> 1 --0.9--> 2
448        //      \--0.9--> 3 --0.9--> 2
449        // With min_cost = 0.5, the 0->1 edge (cost 0.1) is pruned immediately,
450        // so only 0->3->2 (cost 0.81) is found.
451        let mut graph = DiGraph::<(), f64>::new();
452        let n: Vec<_> = (0..4).map(|_| graph.add_node(())).collect();
453        graph.extend_with_edges([
454            (n[0], n[1], 0.1),
455            (n[1], n[2], 0.9),
456            (n[0], n[3], 0.9),
457            (n[3], n[2], 0.9),
458        ]);
459
460        let targets = HashSet::from_iter([n[2]]);
461        let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
462            &graph,
463            n[0],
464            &targets,
465            0,
466            None,
467            1.0,
468            Some(0.5),
469            |c, w, _| c * w,
470        )
471        .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
472        .collect();
473
474        assert_eq!(results.len(), 1);
475        assert_eq!(results[0].0, vec![0, 3, 2]);
476        assert!((results[0].1 - 0.81).abs() < 1e-9);
477    }
478
479    #[test]
480    fn min_cost_should_yield_empty_when_all_paths_below_threshold() {
481        // 0 --0.3--> 1 --0.3--> 2
482        // With min_cost = 0.5, the 0->1 edge (cost 0.3) is pruned immediately,
483        // so no paths are found.
484        let mut graph = DiGraph::<(), f64>::new();
485        let n: Vec<_> = (0..3).map(|_| graph.add_node(())).collect();
486        graph.extend_with_edges([(n[0], n[1], 0.3), (n[1], n[2], 0.3)]);
487
488        let targets = HashSet::from_iter([n[2]]);
489        let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
490            &graph,
491            n[0],
492            &targets,
493            0,
494            None,
495            1.0,
496            Some(0.5),
497            |c, w, _| c * w,
498        )
499        .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
500        .collect();
501
502        assert!(results.is_empty());
503    }
504}