Skip to main content

hopr_statistics/
weighted.rs

1use rand::RngExt;
2
3/// A collection of items with associated weights for probabilistic selection.
4///
5/// Weights must be positive (`> 0.0`); items with non-positive weights are
6/// treated as having zero probability for [`pick_one`](Self::pick_one) /
7/// [`pick_index`](Self::pick_index) and are placed at the end of the shuffled
8/// output in [`into_shuffled`](Self::into_shuffled).
9///
10/// # Examples
11///
12/// ```rust
13/// use hopr_statistics::WeightedCollection;
14///
15/// let wc = WeightedCollection::new(vec![("rare", 0.1), ("common", 10.0)]);
16/// let picked = wc.pick_one().expect("non-empty collection");
17/// assert!(picked == "rare" || picked == "common");
18/// ```
19pub struct WeightedCollection<T> {
20    items: Vec<(T, f64)>,
21    /// Pre-computed sum of positive weights (cached to avoid recomputing on every pick).
22    total_weight: f64,
23}
24
25impl<T> WeightedCollection<T> {
26    /// Create a new weighted collection from items paired with their weights.
27    pub fn new(items: Vec<(T, f64)>) -> Self {
28        let total_weight: f64 = items.iter().map(|(_, w)| w.max(0.0)).sum();
29        Self { items, total_weight }
30    }
31
32    /// Returns `true` if the collection contains no items.
33    pub fn is_empty(&self) -> bool {
34        self.items.is_empty()
35    }
36
37    /// Returns the number of items in the collection.
38    pub fn len(&self) -> usize {
39        self.items.len()
40    }
41
42    /// Iterates over `(item, weight)` pairs.
43    pub fn iter(&self) -> impl Iterator<Item = &(T, f64)> {
44        self.items.iter()
45    }
46
47    /// Returns the index of a randomly selected item, weighted by probability
48    /// proportional to its weight.
49    ///
50    /// Returns `None` if the collection is empty or all weights are non-positive.
51    pub fn pick_index(&self) -> Option<usize> {
52        if self.total_weight <= 0.0 {
53            return None;
54        }
55
56        if self.items.len() == 1 {
57            return Some(0);
58        }
59
60        let mut rng = rand::rng();
61        let r = rng.random_range(0.0..self.total_weight);
62        let mut cumulative = 0.0;
63        for (i, (_, weight)) in self.items.iter().enumerate() {
64            cumulative += weight.max(0.0);
65            if r < cumulative {
66                return Some(i);
67            }
68        }
69
70        // Floating-point edge case: return the last positive-weight item.
71        self.items.iter().rposition(|(_, weight)| *weight > 0.0)
72    }
73}
74
75impl<T> WeightedCollection<T> {
76    /// Pick a reference to one item at random, with probability proportional
77    /// to its weight.
78    ///
79    /// Returns `None` if the collection is empty or all weights are non-positive.
80    pub fn pick_ref(&self) -> Option<&T> {
81        self.pick_index().map(|i| &self.items[i].0)
82    }
83}
84
85impl<T: Clone> WeightedCollection<T> {
86    /// Pick one item at random, with probability proportional to its weight.
87    ///
88    /// Returns `None` if the collection is empty or all weights are non-positive.
89    pub fn pick_one(&self) -> Option<T> {
90        self.pick_ref().cloned()
91    }
92}
93
94impl<T> WeightedCollection<T> {
95    /// Consume the collection and return items in a weighted random permutation.
96    ///
97    /// Uses the Efraimidis–Spirakis algorithm: each item is assigned a key
98    /// `random()^(1/weight)` and the items are sorted by descending key.
99    /// Higher-weight items appear earlier with higher probability, but all
100    /// items retain a nonzero chance of appearing at any position.
101    pub fn into_shuffled(self) -> Vec<T> {
102        let mut rng = rand::rng();
103
104        // Reuse the existing Vec — replace weights with Efraimidis–Spirakis keys in-place.
105        let mut keyed = self.items;
106        for (_, weight) in keyed.iter_mut() {
107            *weight = if *weight > 0.0 {
108                let u: f64 = rng.random_range(f64::EPSILON..1.0);
109                u.powf(1.0 / *weight)
110            } else {
111                0.0
112            };
113        }
114
115        keyed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
116        keyed.into_iter().map(|(item, _)| item).collect()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn pick_one_returns_none_for_empty_collection() {
126        let wc: WeightedCollection<&str> = WeightedCollection::new(vec![]);
127        assert!(wc.pick_one().is_none());
128    }
129
130    #[test]
131    fn pick_one_returns_sole_item_with_positive_weight() {
132        let wc = WeightedCollection::new(vec![("only", 1.0)]);
133        assert_eq!(wc.pick_one(), Some("only"));
134    }
135
136    #[test]
137    fn pick_one_returns_none_for_sole_item_with_non_positive_weight() {
138        let wc = WeightedCollection::new(vec![("only", 0.0)]);
139        assert!(wc.pick_one().is_none());
140        let wc = WeightedCollection::new(vec![("only", -1.0)]);
141        assert!(wc.pick_one().is_none());
142    }
143
144    #[test]
145    fn pick_one_favors_higher_weight() {
146        let wc = WeightedCollection::new(vec![("low", 0.01), ("high", 100.0)]);
147        let mut high_count = 0;
148        let trials = 1000;
149        for _ in 0..trials {
150            if wc.pick_one() == Some("high") {
151                high_count += 1;
152            }
153        }
154        assert!(
155            high_count > trials * 9 / 10,
156            "high-weight item should be picked >90% of the time, was {high_count}/{trials}"
157        );
158    }
159
160    #[test]
161    fn pick_one_returns_none_for_all_non_positive_weights() {
162        let wc = WeightedCollection::new(vec![("a", 0.0), ("b", -1.0)]);
163        assert!(wc.pick_one().is_none());
164    }
165
166    #[test]
167    fn pick_index_returns_valid_index() {
168        let wc = WeightedCollection::new(vec![("a", 1.0), ("b", 2.0), ("c", 3.0)]);
169        for _ in 0..100 {
170            let idx = wc.pick_index().expect("should pick an index");
171            assert!(idx < 3);
172        }
173    }
174
175    #[test]
176    fn pick_index_returns_none_for_non_positive_weights() {
177        let wc = WeightedCollection::new(vec![("a", 0.0), ("b", -5.0)]);
178        assert!(wc.pick_index().is_none());
179    }
180
181    #[test]
182    fn shuffled_preserves_all_items() {
183        let items: Vec<(u32, f64)> = (0..10).map(|i| (i, (i as f64 + 1.0) * 0.1)).collect();
184        let shuffled = WeightedCollection::new(items).into_shuffled();
185        assert_eq!(shuffled.len(), 10);
186        let mut sorted = shuffled.clone();
187        sorted.sort();
188        assert_eq!(sorted, (0..10).collect::<Vec<_>>());
189    }
190
191    #[test]
192    fn shuffled_favors_higher_weight_items() {
193        let items = vec![("low", 0.1), ("high", 10.0)];
194        let mut high_first_count = 0;
195        let trials = 1000;
196        for _ in 0..trials {
197            let shuffled = WeightedCollection::new(items.clone()).into_shuffled();
198            if shuffled[0] == "high" {
199                high_first_count += 1;
200            }
201        }
202        assert!(
203            high_first_count > trials * 8 / 10,
204            "high-weight item should appear first >80% of the time, was {high_first_count}/{trials}"
205        );
206    }
207
208    #[test]
209    fn shuffled_empty_collection_returns_empty() {
210        let wc: WeightedCollection<&str> = WeightedCollection::new(vec![]);
211        assert!(wc.into_shuffled().is_empty());
212    }
213}