1use std::{
5 collections::HashSet,
6 hash::{BuildHasher, Hash},
7 iter::from_fn,
8};
9
10use indexmap::IndexSet;
11use petgraph::{
12 Direction::Outgoing,
13 visit::{EdgeRef, IntoEdgeReferences, IntoEdgesDirected, NodeCount},
14};
15
16#[allow(clippy::too_many_arguments)]
90pub fn all_simple_paths_multi<'a, TargetColl, G, S, F, C>(
91 graph: G,
92 from: G::NodeId,
93 to: &'a HashSet<G::NodeId, S>,
94 min_intermediate_nodes: usize,
95 max_intermediate_nodes: Option<usize>,
96 initial_cost: C,
97 min_cost: Option<C>,
98 cost_fn: F,
99) -> impl Iterator<Item = (TargetColl, C)> + 'a
100where
101 G: NodeCount + IntoEdgesDirected + 'a,
102 <G as IntoEdgesDirected>::EdgesDirected: 'a,
103 G::NodeId: Eq + Hash,
104 TargetColl: FromIterator<G::NodeId>,
105 S: BuildHasher + Default,
106 C: Clone + PartialOrd + 'a,
107 F: Fn(C, &<<G as IntoEdgeReferences>::EdgeRef as EdgeRef>::Weight, usize) -> C + 'a,
108{
109 let max_nodes = if let Some(l) = max_intermediate_nodes {
110 l + 2
111 } else {
112 graph.node_count()
113 };
114
115 let min_nodes = min_intermediate_nodes + 2;
116
117 let mut visited: IndexSet<G::NodeId, S> = IndexSet::from_iter(Some(from));
119 let mut stack = vec![graph.edges_directed(from, Outgoing)];
122 let mut costs: Vec<C> = vec![initial_cost];
124
125 from_fn(move || {
126 while let Some(edges) = stack.last_mut() {
127 if let Some(edge) = edges.next() {
128 let child = edge.target();
129
130 if visited.contains(&child) {
131 continue;
132 }
133
134 let current_nodes = visited.len();
136 let new_cost = cost_fn(costs.last().unwrap().clone(), edge.weight(), current_nodes - 1);
137
138 if let Some(ref min) = min_cost
140 && new_cost < *min
141 {
142 continue;
143 }
144
145 let mut valid_path: Option<(TargetColl, C)> = None;
146
147 if to.contains(&child) && (current_nodes + 1) >= min_nodes {
149 valid_path = Some((
150 visited.iter().cloned().chain(Some(child)).collect::<TargetColl>(),
151 new_cost.clone(),
152 ));
153 }
154
155 if (current_nodes < (max_nodes - 1)) && to.iter().any(|n| *n != child && !visited.contains(n)) {
157 visited.insert(child);
158 stack.push(graph.edges_directed(child, Outgoing));
159 costs.push(new_cost);
160 }
161
162 if valid_path.is_some() {
164 return valid_path;
165 }
166 } else {
167 stack.pop();
169 visited.pop();
170 costs.pop();
171 }
172 }
173 None
174 })
175}
176
177#[cfg(test)]
178mod test {
179 use std::collections::{HashSet, hash_map::RandomState};
180
181 use petgraph::prelude::{DiGraph, UnGraph};
182
183 use super::all_simple_paths_multi;
184
185 fn sorted_paths<T, I: Iterator<Item = (Vec<petgraph::graph::NodeIndex>, T)>>(iter: I) -> Vec<Vec<usize>> {
187 let mut paths: Vec<Vec<usize>> = iter.map(|(v, _)| v.into_iter().map(|i| i.index()).collect()).collect();
188 paths.sort();
189 paths
190 }
191
192 #[test]
193 fn undirected_graph_should_find_all_paths_to_multiple_targets() {
194 let graph = UnGraph::<i32, i32>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
195 let targets = HashSet::from_iter([3.into(), 4.into()]);
196 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
197 &graph,
198 0.into(),
199 &targets,
200 0,
201 None,
202 0,
203 None,
204 |c, _, _| c,
205 ));
206 insta::assert_yaml_snapshot!(paths);
207 }
208
209 #[test]
210 fn directed_graph_should_find_all_paths_to_multiple_targets() {
211 let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
212 let targets = HashSet::from_iter([3.into(), 4.into()]);
213 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
214 &graph,
215 0.into(),
216 &targets,
217 0,
218 None,
219 0,
220 None,
221 |c, _, _| c,
222 ));
223 insta::assert_yaml_snapshot!(paths);
224 }
225
226 #[test]
227 fn undirected_graph_should_respect_max_intermediate_nodes() {
228 let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
229 let targets = HashSet::from_iter([3.into(), 4.into()]);
230 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
231 &graph,
232 0.into(),
233 &targets,
234 0,
235 Some(2),
236 0,
237 None,
238 |c, _, _| c,
239 ));
240 insta::assert_yaml_snapshot!(paths);
241 }
242
243 #[test]
244 fn max_intermediate_nodes_should_not_be_exceeded_when_target_connects_to_target() {
245 let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
251 let targets = HashSet::from_iter([2.into(), 3.into()]);
252 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
253 &graph,
254 0.into(),
255 &targets,
256 0,
257 Some(1),
258 0,
259 None,
260 |c, _, _| c,
261 ));
262 insta::assert_yaml_snapshot!(paths);
263 }
264
265 #[test]
266 fn directed_graph_should_respect_max_intermediate_nodes() {
267 let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3), (2, 4)]);
268 let targets = HashSet::from_iter([3.into(), 4.into()]);
269 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
270 &graph,
271 0.into(),
272 &targets,
273 0,
274 Some(2),
275 0,
276 None,
277 |c, _, _| c,
278 ));
279 insta::assert_yaml_snapshot!(paths);
280 }
281
282 #[test]
283 fn inline_targets_should_yield_both_short_and_long_paths() {
284 let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
285 let targets = HashSet::from_iter([2.into(), 3.into()]);
286 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
287 &graph,
288 0.into(),
289 &targets,
290 0,
291 None,
292 0,
293 None,
294 |c, _, _| c,
295 ));
296 insta::assert_yaml_snapshot!(paths);
297 }
298
299 #[test]
300 fn cyclic_graph_should_yield_only_simple_paths() {
301 let graph = DiGraph::<i32, ()>::from_edges([(0, 1), (1, 2), (2, 0), (1, 3)]);
302 let targets = HashSet::from_iter([2.into(), 3.into()]);
303 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
304 &graph,
305 0.into(),
306 &targets,
307 0,
308 None,
309 0,
310 None,
311 |c, _, _| c,
312 ));
313 insta::assert_yaml_snapshot!(paths);
314 }
315
316 #[test]
317 fn source_in_target_set_should_not_yield_zero_length_path() {
318 let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
319 let targets = HashSet::from_iter([0.into(), 1.into(), 2.into()]);
320 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
321 &graph,
322 0.into(),
323 &targets,
324 0,
325 None,
326 0,
327 None,
328 |c, _, _| c,
329 ));
330 insta::assert_yaml_snapshot!(paths);
331 }
332
333 #[test]
334 fn non_trivial_graph_should_find_all_simple_paths() {
335 let graph = DiGraph::<i32, ()>::from_edges([
336 (0, 1),
337 (1, 2),
338 (2, 3),
339 (3, 4),
340 (0, 5),
341 (1, 5),
342 (1, 3),
343 (5, 4),
344 (4, 2),
345 (4, 3),
346 ]);
347 let targets = HashSet::from_iter([2.into(), 3.into()]);
348 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
349 &graph,
350 1.into(),
351 &targets,
352 0,
353 None,
354 0,
355 None,
356 |c, _, _| c,
357 ));
358 insta::assert_yaml_snapshot!(paths);
359 }
360
361 #[test]
362 fn min_intermediate_nodes_should_exclude_shorter_paths() {
363 let graph = UnGraph::<i32, ()>::from_edges([(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]);
364 let targets = HashSet::from_iter([1.into(), 3.into()]);
365 let paths = sorted_paths(all_simple_paths_multi::<_, _, RandomState, _, _>(
366 &graph,
367 0.into(),
368 &targets,
369 2,
370 None,
371 0,
372 None,
373 |c, _, _| c,
374 ));
375 insta::assert_yaml_snapshot!(paths);
376 }
377
378 #[test]
379 fn multiplicative_cost_should_accumulate_along_path() {
380 let mut graph = DiGraph::<(), f64>::new();
383 let n: Vec<_> = (0..5).map(|_| graph.add_node(())).collect();
384 graph.extend_with_edges([
385 (n[0], n[1], 0.9),
386 (n[1], n[2], 0.8),
387 (n[2], n[3], 0.7),
388 (n[1], n[4], 0.6),
389 ]);
390
391 let targets = HashSet::from_iter([n[3], n[4]]);
392 let results: Vec<(Vec<_>, f64)> =
393 all_simple_paths_multi::<_, _, RandomState, _, _>(&graph, n[0], &targets, 0, None, 1.0, None, |c, w, _| {
394 c * w
395 })
396 .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
397 .collect();
398
399 assert_eq!(results.len(), 2);
402 for (path, cost) in &results {
403 match path.as_slice() {
404 [0, 1, 2, 3] => assert!((cost - 0.504).abs() < 1e-9),
405 [0, 1, 4] => assert!((cost - 0.54).abs() < 1e-9),
406 other => panic!("unexpected path: {other:?}"),
407 }
408 }
409 }
410
411 #[test]
412 fn min_cost_should_prune_path_falling_below_threshold() {
413 let mut graph = DiGraph::<(), f64>::new();
418 let n: Vec<_> = (0..5).map(|_| graph.add_node(())).collect();
419 graph.extend_with_edges([
420 (n[0], n[1], 0.9),
421 (n[1], n[2], 0.8),
422 (n[2], n[3], 0.7),
423 (n[1], n[4], 0.6),
424 ]);
425
426 let targets = HashSet::from_iter([n[3], n[4]]);
427 let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
428 &graph,
429 n[0],
430 &targets,
431 0,
432 None,
433 1.0,
434 Some(0.51),
435 |c, w, _| c * w,
436 )
437 .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
438 .collect();
439
440 assert_eq!(results.len(), 1);
441 assert_eq!(results[0].0, vec![0, 1, 4]);
442 assert!((results[0].1 - 0.54).abs() < 1e-9);
443 }
444
445 #[test]
446 fn min_cost_should_prune_entire_branch_on_low_first_edge() {
447 let mut graph = DiGraph::<(), f64>::new();
452 let n: Vec<_> = (0..4).map(|_| graph.add_node(())).collect();
453 graph.extend_with_edges([
454 (n[0], n[1], 0.1),
455 (n[1], n[2], 0.9),
456 (n[0], n[3], 0.9),
457 (n[3], n[2], 0.9),
458 ]);
459
460 let targets = HashSet::from_iter([n[2]]);
461 let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
462 &graph,
463 n[0],
464 &targets,
465 0,
466 None,
467 1.0,
468 Some(0.5),
469 |c, w, _| c * w,
470 )
471 .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
472 .collect();
473
474 assert_eq!(results.len(), 1);
475 assert_eq!(results[0].0, vec![0, 3, 2]);
476 assert!((results[0].1 - 0.81).abs() < 1e-9);
477 }
478
479 #[test]
480 fn min_cost_should_yield_empty_when_all_paths_below_threshold() {
481 let mut graph = DiGraph::<(), f64>::new();
485 let n: Vec<_> = (0..3).map(|_| graph.add_node(())).collect();
486 graph.extend_with_edges([(n[0], n[1], 0.3), (n[1], n[2], 0.3)]);
487
488 let targets = HashSet::from_iter([n[2]]);
489 let results: Vec<(Vec<_>, f64)> = all_simple_paths_multi::<_, _, RandomState, _, _>(
490 &graph,
491 n[0],
492 &targets,
493 0,
494 None,
495 1.0,
496 Some(0.5),
497 |c, w, _| c * w,
498 )
499 .map(|(v, cost): (Vec<_>, f64)| (v.into_iter().map(|i| i.index()).collect(), cost))
500 .collect();
501
502 assert!(results.is_empty());
503 }
504}