std::rand: move Weighted to distributions.
A user constructs the WeightedChoice distribution and then samples from it, which allows it to use binary search internally.
This commit is contained in:
parent
83aa1abb19
commit
0bba73c0d1
|
@ -20,8 +20,11 @@ that do not need to record state.
|
|||
|
||||
*/
|
||||
|
||||
use iter::range;
|
||||
use option::{Some, None};
|
||||
use num;
|
||||
use rand::{Rng,Rand};
|
||||
use clone::Clone;
|
||||
|
||||
pub use self::range::Range;
|
||||
|
||||
|
@ -61,8 +64,128 @@ impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
|
|||
}
|
||||
}
|
||||
|
||||
mod ziggurat_tables;
|
||||
/// A value with a particular weight for use with `WeightedChoice`.
|
||||
pub struct Weighted<T> {
|
||||
/// The numerical weight of this item
|
||||
weight: uint,
|
||||
/// The actual item which is being weighted
|
||||
item: T,
|
||||
}
|
||||
|
||||
/// A distribution that selects from a finite collection of weighted items.
|
||||
///
|
||||
/// Each item has an associated weight that influences how likely it
|
||||
/// is to be chosen: higher weight is more likely.
|
||||
///
|
||||
/// The `Clone` restriction is a limitation of the `Sample` and
|
||||
/// `IndepedentSample` traits. Note that `&T` is (cheaply) `Clone` for
|
||||
/// all `T`, as is `uint`, so one can store references or indices into
|
||||
/// another vector.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::rand;
|
||||
/// use std::rand::distributions::{Weighted, WeightedChoice, IndepedentSample};
|
||||
///
|
||||
/// fn main() {
|
||||
/// let wc = WeightedChoice::new(~[Weighted { weight: 2, item: 'a' },
|
||||
/// Weighted { weight: 4, item: 'b' },
|
||||
/// Weighted { weight: 1, item: 'c' }]);
|
||||
/// let rng = rand::task_rng();
|
||||
/// for _ in range(0, 16) {
|
||||
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
|
||||
/// println!("{}", wc.ind_sample(rng));
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub struct WeightedChoice<T> {
|
||||
priv items: ~[Weighted<T>],
|
||||
priv weight_range: Range<uint>
|
||||
}
|
||||
|
||||
impl<T: Clone> WeightedChoice<T> {
|
||||
/// Create a new `WeightedChoice`.
|
||||
///
|
||||
/// Fails if:
|
||||
/// - `v` is empty
|
||||
/// - the total weight is 0
|
||||
/// - the total weight is larger than a `uint` can contain.
|
||||
pub fn new(mut items: ~[Weighted<T>]) -> WeightedChoice<T> {
|
||||
// strictly speaking, this is subsumed by the total weight == 0 case
|
||||
assert!(!items.is_empty(), "WeightedChoice::new called with no items");
|
||||
|
||||
let mut running_total = 0u;
|
||||
|
||||
// we convert the list from individual weights to cumulative
|
||||
// weights so we can binary search. This *could* drop elements
|
||||
// with weight == 0 as an optimisation.
|
||||
for item in items.mut_iter() {
|
||||
running_total = running_total.checked_add(&item.weight)
|
||||
.expect("WeightedChoice::new called with a total weight larger \
|
||||
than a uint can contain");
|
||||
|
||||
item.weight = running_total;
|
||||
}
|
||||
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
|
||||
|
||||
WeightedChoice {
|
||||
items: items,
|
||||
// we're likely to be generating numbers in this range
|
||||
// relatively often, so might as well cache it
|
||||
weight_range: Range::new(0, running_total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Sample<T> for WeightedChoice<T> {
|
||||
fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
|
||||
}
|
||||
|
||||
impl<T: Clone> IndependentSample<T> for WeightedChoice<T> {
|
||||
fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
|
||||
// we want to find the first element that has cumulative
|
||||
// weight > sample_weight, which we do by binary since the
|
||||
// cumulative weights of self.items are sorted.
|
||||
|
||||
// choose a weight in [0, total_weight)
|
||||
let sample_weight = self.weight_range.ind_sample(rng);
|
||||
|
||||
// short circuit when it's the first item
|
||||
if sample_weight < self.items[0].weight {
|
||||
return self.items[0].item.clone();
|
||||
}
|
||||
|
||||
let mut idx = 0;
|
||||
let mut modifier = self.items.len();
|
||||
|
||||
// now we know that every possibility has an element to the
|
||||
// left, so we can just search for the last element that has
|
||||
// cumulative weight <= sample_weight, then the next one will
|
||||
// be "it". (Note that this greatest element will never be the
|
||||
// last element of the vector, since sample_weight is chosen
|
||||
// in [0, total_weight) and the cumulative weight of the last
|
||||
// one is exactly the total weight.)
|
||||
while modifier > 1 {
|
||||
let i = idx + modifier / 2;
|
||||
if self.items[i].weight <= sample_weight {
|
||||
// we're small, so look to the right, but allow this
|
||||
// exact element still.
|
||||
idx = i;
|
||||
// we need the `/ 2` to round up otherwise we'll drop
|
||||
// the trailing elements when `modifier` is odd.
|
||||
modifier += 1;
|
||||
} else {
|
||||
// otherwise we're too big, so go left. (i.e. do
|
||||
// nothing)
|
||||
}
|
||||
modifier /= 2;
|
||||
}
|
||||
return self.items[idx + 1].item.clone();
|
||||
}
|
||||
}
|
||||
|
||||
mod ziggurat_tables;
|
||||
|
||||
/// Sample a random number using the Ziggurat method (specifically the
|
||||
/// ZIGNOR variant from Doornik 2005). Most of the arguments are
|
||||
|
@ -302,6 +425,18 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
// 0, 1, 2, 3, ...
|
||||
struct CountingRng { i: u32 }
|
||||
impl Rng for CountingRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.i += 1;
|
||||
self.i - 1
|
||||
}
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
self.next_u32() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rand_sample() {
|
||||
let mut rand_sample = RandSample::<ConstRand>;
|
||||
|
@ -344,6 +479,77 @@ mod tests {
|
|||
fn test_exp_invalid_lambda_neg() {
|
||||
Exp::new(-10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_choice() {
|
||||
// this makes assumptions about the internal implementation of
|
||||
// WeightedChoice, specifically: it doesn't reorder the items,
|
||||
// it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
|
||||
// 1, internally; modulo a modulo operation).
|
||||
|
||||
macro_rules! t (
|
||||
($items:expr, $expected:expr) => {{
|
||||
let wc = WeightedChoice::new($items);
|
||||
let expected = $expected;
|
||||
|
||||
let mut rng = CountingRng { i: 0 };
|
||||
|
||||
for &val in expected.iter() {
|
||||
assert_eq!(wc.ind_sample(&mut rng), val)
|
||||
}
|
||||
}}
|
||||
);
|
||||
|
||||
t!(~[Weighted { weight: 1, item: 10}], ~[10]);
|
||||
|
||||
// skip some
|
||||
t!(~[Weighted { weight: 0, item: 20},
|
||||
Weighted { weight: 2, item: 21},
|
||||
Weighted { weight: 0, item: 22},
|
||||
Weighted { weight: 1, item: 23}],
|
||||
~[21,21, 23]);
|
||||
|
||||
// different weights
|
||||
t!(~[Weighted { weight: 4, item: 30},
|
||||
Weighted { weight: 3, item: 31}],
|
||||
~[30,30,30,30, 31,31,31]);
|
||||
|
||||
// check that we're binary searching
|
||||
// correctly with some vectors of odd
|
||||
// length.
|
||||
t!(~[Weighted { weight: 1, item: 40},
|
||||
Weighted { weight: 1, item: 41},
|
||||
Weighted { weight: 1, item: 42},
|
||||
Weighted { weight: 1, item: 43},
|
||||
Weighted { weight: 1, item: 44}],
|
||||
~[40, 41, 42, 43, 44]);
|
||||
t!(~[Weighted { weight: 1, item: 50},
|
||||
Weighted { weight: 1, item: 51},
|
||||
Weighted { weight: 1, item: 52},
|
||||
Weighted { weight: 1, item: 53},
|
||||
Weighted { weight: 1, item: 54},
|
||||
Weighted { weight: 1, item: 55},
|
||||
Weighted { weight: 1, item: 56}],
|
||||
~[50, 51, 52, 53, 54, 55, 56]);
|
||||
}
|
||||
|
||||
#[test] #[should_fail]
|
||||
fn test_weighted_choice_no_items() {
|
||||
WeightedChoice::<int>::new(~[]);
|
||||
}
|
||||
#[test] #[should_fail]
|
||||
fn test_weighted_choice_zero_weight() {
|
||||
WeightedChoice::new(~[Weighted { weight: 0, item: 0},
|
||||
Weighted { weight: 0, item: 1}]);
|
||||
}
|
||||
#[test] #[should_fail]
|
||||
fn test_weighted_choice_weight_overflows() {
|
||||
let x = (-1) as uint / 2; // x + x + 2 is the overflow
|
||||
WeightedChoice::new(~[Weighted { weight: x, item: 0 },
|
||||
Weighted { weight: 1, item: 1 },
|
||||
Weighted { weight: x, item: 2 },
|
||||
Weighted { weight: 1, item: 3 }]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -100,14 +100,6 @@ pub trait Rand {
|
|||
fn rand<R: Rng>(rng: &mut R) -> Self;
|
||||
}
|
||||
|
||||
/// A value with a particular weight compared to other values
|
||||
pub struct Weighted<T> {
|
||||
/// The numerical weight of this item
|
||||
weight: uint,
|
||||
/// The actual item which is being weighted
|
||||
item: T,
|
||||
}
|
||||
|
||||
/// A random number generator
|
||||
pub trait Rng {
|
||||
/// Return the next random u32. This rarely needs to be called
|
||||
|
@ -334,91 +326,6 @@ pub trait Rng {
|
|||
}
|
||||
}
|
||||
|
||||
/// Choose an item respecting the relative weights, failing if the sum of
|
||||
/// the weights is 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::rand;
|
||||
/// use std::rand::Rng;
|
||||
///
|
||||
/// fn main() {
|
||||
/// let mut rng = rand::rng();
|
||||
/// let x = [rand::Weighted {weight: 4, item: 'a'},
|
||||
/// rand::Weighted {weight: 2, item: 'b'},
|
||||
/// rand::Weighted {weight: 2, item: 'c'}];
|
||||
/// println!("{}", rng.choose_weighted(x));
|
||||
/// }
|
||||
/// ```
|
||||
fn choose_weighted<T:Clone>(&mut self, v: &[Weighted<T>]) -> T {
|
||||
self.choose_weighted_option(v).expect("Rng.choose_weighted: total weight is 0")
|
||||
}
|
||||
|
||||
/// Choose Some(item) respecting the relative weights, returning none if
|
||||
/// the sum of the weights is 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::rand;
|
||||
/// use std::rand::Rng;
|
||||
///
|
||||
/// fn main() {
|
||||
/// let mut rng = rand::rng();
|
||||
/// let x = [rand::Weighted {weight: 4, item: 'a'},
|
||||
/// rand::Weighted {weight: 2, item: 'b'},
|
||||
/// rand::Weighted {weight: 2, item: 'c'}];
|
||||
/// println!("{:?}", rng.choose_weighted_option(x));
|
||||
/// }
|
||||
/// ```
|
||||
fn choose_weighted_option<T:Clone>(&mut self, v: &[Weighted<T>])
|
||||
-> Option<T> {
|
||||
let mut total = 0u;
|
||||
for item in v.iter() {
|
||||
total += item.weight;
|
||||
}
|
||||
if total == 0u {
|
||||
return None;
|
||||
}
|
||||
let chosen = self.gen_range(0u, total);
|
||||
let mut so_far = 0u;
|
||||
for item in v.iter() {
|
||||
so_far += item.weight;
|
||||
if so_far > chosen {
|
||||
return Some(item.item.clone());
|
||||
}
|
||||
}
|
||||
unreachable!();
|
||||
}
|
||||
|
||||
/// Return a vec containing copies of the items, in order, where
|
||||
/// the weight of the item determines how many copies there are
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::rand;
|
||||
/// use std::rand::Rng;
|
||||
///
|
||||
/// fn main() {
|
||||
/// let mut rng = rand::rng();
|
||||
/// let x = [rand::Weighted {weight: 4, item: 'a'},
|
||||
/// rand::Weighted {weight: 2, item: 'b'},
|
||||
/// rand::Weighted {weight: 2, item: 'c'}];
|
||||
/// println!("{}", rng.weighted_vec(x));
|
||||
/// }
|
||||
/// ```
|
||||
fn weighted_vec<T:Clone>(&mut self, v: &[Weighted<T>]) -> ~[T] {
|
||||
let mut r = ~[];
|
||||
for item in v.iter() {
|
||||
for _ in range(0u, item.weight) {
|
||||
r.push(item.item.clone());
|
||||
}
|
||||
}
|
||||
r
|
||||
}
|
||||
|
||||
/// Shuffle a vec
|
||||
///
|
||||
/// # Example
|
||||
|
@ -860,44 +767,6 @@ mod test {
|
|||
assert_eq!(r.choose_option(v), Some(&i));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_choose_weighted() {
|
||||
let mut r = rng();
|
||||
assert!(r.choose_weighted([
|
||||
Weighted { weight: 1u, item: 42 },
|
||||
]) == 42);
|
||||
assert!(r.choose_weighted([
|
||||
Weighted { weight: 0u, item: 42 },
|
||||
Weighted { weight: 1u, item: 43 },
|
||||
]) == 43);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_choose_weighted_option() {
|
||||
let mut r = rng();
|
||||
assert!(r.choose_weighted_option([
|
||||
Weighted { weight: 1u, item: 42 },
|
||||
]) == Some(42));
|
||||
assert!(r.choose_weighted_option([
|
||||
Weighted { weight: 0u, item: 42 },
|
||||
Weighted { weight: 1u, item: 43 },
|
||||
]) == Some(43));
|
||||
let v: Option<int> = r.choose_weighted_option([]);
|
||||
assert!(v.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_vec() {
|
||||
let mut r = rng();
|
||||
let empty: ~[int] = ~[];
|
||||
assert_eq!(r.weighted_vec([]), empty);
|
||||
assert!(r.weighted_vec([
|
||||
Weighted { weight: 0u, item: 3u },
|
||||
Weighted { weight: 1u, item: 2u },
|
||||
Weighted { weight: 2u, item: 1u },
|
||||
]) == ~[2u, 1u, 1u]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle() {
|
||||
let mut r = rng();
|
||||
|
|
Loading…
Reference in New Issue