Skip to main content

Get the sampling code working

ID
72223b6
date
2025-01-09 23:09:35+00:00
author
Alex Chan <alex@alexwlchan.net>
parent
d6a28f6
message
Get the sampling code working
changed files
1 file, 80 additions, 10 deletions

Changed files

src/sampling.rs (2787) → src/sampling.rs (5269)

diff --git a/src/sampling.rs b/src/sampling.rs
index 2b9e0b6..e74cfff 100644
--- a/src/sampling.rs
+++ b/src/sampling.rs
@@ -1,22 +1,27 @@
 use rand::Rng;
 use std::collections::HashMap;
 
-fn random_weight() -> i32 {
-    rand::thread_rng().gen_range(i32::MIN..i32::MAX)
-}
-
+/// Choose a sample of `k` items from the iterator `items.
+///
+/// Each item has an equal chance of being picked -- that is, there's
+/// a 1/N chance of choosing an item, where N is the length of the iterator.
+///
+/// This implements "Algorithm L" for reservoir sampling, as described
+/// on the Wikipedia page:
+/// https://en.wikipedia.org/wiki/Reservoir_sampling#Optimal:_Algorithm_L
+///
 pub fn reservoir_sample<T: std::fmt::Debug>(
     mut items: impl Iterator<Item = T>,
-    n: usize,
+    k: usize,
 ) -> Vec<T> {
     // Create an empty reservoir.
     //
     // This is a map (weight) -> (item).
-    let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(n);
+    let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(k);
 
     // Fill the reservoir with the first n items.  If there are less
     // than n items, we can exit immediately.
-    for _ in 1..=n {
+    for _ in 1..=k {
         match items.next() {
             Some(this_item) => reservoir.insert(random_weight(), this_item),
             None => return reservoir.into_values().collect(),
@@ -37,19 +42,32 @@ pub fn reservoir_sample<T: std::fmt::Debug>(
             continue;
         }
 
-        // Replace the item that had the max weight with the new item,
-        // then recalculate the max weight.
+        // Otherwise, this item has a lower weight than the current item
+        // with max weight -- so we'll replace that item.
         assert!(reservoir.remove(&max_weight).is_some());
         reservoir.insert(this_weight, this_item);
+
+        // Recalculate the max weight for the new sample.
         max_weight = *reservoir.keys().max().unwrap();
     }
 
     reservoir.into_values().collect()
 }
 
+/// Create a random weight, i.e. an integer selected from
+/// a uniform distribution.
+///
+/// Note: most implementations use a float in [0, 1], but comparing floats
+/// is more annoying so I used ints instead (e.g. you can do easy equality/
+/// comparison with i32, but not f32).
+fn random_weight() -> i32 {
+    rand::thread_rng().gen_range(i32::MIN..i32::MAX)
+}
+
 #[cfg(test)]
 mod reservoir_sample_tests {
-    use crate::sampling::reservoir_sample;
+    use super::*;
+    use std::collections::HashMap;
 
     // If there are no items, then the sample is empty.
     #[test]
@@ -80,6 +98,58 @@ mod reservoir_sample_tests {
         assert!(equivalent_items(sample, vec!["a", "b", "c"]));
     }
 
+    // It chooses items with a uniform distribution -- every item has
+    // an equal chance of being picked.
+    //
+    // We take a large number of samples of the integers 0..n, and check
+    // that each integer is picked about as many times as we expect.
+    #[test]
+    fn test_distribution() {
+        let k = 20;
+        let n = 100;
+        let iterations = 10000;
+
+        // How often was each integer picked?
+        let mut counts: HashMap<i32, usize> = HashMap::new();
+
+        // Run many iterations, create a sample, and record how many
+        // times each integer was picked.
+        for _ in 0..iterations {
+            let items = 0..n;
+            let sample = reservoir_sample(items, k);
+
+            for s in sample.into_iter() {
+                *counts.entry(s).or_insert(0) += 1;
+            }
+        }
+
+        // Now check that each number appears roughly as many times
+        // as we'd expect (within reasonable bounds).
+        let total_samples = iterations * k;
+        let expected = total_samples as f64 / n as f64;
+
+        for item in 0..n {
+            let item_count = *counts.get(&item).unwrap_or(&0);
+
+            let ratio = (item_count as f64) / expected;
+            assert!(
+                ratio > 0.8 && ratio < 1.2,
+                "Distribution appears skewed: count={}, expected={}",
+                item_count,
+                expected
+            );
+        }
+    }
+
+    /// Returns true if two vectors contain the same items (but potentially
+    /// in a different order), false otherwise.
+    ///
+    ///     equivalent_items(vec![1, 3, 2], vec![3, 2, 1])
+    ///     => true
+    ///
+    ///     equivalent_items(vec![4, 5, 6], vec![3, 2, 1])
+    ///     => false
+    ///
     fn equivalent_items<T: std::cmp::PartialEq + std::cmp::Ord>(
         mut vec1: Vec<T>,
         mut vec2: Vec<T>,