Merge pull request #4 from alexwlchan/f32-weights
- ID
66df6d7- date
2025-01-13 09:04:19+00:00- author
Alex Chan <alex@alexwlchan.net>- parents
631b299,1bb15ba- message
Merge pull request #4 from alexwlchan/f32-weights Use f64 for weight, not i32- changed files
4 files, 70 additions, 22 deletions
Changed files
CHANGELOG.md (54) → CHANGELOG.md (275)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6a13ba2..938b8df 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,11 @@
# CHANGELOG
+## v1.0.1 - 2025-01-13
+
+Internal refactoring.
+
+This fixes a theoretically possible but statistically unlikely bug where `randline` could return less than *k* items, even when there were less than *k* items in the input.
+
## v1.0.0 - 2025-01-11
Initial release.
Cargo.lock (6718) → Cargo.lock (6718)
diff --git a/Cargo.lock b/Cargo.lock
index 28728ff..09c5f9a 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -168,7 +168,7 @@ dependencies = [
[[package]]
name = "randline"
-version = "1.0.0"
+version = "1.0.1"
dependencies = [
"assert_cmd",
"rand",
Cargo.toml (109) → Cargo.toml (109)
diff --git a/Cargo.toml b/Cargo.toml
index b92a0a5..ccffa1e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "randline"
-version = "1.0.0"
+version = "1.0.1"
edition = "2021"
[dependencies]
src/sampling.rs (5944) → src/sampling.rs (7200)
diff --git a/src/sampling.rs b/src/sampling.rs
index ed08134..79b356c 100644
--- a/src/sampling.rs
+++ b/src/sampling.rs
@@ -1,5 +1,46 @@
use rand::Rng;
-use std::collections::HashMap;
+use std::cmp::Ordering;
+use std::collections::BinaryHeap;
+use std::ptr;
+
+struct ReservoirEntry<T> {
+ item: T,
+ weight: f64,
+}
+
+// Two items are only equal if they are identical -- that is, they're
+// the same underlying object in memory.
+//
+// [I suppose it's theoretically possible that there could be duplicate
+// reservoir entries, if the RNG was bugged and the input has repeated
+// values -- seems unlikely in practice, but this protects against it
+// just in case.]
+impl<T> PartialEq for ReservoirEntry<T> {
+ fn eq(&self, other: &Self) -> bool {
+ ptr::eq(self, other)
+ }
+}
+
+impl<T> Eq for ReservoirEntry<T> {}
+
+// Rust doesn't implement ordering for f64 because it includes NaN
+// which makes everything a mess. In particular NaN isn't comparable
+// with other floating-point numbers.
+//
+// We're generating all the f64 weights we'll be dealing with, so we
+// know we'll never have NaN in the mix -- we can do a partial comparison
+// and assert the two values are comparable when we unwrap.
+impl<T> PartialOrd for ReservoirEntry<T> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl<T> Ord for ReservoirEntry<T> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.weight.partial_cmp(&other.weight).unwrap()
+ }
+}
/// Choose a sample of `k` items from the iterator `items.
///
@@ -18,16 +59,17 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
}
// Create an empty reservoir.
- //
- // This is a map (weight) -> (item).
- let mut reservoir: HashMap<i32, T> = HashMap::with_capacity(k);
+ let mut reservoir: BinaryHeap<ReservoirEntry<T>> = BinaryHeap::with_capacity(k);
// Fill the reservoir with the first k items. If there are less
// than n items, we can exit immediately.
for _ in 1..=k {
match items.next() {
- Some(this_item) => reservoir.insert(random_weight(), this_item),
- None => return reservoir.into_values().collect(),
+ Some(this_item) => reservoir.push(ReservoirEntry {
+ item: this_item,
+ weight: pick_weight(),
+ }),
+ None => return reservoir.into_vec().into_iter().map(|r| r.item).collect(),
};
}
@@ -37,12 +79,12 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
// contains at least one item. Either `items` was non-empty, or if itwas
// was empty, then we'd already have returned when trying to fill the
// reservoir with the first k items.
- let mut max_weight: i32 = *reservoir.keys().max().unwrap();
+ let mut max_weight: f64 = reservoir.peek().unwrap().weight;
// Now go through the remaining items.
for this_item in items {
// Choose a weight for this item.
- let this_weight = random_weight();
+ let this_weight = pick_weight();
// If this is greater than the weights seen so far, we can ignore
// this item and move on to the next one.
@@ -52,24 +94,24 @@ pub fn reservoir_sample<T>(mut items: impl Iterator<Item = T>, k: usize) -> Vec<
// Otherwise, this item has a lower weight than the current item
// with max weight -- so we'll replace that item.
- assert!(reservoir.remove(&max_weight).is_some());
- reservoir.insert(this_weight, this_item);
+ assert!(reservoir.pop().is_some());
+ reservoir.push(ReservoirEntry {
+ item: this_item,
+ weight: this_weight,
+ });
// Recalculate the max weight for the new sample.
- max_weight = *reservoir.keys().max().unwrap();
+ max_weight = reservoir.peek().unwrap().weight;
}
- reservoir.into_values().collect()
+ let sample: Vec<T> = reservoir.into_vec().into_iter().map(|r| r.item).collect();
+ assert!(sample.len() == k);
+ sample
}
-/// Create a random weight, i.e. an integer selected from
-/// a uniform distribution.
-///
-/// Note: most implementations use a float in [0, 1], but comparing floats
-/// is more annoying so I used ints instead (e.g. you can do easy equality/
-/// comparison with i32, but not f32).
-fn random_weight() -> i32 {
- rand::thread_rng().gen_range(i32::MIN..i32::MAX)
+/// Create a random weight, i.e. a float selected from a uniform random.
+fn pick_weight() -> f64 {
+ rand::thread_rng().gen::<f64>()
}
#[cfg(test)]