hopr_statistics/
weighted.rs1use rand::RngExt;
2
3pub struct WeightedCollection<T> {
20 items: Vec<(T, f64)>,
21 total_weight: f64,
23}
24
25impl<T> WeightedCollection<T> {
26 pub fn new(items: Vec<(T, f64)>) -> Self {
28 let total_weight: f64 = items.iter().map(|(_, w)| w.max(0.0)).sum();
29 Self { items, total_weight }
30 }
31
32 pub fn is_empty(&self) -> bool {
34 self.items.is_empty()
35 }
36
37 pub fn len(&self) -> usize {
39 self.items.len()
40 }
41
42 pub fn iter(&self) -> impl Iterator<Item = &(T, f64)> {
44 self.items.iter()
45 }
46
47 pub fn pick_index(&self) -> Option<usize> {
52 if self.total_weight <= 0.0 {
53 return None;
54 }
55
56 if self.items.len() == 1 {
57 return Some(0);
58 }
59
60 let mut rng = rand::rng();
61 let r = rng.random_range(0.0..self.total_weight);
62 let mut cumulative = 0.0;
63 for (i, (_, weight)) in self.items.iter().enumerate() {
64 cumulative += weight.max(0.0);
65 if r < cumulative {
66 return Some(i);
67 }
68 }
69
70 self.items.iter().rposition(|(_, weight)| *weight > 0.0)
72 }
73}
74
75impl<T> WeightedCollection<T> {
76 pub fn pick_ref(&self) -> Option<&T> {
81 self.pick_index().map(|i| &self.items[i].0)
82 }
83}
84
85impl<T: Clone> WeightedCollection<T> {
86 pub fn pick_one(&self) -> Option<T> {
90 self.pick_ref().cloned()
91 }
92}
93
94impl<T> WeightedCollection<T> {
95 pub fn into_shuffled(self) -> Vec<T> {
102 let mut rng = rand::rng();
103
104 let mut keyed = self.items;
106 for (_, weight) in keyed.iter_mut() {
107 *weight = if *weight > 0.0 {
108 let u: f64 = rng.random_range(f64::EPSILON..1.0);
109 u.powf(1.0 / *weight)
110 } else {
111 0.0
112 };
113 }
114
115 keyed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
116 keyed.into_iter().map(|(item, _)| item).collect()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn pick_one_returns_none_for_empty_collection() {
126 let wc: WeightedCollection<&str> = WeightedCollection::new(vec![]);
127 assert!(wc.pick_one().is_none());
128 }
129
130 #[test]
131 fn pick_one_returns_sole_item_with_positive_weight() {
132 let wc = WeightedCollection::new(vec![("only", 1.0)]);
133 assert_eq!(wc.pick_one(), Some("only"));
134 }
135
136 #[test]
137 fn pick_one_returns_none_for_sole_item_with_non_positive_weight() {
138 let wc = WeightedCollection::new(vec![("only", 0.0)]);
139 assert!(wc.pick_one().is_none());
140 let wc = WeightedCollection::new(vec![("only", -1.0)]);
141 assert!(wc.pick_one().is_none());
142 }
143
144 #[test]
145 fn pick_one_favors_higher_weight() {
146 let wc = WeightedCollection::new(vec![("low", 0.01), ("high", 100.0)]);
147 let mut high_count = 0;
148 let trials = 1000;
149 for _ in 0..trials {
150 if wc.pick_one() == Some("high") {
151 high_count += 1;
152 }
153 }
154 assert!(
155 high_count > trials * 9 / 10,
156 "high-weight item should be picked >90% of the time, was {high_count}/{trials}"
157 );
158 }
159
160 #[test]
161 fn pick_one_returns_none_for_all_non_positive_weights() {
162 let wc = WeightedCollection::new(vec![("a", 0.0), ("b", -1.0)]);
163 assert!(wc.pick_one().is_none());
164 }
165
166 #[test]
167 fn pick_index_returns_valid_index() {
168 let wc = WeightedCollection::new(vec![("a", 1.0), ("b", 2.0), ("c", 3.0)]);
169 for _ in 0..100 {
170 let idx = wc.pick_index().expect("should pick an index");
171 assert!(idx < 3);
172 }
173 }
174
175 #[test]
176 fn pick_index_returns_none_for_non_positive_weights() {
177 let wc = WeightedCollection::new(vec![("a", 0.0), ("b", -5.0)]);
178 assert!(wc.pick_index().is_none());
179 }
180
181 #[test]
182 fn shuffled_preserves_all_items() {
183 let items: Vec<(u32, f64)> = (0..10).map(|i| (i, (i as f64 + 1.0) * 0.1)).collect();
184 let shuffled = WeightedCollection::new(items).into_shuffled();
185 assert_eq!(shuffled.len(), 10);
186 let mut sorted = shuffled.clone();
187 sorted.sort();
188 assert_eq!(sorted, (0..10).collect::<Vec<_>>());
189 }
190
191 #[test]
192 fn shuffled_favors_higher_weight_items() {
193 let items = vec![("low", 0.1), ("high", 10.0)];
194 let mut high_first_count = 0;
195 let trials = 1000;
196 for _ in 0..trials {
197 let shuffled = WeightedCollection::new(items.clone()).into_shuffled();
198 if shuffled[0] == "high" {
199 high_first_count += 1;
200 }
201 }
202 assert!(
203 high_first_count > trials * 8 / 10,
204 "high-weight item should appear first >80% of the time, was {high_first_count}/{trials}"
205 );
206 }
207
208 #[test]
209 fn shuffled_empty_collection_returns_empty() {
210 let wc: WeightedCollection<&str> = WeightedCollection::new(vec![]);
211 assert!(wc.into_shuffled().is_empty());
212 }
213}