3use std::collections::BinaryHeap;
6struct WeightedItem<T> {
11// Two items are only equal if they are identical -- that is, they're
12// the same underlying object in memory.
14// [I suppose it's theoretically possible that there could be duplicate
15// reservoir entries, if the RNG was bugged and the input has repeated
16// values -- seems unlikely in practice, but this protects against it
18impl<T> PartialEq for WeightedItem<T> {
19 fn eq(&self, other: &Self) -> bool {
24impl<T> Eq for WeightedItem<T> {}
26// Rust doesn't implement ordering for f64 because it includes NaN
27// which makes everything a mess. In particular NaN isn't comparable
28// with other floating-point numbers.
30// We're generating all the f64 weights we'll be dealing with, so we
31// know we'll never have NaN in the mix -- we can do a partial comparison
32// and assert the two values are comparable when we unwrap.
33impl<T> PartialOrd for WeightedItem<T> {
34 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
39impl<T> Ord for WeightedItem<T> {
40 fn cmp(&self, other: &Self) -> Ordering {
41 self.weight.partial_cmp(&other.weight).unwrap()
45/// Choose a sample of `k` items from the iterator `items.
47/// Each item has an equal chance of being picked -- that is, there's
48/// a 1/N chance of choosing an item, where N is the length of the iterator.
50/// This implements "Algorithm L" for reservoir sampling, as described
51/// on the Wikipedia page:
52/// https://en.wikipedia.org/wiki/Reservoir_sampling#Optimal:_Algorithm_L
54pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<T> {
55 // Taking a sample with k=0 doesn't make much sense in practice,
56 // but we include this to avoid problems downstream.
61 // Create an empty reservoir.
62 let mut reservoir: BinaryHeap<WeightedItem<T>> = BinaryHeap::with_capacity(k);
64 // Fill the reservoir with the first k items. If there are less
65 // than n items, we can exit immediately.
68 Some(this_item) => reservoir.push(WeightedItem {
70 weight: pick_weight(),
72 None => return reservoir.into_vec().into_iter().map(|r| r.item).collect(),
76 // What's the largest weight seen so far?
78 // Note: we're okay to `unwrap()` here because we know that `reservoir`
79 // contains at least one item. Either `items` was non-empty, or if itwas
80 // was empty, then we'd already have returned when trying to fill the
81 // reservoir with the first k items.
82 let mut max_weight: f64 = reservoir.peek().unwrap().weight;
84 // Now go through the remaining items.
85 for this_item in items {
86 // Choose a weight for this item.
87 let this_weight = pick_weight();
89 // If this is greater than the weights seen so far, we can ignore
90 // this item and move on to the next one.
91 if this_weight > max_weight {
95 // Otherwise, this item has a lower weight than the current item
96 // with max weight -- so we'll replace that item.
97 assert!(reservoir.pop().is_some());
98 reservoir.push(WeightedItem {
103 // Recalculate the max weight for the new sample.
104 max_weight = reservoir.peek().unwrap().weight;
107 let sample: Vec<T> = reservoir.into_vec().into_iter().map(|r| r.item).collect();
108 assert!(sample.len() == k);
112/// Create a random weight u_i ~ U[0,1]
113fn pick_weight() -> f64 {
114 rand::rng().random_range(0.0..1.0)
118mod reservoir_sample_tests {
120 use std::collections::HashMap;
122 // If there are no items, then the sample is empty.
124 fn it_returns_an_empty_sample_for_an_empty_input() {
125 let items: Vec<usize> = vec![];
126 let sample = reservoir_sample(items.into_iter(), 5);
128 assert_eq!(sample.len(), 0);
131 // If there are less items than the sample size, then the sample is
134 fn it_returns_complete_sample_if_less_items_than_sample_size() {
135 let items = vec!["a", "b", "c"];
136 let sample = reservoir_sample(items.into_iter(), 5);
138 assert!(equivalent_items(sample, vec!["a", "b", "c"]));
141 // If there's an equal number of items to the sample size, then the
142 // sample is the complete set.
144 fn it_returns_complete_sample_if_item_count_equal_to_sample_size() {
145 let items = vec!["a", "b", "c"];
146 let sample = reservoir_sample(items.into_iter(), 3);
148 assert!(equivalent_items(sample, vec!["a", "b", "c"]));
151 // If k=0, then it returns an empty sample.
153 fn it_returns_an_empty_sample_if_k_zero() {
154 let items = vec!["a", "b", "c"];
155 let sample = reservoir_sample(items.into_iter(), 0);
157 assert_eq!(sample.len(), 0);
160 // It chooses items with a uniform distribution -- every item has
161 // an equal chance of being picked.
163 // We take a large number of samples of the integers 0..n, and check
164 // that each integer is picked about as many times as we expect.
166 fn test_distribution() {
169 let iterations = 10000;
171 // How often was each integer picked?
172 let mut counts: HashMap<i32, usize> = HashMap::new();
174 // Run many iterations, create a sample, and record how many
175 // times each integer was picked.
176 for _ in 0..iterations {
178 let sample = reservoir_sample(items, k);
180 for s in sample.into_iter() {
181 *counts.entry(s).or_insert(0) += 1;
185 // Now check that each number appears roughly as many times
186 // as we'd expect (within reasonable bounds).
187 let total_samples = iterations * k;
188 let expected = total_samples as f64 / n as f64;
191 let item_count = *counts.get(&item).unwrap_or(&0);
193 let ratio = (item_count as f64) / expected;
195 ratio > 0.8 && ratio < 1.2,
196 "Distribution appears skewed: count={}, expected={}",
203 /// Returns true if two vectors contain the same items (but potentially
204 /// in a different order), false otherwise.
206 /// equivalent_items(vec![1, 3, 2], vec![3, 2, 1])
209 /// equivalent_items(vec![4, 5, 6], vec![3, 2, 1])
212 fn equivalent_items<T: std::cmp::PartialEq + std::cmp::Ord>(