Skip to main content

Merge pull request #4 from alexwlchan/f32-weights

ID
66df6d7
date
2025-01-13 09:04:19+00:00
author
Alex Chan <alex@alexwlchan.net>
parents
631b299, 1bb15ba
message
Merge pull request #4 from alexwlchan/f32-weights

Use f64 for weight, not i32
changed files
4 files, 70 additions, 22 deletions

Changed files

CHANGELOG.md (54) → CHANGELOG.md (275)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6a13ba2..938b8df 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,11 @@
 # CHANGELOG
 
+## v1.0.1 - 2025-01-13
+
+Internal refactoring.
+
+This fixes a theoretically possible but statistically unlikely bug where `randline` could return less than *k* items, even when there were less than *k* items in the input.
+
 ## v1.0.0 - 2025-01-11
 
 Initial release.

Cargo.lock (6718) → Cargo.lock (6718)

diff --git a/Cargo.lock b/Cargo.lock
index 28728ff..09c5f9a 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -168,7 +168,7 @@ dependencies = [
 
 [[package]]
 name = "randline"
-version = "1.0.0"
+version = "1.0.1"
 dependencies = [
  "assert_cmd",
  "rand",

Cargo.toml (109) → Cargo.toml (109)

diff --git a/Cargo.toml b/Cargo.toml
index b92a0a5..ccffa1e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "randline"
-version = "1.0.0"
+version = "1.0.1"
 edition = "2021"
 
 [dependencies]

src/sampling.rs (5944) → src/sampling.rs (7200)

diff --git a/src/sampling.rs b/src/sampling.rs
index ed08134..79b356c 100644
--- a/src/sampling.rs
+++ b/src/sampling.rs
@@ -1,5 +1,46 @@
 use rand::Rng;
-use std::collections::HashMap;
+use std::cmp::Ordering;
+use std::collections::BinaryHeap;
+use std::ptr;
+
+struct ReservoirEntry<T> {
+    item: T,
+    weight: f64,
+}
+
+// Two items are only equal if they are identical -- that is, they're
+// the same underlying object in memory.
+//
+// [I suppose it's theoretically possible that there could be duplicate
+// reservoir entries, if the RNG was bugged and the input has repeated
+// values -- seems unlikely in practice, but this protects against it
+// just in case.]
+impl<T> PartialEq for ReservoirEntry<T> {
+    fn eq(&self, other: &Self) -> bool {
+        ptr::eq(self, other)
+    }
+}
+
+impl<T> Eq for ReservoirEntry<T> {}
+
+// Rust doesn't implement ordering for f64 because it includes NaN
+// which makes everything a mess.  In particular NaN isn't comparable
+// with other floating-point numbers.
+//
+// We're generating all the f64 weights we'll be dealing with, so we
+// know we'll never have NaN in the mix -- we can do a partial comparison
+// and assert the two values are comparable when we unwrap.
+impl<T> PartialOrd for ReservoirEntry<T> {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl<T> Ord for ReservoirEntry<T> {
+    fn cmp(&self, other: &Self) -> Ordering {
+        self.weight.partial_cmp(&other.weight).unwrap()
+    }
+}
 
 /// Choose a sample of `k` items from the iterator `items.
 ///
@@ -18,16 +59,17 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
     }
 
     // Create an empty reservoir.
-    //
-    // This is a map (weight) -> (item).
-    let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(k);
+    let mut reservoir: BinaryHeap<ReservoirEntry<T>> = BinaryHeap::with_capacity(k);
 
     // Fill the reservoir with the first k items.  If there are less
     // than n items, we can exit immediately.
     for _ in 1..=k {
         match items.next() {
-            Some(this_item) => reservoir.insert(random_weight(), this_item),
-            None => return reservoir.into_values().collect(),
+            Some(this_item) => reservoir.push(ReservoirEntry {
+                item: this_item,
+                weight: pick_weight(),
+            }),
+            None => return reservoir.into_vec().into_iter().map(|r| r.item).collect(),
         };
     }
 
@@ -37,12 +79,12 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
     // contains at least one item.  Either `items` was non-empty, or if itwas
     // was empty, then we'd already have returned when trying to fill the
     // reservoir with the first k items.
-    let mut max_weight: i32 = *reservoir.keys().max().unwrap();
+    let mut max_weight: f64 = reservoir.peek().unwrap().weight;
 
     // Now go through the remaining items.
     for this_item in items {
         // Choose a weight for this item.
-        let this_weight = random_weight();
+        let this_weight = pick_weight();
 
         // If this is greater than the weights seen so far, we can ignore
         // this item and move on to the next one.
@@ -52,24 +94,24 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
 
         // Otherwise, this item has a lower weight than the current item
         // with max weight -- so we'll replace that item.
-        assert!(reservoir.remove(&max_weight).is_some());
-        reservoir.insert(this_weight, this_item);
+        assert!(reservoir.pop().is_some());
+        reservoir.push(ReservoirEntry {
+            item: this_item,
+            weight: this_weight,
+        });
 
         // Recalculate the max weight for the new sample.
-        max_weight = *reservoir.keys().max().unwrap();
+        max_weight = reservoir.peek().unwrap().weight;
     }
 
-    reservoir.into_values().collect()
+    let sample: Vec<T> = reservoir.into_vec().into_iter().map(|r| r.item).collect();
+    assert!(sample.len() == k);
+    sample
 }
 
-/// Create a random weight, i.e. an integer selected from
-/// a uniform distribution.
-///
-/// Note: most implementations use a float in [0, 1], but comparing floats
-/// is more annoying so I used ints instead (e.g. you can do easy equality/
-/// comparison with i32, but not f32).
-fn random_weight() -> i32 {
-    rand::thread_rng().gen_range(i32::MIN..i32::MAX)
+/// Create a random weight, i.e. a float selected from a uniform random.
+fn pick_weight() -> f64 {
+    rand::thread_rng().gen::<f64>()
 }
 
 #[cfg(test)]