Implement Set container on top of a bit vector

This commit is contained in:
Alex Crichton 2013-02-17 20:01:47 -05:00
parent 393a4b41f6
commit bf8ed45adc

View File

@ -8,10 +8,12 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use core::container::{Container, Mutable, Set};
use core::num::NumCast;
use core::ops;
use core::prelude::*;
use core::uint;
use core::vec::{cast_to_mut, from_elem};
use core::vec::from_elem;
use core::vec;
struct SmallBitv {
@ -133,18 +135,15 @@ impl BigBitv {
let len = b.storage.len();
assert (self.storage.len() == len);
let mut changed = false;
do uint::range(0, len) |i| {
for uint::range(0, len) |i| {
let mask = big_mask(nbits, i);
let w0 = self.storage[i] & mask;
let w1 = b.storage[i] & mask;
let w = op(w0, w1) & mask;
if w0 != w {
unsafe {
changed = true;
self.storage[i] = w;
}
changed = true;
self.storage[i] = w;
}
true
}
changed
}
@ -556,16 +555,317 @@ pub fn from_fn(len: uint, f: fn(index: uint) -> bool) -> Bitv {
bitv
}
impl ops::Index<uint,bool> for Bitv {
pure fn index(&self, i: uint) -> bool {
self.get(i)
}
}
#[inline(always)]
pure fn iterate_bits(base: uint, bits: uint, f: fn(uint) -> bool) -> bool {
if bits == 0 {
return true;
}
for uint::range(0, uint::bits) |i| {
if bits & (1 << i) != 0 {
if !f(base + i) {
return false;
}
}
}
return true;
}
/// An implementation of a set using a bit vector as an underlying
/// representation for holding numerical elements.
///
/// It should also be noted that the amount of storage necessary for holding a
/// set of objects is proportional to the maximum of the objects when viewed
/// as a uint.
pub struct BitvSet {
priv size: uint,
// In theory this is a Bitv instead of always a BigBitv, but knowing that
// there's an array of storage makes our lives a whole lot easier when
// performing union/intersection/etc operations
priv bitv: BigBitv
}
impl BitvSet {
/// Creates a new bit vector set with initially no contents
static fn new() -> BitvSet {
BitvSet{ size: 0, bitv: BigBitv::new(~[0]) }
}
/// Creates a new bit vector set from the given bit vector
static fn from_bitv(bitv: Bitv) -> BitvSet {
let mut size = 0;
for bitv.ones |_| {
size += 1;
}
let Bitv{rep, _} = bitv;
match rep {
Big(~b) => BitvSet{ size: size, bitv: b },
Small(~SmallBitv{bits}) =>
BitvSet{ size: size, bitv: BigBitv{ storage: ~[bits] } },
}
}
/// Returns the capacity in bits for this bit vector. Inserting any
/// element less than this amount will not trigger a resizing.
pure fn capacity(&self) -> uint { self.bitv.storage.len() * uint::bits }
/// Consumes this set to return the underlying bit vector
fn unwrap(self) -> Bitv {
let cap = self.capacity();
let BitvSet{bitv, _} = self;
return Bitv{ nbits:cap, rep: Big(~bitv) };
}
#[inline(always)]
priv fn other_op(&mut self, other: &BitvSet, f: fn(uint, uint) -> uint) {
fn nbits(mut w: uint) -> uint {
let mut bits = 0;
for uint::bits.times {
if w == 0 {
break;
}
bits += w & 1;
w >>= 1;
}
return bits;
}
if self.capacity() < other.capacity() {
self.bitv.storage.grow(other.capacity() / uint::bits, &0);
}
for other.bitv.storage.eachi |i, &w| {
let old = self.bitv.storage[i];
let new = f(old, w);
self.bitv.storage[i] = new;
self.size += nbits(new) - nbits(old);
}
}
/// Union in-place with the specified other bit vector
fn union_with(&mut self, other: &BitvSet) {
self.other_op(other, |w1, w2| w1 | w2);
}
/// Intersect in-place with the specified other bit vector
fn intersect_with(&mut self, other: &BitvSet) {
self.other_op(other, |w1, w2| w1 & w2);
}
/// Difference in-place with the specified other bit vector
fn difference_with(&mut self, other: &BitvSet) {
self.other_op(other, |w1, w2| w1 & !w2);
}
/// Symmetric difference in-place with the specified other bit vector
fn symmetric_difference_with(&mut self, other: &BitvSet) {
self.other_op(other, |w1, w2| w1 ^ w2);
}
}
impl BaseIter<uint> for BitvSet {
pure fn size_hint(&self) -> Option<uint> { Some(self.len()) }
pure fn each(&self, blk: fn(v: &uint) -> bool) {
for self.bitv.storage.eachi |i, &w| {
if !iterate_bits(i * uint::bits, w, |b| blk(&b)) {
return;
}
}
}
}
impl cmp::Eq for BitvSet {
pure fn eq(&self, other: &BitvSet) -> bool {
if self.size != other.size {
return false;
}
for self.each_common(other) |_, w1, w2| {
if w1 != w2 {
return false;
}
}
for self.each_outlier(other) |_, _, w| {
if w != 0 {
return false;
}
}
return true;
}
pure fn ne(&self, other: &BitvSet) -> bool { !self.eq(other) }
}
impl Container for BitvSet {
pure fn len(&self) -> uint { self.size }
pure fn is_empty(&self) -> bool { self.size == 0 }
}
impl Mutable for BitvSet {
fn clear(&mut self) {
for self.bitv.each_storage |w| { *w = 0; }
self.size = 0;
}
}
impl Set<uint> for BitvSet {
pure fn contains(&self, value: &uint) -> bool {
*value < self.bitv.storage.len() * uint::bits && self.bitv.get(*value)
}
fn insert(&mut self, value: uint) -> bool {
if self.contains(&value) {
return false;
}
let nbits = self.capacity();
if value >= nbits {
let newsize = uint::max(value, nbits * 2) / uint::bits + 1;
assert newsize > self.bitv.storage.len();
self.bitv.storage.grow(newsize, &0);
}
self.size += 1;
self.bitv.set(value, true);
return true;
}
fn remove(&mut self, value: &uint) -> bool {
if !self.contains(value) {
return false;
}
self.size -= 1;
self.bitv.set(*value, false);
// Attempt to truncate our storage
let mut i = self.bitv.storage.len();
while i > 1 && self.bitv.storage[i - 1] == 0 {
i -= 1;
}
self.bitv.storage.truncate(i);
return true;
}
pure fn is_disjoint(&self, other: &BitvSet) -> bool {
for self.intersection(other) |_| {
return false;
}
return true;
}
pure fn is_subset(&self, other: &BitvSet) -> bool {
for self.each_common(other) |_, w1, w2| {
if w1 & w2 != w1 {
return false;
}
}
/* If anything is not ours, then everything is not ours so we're
definitely a subset in that case. Otherwise if there's any stray
ones that 'other' doesn't have, we're not a subset. */
for self.each_outlier(other) |mine, _, w| {
if !mine {
return true;
} else if w != 0 {
return false;
}
}
return true;
}
pure fn is_superset(&self, other: &BitvSet) -> bool {
other.is_subset(self)
}
pure fn difference(&self, other: &BitvSet, f: fn(&uint) -> bool) {
for self.each_common(other) |i, w1, w2| {
if !iterate_bits(i, w1 & !w2, |b| f(&b)) {
return;
}
}
/* everything we have that they don't also shows up */
self.each_outlier(other, |mine, i, w|
!mine || iterate_bits(i, w, |b| f(&b))
);
}
pure fn symmetric_difference(&self, other: &BitvSet,
f: fn(&uint) -> bool) {
for self.each_common(other) |i, w1, w2| {
if !iterate_bits(i, w1 ^ w2, |b| f(&b)) {
return;
}
}
self.each_outlier(other, |_, i, w|
iterate_bits(i, w, |b| f(&b))
);
}
pure fn intersection(&self, other: &BitvSet, f: fn(&uint) -> bool) {
for self.each_common(other) |i, w1, w2| {
if !iterate_bits(i, w1 & w2, |b| f(&b)) {
return;
}
}
}
pure fn union(&self, other: &BitvSet, f: fn(&uint) -> bool) {
for self.each_common(other) |i, w1, w2| {
if !iterate_bits(i, w1 | w2, |b| f(&b)) {
return;
}
}
self.each_outlier(other, |_, i, w|
iterate_bits(i, w, |b| f(&b))
);
}
}
priv impl BitvSet {
/// Visits each of the words that the two bit vectors (self and other)
/// both have in common. The three yielded arguments are (bit location,
/// w1, w2) where the bit location is the number of bits offset so far,
/// and w1/w2 are the words coming from the two vectors self, other.
pure fn each_common(&self, other: &BitvSet,
f: fn(uint, uint, uint) -> bool) {
let min = uint::min(self.bitv.storage.len(),
other.bitv.storage.len());
for self.bitv.storage.view(0, min).eachi |i, &w| {
if !f(i * uint::bits, w, other.bitv.storage[i]) {
return;
}
}
}
/// Visits each word in self or other that extends beyond the other. This
/// will only iterate through one of the vectors, and it only iterates
/// over the portion that doesn't overlap with the other one.
///
/// The yielded arguments are a bool, the bit offset, and a word. The bool
/// is true if the word comes from 'self', and false if it comes from
/// 'other'.
pure fn each_outlier(&self, other: &BitvSet,
f: fn(bool, uint, uint) -> bool) {
let len1 = self.bitv.storage.len();
let len2 = other.bitv.storage.len();
let min = uint::min(len1, len2);
/* only one of these loops will execute and that's the point */
for self.bitv.storage.view(min, len1).eachi |i, &w| {
if !f(true, (i + min) * uint::bits, w) {
return;
}
}
for other.bitv.storage.view(min, len2).eachi |i, &w| {
if !f(false, (i + min) * uint::bits, w) {
return;
}
}
}
}
#[cfg(test)]
mod tests {
use core::prelude::*;
@ -946,48 +1246,178 @@ mod tests {
#[test]
pub fn test_small_difference() {
let mut b1 = Bitv::new(3, false);
let mut b2 = Bitv::new(3, false);
b1.set(0, true);
b1.set(1, true);
b2.set(1, true);
b2.set(2, true);
assert b1.difference(&b2);
assert b1[0];
assert !b1[1];
assert !b1[2];
let mut b1 = Bitv::new(3, false);
let mut b2 = Bitv::new(3, false);
b1.set(0, true);
b1.set(1, true);
b2.set(1, true);
b2.set(2, true);
assert b1.difference(&b2);
assert b1[0];
assert !b1[1];
assert !b1[2];
}
#[test]
pub fn test_big_difference() {
let mut b1 = Bitv::new(100, false);
let mut b2 = Bitv::new(100, false);
b1.set(0, true);
b1.set(40, true);
b2.set(40, true);
b2.set(80, true);
assert b1.difference(&b2);
assert b1[0];
assert !b1[40];
assert !b1[80];
let mut b1 = Bitv::new(100, false);
let mut b2 = Bitv::new(100, false);
b1.set(0, true);
b1.set(40, true);
b2.set(40, true);
b2.set(80, true);
assert b1.difference(&b2);
assert b1[0];
assert !b1[40];
assert !b1[80];
}
#[test]
pub fn test_small_clear() {
let mut b = Bitv::new(14, true);
b.clear();
for b.ones |i| {
fail!(fmt!("found 1 at %?", i));
}
let mut b = Bitv::new(14, true);
b.clear();
for b.ones |i| {
fail!(fmt!("found 1 at %?", i));
}
}
#[test]
pub fn test_big_clear() {
let mut b = Bitv::new(140, true);
b.clear();
for b.ones |i| {
fail!(fmt!("found 1 at %?", i));
}
let mut b = Bitv::new(140, true);
b.clear();
for b.ones |i| {
fail!(fmt!("found 1 at %?", i));
}
}
#[test]
pub fn test_bitv_set_basic() {
let mut b = BitvSet::new();
assert b.insert(3);
assert !b.insert(3);
assert b.contains(&3);
assert b.insert(400);
assert !b.insert(400);
assert b.contains(&400);
assert b.len() == 2;
}
#[test]
fn test_bitv_set_intersection() {
let mut a = BitvSet::new();
let mut b = BitvSet::new();
assert a.insert(11);
assert a.insert(1);
assert a.insert(3);
assert a.insert(77);
assert a.insert(103);
assert a.insert(5);
assert b.insert(2);
assert b.insert(11);
assert b.insert(77);
assert b.insert(5);
assert b.insert(3);
let mut i = 0;
let expected = [3, 5, 11, 77];
for a.intersection(&b) |x| {
assert *x == expected[i];
i += 1
}
assert i == expected.len();
}
#[test]
fn test_bitv_set_difference() {
let mut a = BitvSet::new();
let mut b = BitvSet::new();
assert a.insert(1);
assert a.insert(3);
assert a.insert(5);
assert a.insert(200);
assert a.insert(500);
assert b.insert(3);
assert b.insert(200);
let mut i = 0;
let expected = [1, 5, 500];
for a.difference(&b) |x| {
assert *x == expected[i];
i += 1
}
assert i == expected.len();
}
#[test]
fn test_bitv_set_symmetric_difference() {
let mut a = BitvSet::new();
let mut b = BitvSet::new();
assert a.insert(1);
assert a.insert(3);
assert a.insert(5);
assert a.insert(9);
assert a.insert(11);
assert b.insert(3);
assert b.insert(9);
assert b.insert(14);
assert b.insert(220);
let mut i = 0;
let expected = [1, 5, 11, 14, 220];
for a.symmetric_difference(&b) |x| {
assert *x == expected[i];
i += 1
}
assert i == expected.len();
}
#[test]
pub fn test_bitv_set_union() {
let mut a = BitvSet::new();
let mut b = BitvSet::new();
assert a.insert(1);
assert a.insert(3);
assert a.insert(5);
assert a.insert(9);
assert a.insert(11);
assert a.insert(160);
assert a.insert(19);
assert a.insert(24);
assert b.insert(1);
assert b.insert(5);
assert b.insert(9);
assert b.insert(13);
assert b.insert(19);
let mut i = 0;
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160];
for a.union(&b) |x| {
assert *x == expected[i];
i += 1
}
assert i == expected.len();
}
#[test]
pub fn test_bitv_remove() {
let mut a = BitvSet::new();
assert a.insert(1);
assert a.remove(&1);
assert a.insert(100);
assert a.remove(&100);
assert a.insert(1000);
assert a.remove(&1000);
assert a.capacity() == uint::bits;
}
}