Skip to main content

src/sampling.rs

1use rand::Rng;
2use std::cmp::Ordering;
3use std::collections::BinaryHeap;
4use std::ptr;
6struct WeightedItem<T> {
7 item: T,
8 weight: f64,
9}
11// Two items are only equal if they are identical -- that is, they're
12// the same underlying object in memory.
13//
14// [I suppose it's theoretically possible that there could be duplicate
15// reservoir entries, if the RNG was bugged and the input has repeated
16// values -- seems unlikely in practice, but this protects against it
17// just in case.]
18impl<T> PartialEq for WeightedItem<T> {
19 fn eq(&self, other: &Self) -> bool {
20 ptr::eq(self, other)
21 }
24impl<T> Eq for WeightedItem<T> {}
26// Rust doesn't implement ordering for f64 because it includes NaN
27// which makes everything a mess. In particular NaN isn't comparable
28// with other floating-point numbers.
29//
30// We're generating all the f64 weights we'll be dealing with, so we
31// know we'll never have NaN in the mix -- we can do a partial comparison
32// and assert the two values are comparable when we unwrap.
33impl<T> PartialOrd for WeightedItem<T> {
34 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35 Some(self.cmp(other))
36 }
39impl<T> Ord for WeightedItem<T> {
40 fn cmp(&self, other: &Self) -> Ordering {
41 self.weight.partial_cmp(&other.weight).unwrap()
42 }
45/// Choose a sample of `k` items from the iterator `items.
46///
47/// Each item has an equal chance of being picked -- that is, there's
48/// a 1/N chance of choosing an item, where N is the length of the iterator.
49///
50/// This implements "Algorithm L" for reservoir sampling, as described
51/// on the Wikipedia page:
52/// https://en.wikipedia.org/wiki/Reservoir_sampling#Optimal:_Algorithm_L
53///
54pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<T> {
55 // Taking a sample with k=0 doesn't make much sense in practice,
56 // but we include this to avoid problems downstream.
57 if k == 0 {
58 return vec![];
59 }
61 // Create an empty reservoir.
62 let mut reservoir: BinaryHeap<WeightedItem<T>> = BinaryHeap::with_capacity(k);
64 // Fill the reservoir with the first k items. If there are less
65 // than n items, we can exit immediately.
66 for _ in 1..=k {
67 match items.next() {
68 Some(this_item) => reservoir.push(WeightedItem {
69 item: this_item,
70 weight: pick_weight(),
71 }),
72 None => return reservoir.into_vec().into_iter().map(|r| r.item).collect(),
73 };
74 }
76 // What's the largest weight seen so far?
77 //
78 // Note: we're okay to `unwrap()` here because we know that `reservoir`
79 // contains at least one item. Either `items` was non-empty, or if itwas
80 // was empty, then we'd already have returned when trying to fill the
81 // reservoir with the first k items.
82 let mut max_weight: f64 = reservoir.peek().unwrap().weight;
84 // Now go through the remaining items.
85 for this_item in items {
86 // Choose a weight for this item.
87 let this_weight = pick_weight();
89 // If this is greater than the weights seen so far, we can ignore
90 // this item and move on to the next one.
91 if this_weight > max_weight {
92 continue;
93 }
95 // Otherwise, this item has a lower weight than the current item
96 // with max weight -- so we'll replace that item.
97 assert!(reservoir.pop().is_some());
98 reservoir.push(WeightedItem {
99 item: this_item,
100 weight: this_weight,
101 });
103 // Recalculate the max weight for the new sample.
104 max_weight = reservoir.peek().unwrap().weight;
105 }
107 let sample: Vec<T> = reservoir.into_vec().into_iter().map(|r| r.item).collect();
108 assert!(sample.len() == k);
109 sample
112/// Create a random weight u_i ~ U[0,1]
113fn pick_weight() -> f64 {
114 rand::rng().random_range(0.0..1.0)
117#[cfg(test)]
118mod reservoir_sample_tests {
119 use super::*;
120 use std::collections::HashMap;
122 // If there are no items, then the sample is empty.
123 #[test]
124 fn it_returns_an_empty_sample_for_an_empty_input() {
125 let items: Vec<usize> = vec![];
126 let sample = reservoir_sample(items.into_iter(), 5);
128 assert_eq!(sample.len(), 0);
129 }
131 // If there are less items than the sample size, then the sample is
132 // the complete set.
133 #[test]
134 fn it_returns_complete_sample_if_less_items_than_sample_size() {
135 let items = vec!["a", "b", "c"];
136 let sample = reservoir_sample(items.into_iter(), 5);
138 assert!(equivalent_items(sample, vec!["a", "b", "c"]));
139 }
141 // If there's an equal number of items to the sample size, then the
142 // sample is the complete set.
143 #[test]
144 fn it_returns_complete_sample_if_item_count_equal_to_sample_size() {
145 let items = vec!["a", "b", "c"];
146 let sample = reservoir_sample(items.into_iter(), 3);
148 assert!(equivalent_items(sample, vec!["a", "b", "c"]));
149 }
151 // If k=0, then it returns an empty sample.
152 #[test]
153 fn it_returns_an_empty_sample_if_k_zero() {
154 let items = vec!["a", "b", "c"];
155 let sample = reservoir_sample(items.into_iter(), 0);
157 assert_eq!(sample.len(), 0);
158 }
160 // It chooses items with a uniform distribution -- every item has
161 // an equal chance of being picked.
162 //
163 // We take a large number of samples of the integers 0..n, and check
164 // that each integer is picked about as many times as we expect.
165 #[test]
166 fn test_distribution() {
167 let k = 20;
168 let n = 100;
169 let iterations = 10000;
171 // How often was each integer picked?
172 let mut counts: HashMap<i32, usize> = HashMap::new();
174 // Run many iterations, create a sample, and record how many
175 // times each integer was picked.
176 for _ in 0..iterations {
177 let items = 0..n;
178 let sample = reservoir_sample(items, k);
180 for s in sample.into_iter() {
181 *counts.entry(s).or_insert(0) += 1;
182 }
183 }
185 // Now check that each number appears roughly as many times
186 // as we'd expect (within reasonable bounds).
187 let total_samples = iterations * k;
188 let expected = total_samples as f64 / n as f64;
190 for item in 0..n {
191 let item_count = *counts.get(&item).unwrap_or(&0);
193 let ratio = (item_count as f64) / expected;
194 assert!(
195 ratio > 0.8 && ratio < 1.2,
196 "Distribution appears skewed: count={}, expected={}",
197 item_count,
198 expected
199 );
200 }
201 }
203 /// Returns true if two vectors contain the same items (but potentially
204 /// in a different order), false otherwise.
205 ///
206 /// equivalent_items(vec![1, 3, 2], vec![3, 2, 1])
207 /// => true
208 ///
209 /// equivalent_items(vec![4, 5, 6], vec![3, 2, 1])
210 /// => false
211 ///
212 fn equivalent_items<T: std::cmp::PartialEq + std::cmp::Ord>(
213 mut vec1: Vec<T>,
214 mut vec2: Vec<T>,
215 ) -> bool {
216 vec1.sort();
217 vec2.sort();
219 vec1 == vec2
220 }