Get the sampling code working
- ID
72223b6- date
2025-01-09 23:09:35+00:00- author
Alex Chan <alex@alexwlchan.net>- parent
d6a28f6- message
Get the sampling code working- changed files
1 file, 80 additions, 10 deletions
Changed files
src/sampling.rs (2787) → src/sampling.rs (5269)
diff --git a/src/sampling.rs b/src/sampling.rs
index 2b9e0b6..e74cfff 100644
--- a/src/sampling.rs
+++ b/src/sampling.rs
@@ -1,22 +1,27 @@
use rand::Rng;
use std::collections::HashMap;
-fn random_weight() -> i32 {
- rand::thread_rng().gen_range(i32::MIN..i32::MAX)
-}
-
+/// Choose a sample of `k` items from the iterator `items.
+///
+/// Each item has an equal chance of being picked -- that is, there's
+/// a 1/N chance of choosing an item, where N is the length of the iterator.
+///
+/// This implements "Algorithm L" for reservoir sampling, as described
+/// on the Wikipedia page:
+/// https://en.wikipedia.org/wiki/Reservoir_sampling#Optimal:_Algorithm_L
+///
pub fn reservoir_sample<T: std::fmt::Debug>(
mut items: impl Iterator<Item = T>,
- n: usize,
+ k: usize,
) -> Vec<T> {
// Create an empty reservoir.
//
// This is a map (weight) -> (item).
- let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(n);
+ let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(k);
// Fill the reservoir with the first n items. If there are less
// than n items, we can exit immediately.
- for _ in 1..=n {
+ for _ in 1..=k {
match items.next() {
Some(this_item) => reservoir.insert(random_weight(), this_item),
None => return reservoir.into_values().collect(),
@@ -37,19 +42,32 @@ pub fn reservoir_sample<T: std::fmt::Debug>(
continue;
}
- // Replace the item that had the max weight with the new item,
- // then recalculate the max weight.
+ // Otherwise, this item has a lower weight than the current item
+ // with max weight -- so we'll replace that item.
assert!(reservoir.remove(&max_weight).is_some());
reservoir.insert(this_weight, this_item);
+
+ // Recalculate the max weight for the new sample.
max_weight = *reservoir.keys().max().unwrap();
}
reservoir.into_values().collect()
}
+/// Create a random weight, i.e. an integer selected from
+/// a uniform distribution.
+///
+/// Note: most implementations use a float in [0, 1], but comparing floats
+/// is more annoying so I used ints instead (e.g. you can do easy equality/
+/// comparison with i32, but not f32).
+fn random_weight() -> i32 {
+ rand::thread_rng().gen_range(i32::MIN..i32::MAX)
+}
+
#[cfg(test)]
mod reservoir_sample_tests {
- use crate::sampling::reservoir_sample;
+ use super::*;
+ use std::collections::HashMap;
// If there are no items, then the sample is empty.
#[test]
@@ -80,6 +98,58 @@ mod reservoir_sample_tests {
assert!(equivalent_items(sample, vec!["a", "b", "c"]));
}
+ // It chooses items with a uniform distribution -- every item has
+ // an equal chance of being picked.
+ //
+ // We take a large number of samples of the integers 0..n, and check
+ // that each integer is picked about as many times as we expect.
+ #[test]
+ fn test_distribution() {
+ let k = 20;
+ let n = 100;
+ let iterations = 10000;
+
+ // How often was each integer picked?
+ let mut counts: HashMap<i32, usize> = HashMap::new();
+
+ // Run many iterations, create a sample, and record how many
+ // times each integer was picked.
+ for _ in 0..iterations {
+ let items = 0..n;
+ let sample = reservoir_sample(items, k);
+
+ for s in sample.into_iter() {
+ *counts.entry(s).or_insert(0) += 1;
+ }
+ }
+
+ // Now check that each number appears roughly as many times
+ // as we'd expect (within reasonable bounds).
+ let total_samples = iterations * k;
+ let expected = total_samples as f64 / n as f64;
+
+ for item in 0..n {
+ let item_count = *counts.get(&item).unwrap_or(&0);
+
+ let ratio = (item_count as f64) / expected;
+ assert!(
+ ratio > 0.8 && ratio < 1.2,
+ "Distribution appears skewed: count={}, expected={}",
+ item_count,
+ expected
+ );
+ }
+ }
+
+ /// Returns true if two vectors contain the same items (but potentially
+ /// in a different order), false otherwise.
+ ///
+ /// equivalent_items(vec![1, 3, 2], vec![3, 2, 1])
+ /// => true
+ ///
+ /// equivalent_items(vec![4, 5, 6], vec![3, 2, 1])
+ /// => false
+ ///
fn equivalent_items<T: std::cmp::PartialEq + std::cmp::Ord>(
mut vec1: Vec<T>,
mut vec2: Vec<T>,