diff --git a/src/librustc_data_structures/bitvec.rs b/src/librustc_data_structures/bitvec.rs index 80cdb0e4417..54565afa4c6 100644 --- a/src/librustc_data_structures/bitvec.rs +++ b/src/librustc_data_structures/bitvec.rs @@ -8,19 +8,28 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use std::collections::BTreeMap; +use std::collections::btree_map::Entry; +use std::marker::PhantomData; use std::iter::FromIterator; +use indexed_vec::{Idx, IndexVec}; + +type Word = u128; +const WORD_BITS: usize = 128; /// A very simple BitVector type. #[derive(Clone, Debug, PartialEq)] pub struct BitVector { - data: Vec, + data: Vec, } impl BitVector { #[inline] pub fn new(num_bits: usize) -> BitVector { - let num_words = u64s(num_bits); - BitVector { data: vec![0; num_words] } + let num_words = words(num_bits); + BitVector { + data: vec![0; num_words], + } } #[inline] @@ -78,7 +87,7 @@ impl BitVector { #[inline] pub fn grow(&mut self, num_bits: usize) { - let num_words = u64s(num_bits); + let num_words = words(num_bits); if self.data.len() < num_words { self.data.resize(num_words, 0) } @@ -96,8 +105,8 @@ impl BitVector { } pub struct BitVectorIter<'a> { - iter: ::std::slice::Iter<'a, u64>, - current: u64, + iter: ::std::slice::Iter<'a, Word>, + current: Word, idx: usize, } @@ -107,10 +116,10 @@ impl<'a> Iterator for BitVectorIter<'a> { while self.current == 0 { self.current = if let Some(&i) = self.iter.next() { if i == 0 { - self.idx += 64; + self.idx += WORD_BITS; continue; } else { - self.idx = u64s(self.idx) * 64; + self.idx = words(self.idx) * WORD_BITS; i } } else { @@ -126,12 +135,15 @@ impl<'a> Iterator for BitVectorIter<'a> { } impl FromIterator for BitVector { - fn from_iter(iter: I) -> BitVector where I: IntoIterator { + fn from_iter(iter: I) -> BitVector + where + I: IntoIterator, + { let iter = iter.into_iter(); let (len, _) = iter.size_hint(); - // Make the minimum length for the bitvector 64 bits since that's + // Make the minimum length for the bitvector WORD_BITS bits since that's // the smallest non-zero size anyway. - let len = if len < 64 { 64 } else { len }; + let len = if len < WORD_BITS { WORD_BITS } else { len }; let mut bv = BitVector::new(len); for (idx, val) in iter.enumerate() { if idx > len { @@ -152,32 +164,32 @@ impl FromIterator for BitVector { #[derive(Clone, Debug)] pub struct BitMatrix { columns: usize, - vector: Vec, + vector: Vec, } impl BitMatrix { /// Create a new `rows x columns` matrix, initially empty. pub fn new(rows: usize, columns: usize) -> BitMatrix { // For every element, we need one bit for every other - // element. Round up to an even number of u64s. - let u64s_per_row = u64s(columns); + // element. Round up to an even number of words. + let words_per_row = words(columns); BitMatrix { columns, - vector: vec![0; rows * u64s_per_row], + vector: vec![0; rows * words_per_row], } } /// The range of bits for a given row. fn range(&self, row: usize) -> (usize, usize) { - let u64s_per_row = u64s(self.columns); - let start = row * u64s_per_row; - (start, start + u64s_per_row) + let words_per_row = words(self.columns); + let start = row * words_per_row; + (start, start + words_per_row) } /// Sets the cell at `(row, column)` to true. Put another way, add /// `column` to the bitset for `row`. /// - /// Returns true if this changed the matrix, and false otherwies. + /// Returns true if this changed the matrix, and false otherwise. pub fn add(&mut self, row: usize, column: usize) -> bool { let (start, _) = self.range(row); let (word, mask) = word_mask(column); @@ -208,12 +220,12 @@ impl BitMatrix { let mut result = Vec::with_capacity(self.columns); for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() { let mut v = self.vector[i] & self.vector[j]; - for bit in 0..64 { + for bit in 0..WORD_BITS { if v == 0 { break; } if v & 0x1 != 0 { - result.push(base * 64 + bit); + result.push(base * WORD_BITS + bit); } v >>= 1; } @@ -254,15 +266,214 @@ impl BitMatrix { } } -#[inline] -fn u64s(elements: usize) -> usize { - (elements + 63) / 64 +#[derive(Clone, Debug)] +pub struct SparseBitMatrix +where + R: Idx, + C: Idx, +{ + vector: IndexVec>, +} + +impl SparseBitMatrix { + /// Create a new `rows x columns` matrix, initially empty. + pub fn new(rows: R, _columns: C) -> SparseBitMatrix { + SparseBitMatrix { + vector: IndexVec::from_elem_n(SparseBitSet::new(), rows.index()), + } + } + + /// Sets the cell at `(row, column)` to true. Put another way, insert + /// `column` to the bitset for `row`. + /// + /// Returns true if this changed the matrix, and false otherwise. + pub fn add(&mut self, row: R, column: C) -> bool { + self.vector[row].insert(column) + } + + /// Do the bits from `row` contain `column`? Put another way, is + /// the matrix cell at `(row, column)` true? Put yet another way, + /// if the matrix represents (transitive) reachability, can + /// `row` reach `column`? + pub fn contains(&self, row: R, column: C) -> bool { + self.vector[row].contains(column) + } + + /// Add the bits from row `read` to the bits from row `write`, + /// return true if anything changed. + /// + /// This is used when computing transitive reachability because if + /// you have an edge `write -> read`, because in that case + /// `write` can reach everything that `read` can (and + /// potentially more). + pub fn merge(&mut self, read: R, write: R) -> bool { + let mut changed = false; + + if read != write { + let (bit_set_read, bit_set_write) = self.vector.pick2_mut(read, write); + + for read_val in bit_set_read.iter() { + changed = changed | bit_set_write.insert(read_val); + } + } + + changed + } + + /// Iterates through all the columns set to true in a given row of + /// the matrix. + pub fn iter<'a>(&'a self, row: R) -> impl Iterator + 'a { + self.vector[row].iter() + } +} + +#[derive(Clone, Debug)] +pub struct SparseBitSet { + chunk_bits: BTreeMap, + _marker: PhantomData, +} + +#[derive(Copy, Clone)] +pub struct SparseChunk { + key: u32, + bits: Word, + _marker: PhantomData, +} + +impl SparseChunk { + pub fn one(index: I) -> Self { + let index = index.index(); + let key_usize = index / 128; + let key = key_usize as u32; + assert_eq!(key as usize, key_usize); + SparseChunk { + key, + bits: 1 << (index % 128), + _marker: PhantomData, + } + } + + pub fn any(&self) -> bool { + self.bits != 0 + } + + pub fn iter(&self) -> impl Iterator { + let base = self.key as usize * 128; + let mut bits = self.bits; + (0..128) + .map(move |i| { + let current_bits = bits; + bits >>= 1; + (i, current_bits) + }) + .take_while(|&(_, bits)| bits != 0) + .filter_map(move |(i, bits)| { + if (bits & 1) != 0 { + Some(I::new(base + i)) + } else { + None + } + }) + } +} + +impl SparseBitSet { + pub fn new() -> Self { + SparseBitSet { + chunk_bits: BTreeMap::new(), + _marker: PhantomData, + } + } + + pub fn capacity(&self) -> usize { + self.chunk_bits.len() * 128 + } + + pub fn contains_chunk(&self, chunk: SparseChunk) -> SparseChunk { + SparseChunk { + bits: self.chunk_bits + .get(&chunk.key) + .map_or(0, |bits| bits & chunk.bits), + ..chunk + } + } + + pub fn insert_chunk(&mut self, chunk: SparseChunk) -> SparseChunk { + if chunk.bits == 0 { + return chunk; + } + let bits = self.chunk_bits.entry(chunk.key).or_insert(0); + let old_bits = *bits; + let new_bits = old_bits | chunk.bits; + *bits = new_bits; + let changed = new_bits ^ old_bits; + SparseChunk { + bits: changed, + ..chunk + } + } + + pub fn remove_chunk(&mut self, chunk: SparseChunk) -> SparseChunk { + if chunk.bits == 0 { + return chunk; + } + let changed = match self.chunk_bits.entry(chunk.key) { + Entry::Occupied(mut bits) => { + let old_bits = *bits.get(); + let new_bits = old_bits & !chunk.bits; + if new_bits == 0 { + bits.remove(); + } else { + bits.insert(new_bits); + } + new_bits ^ old_bits + } + Entry::Vacant(_) => 0, + }; + SparseChunk { + bits: changed, + ..chunk + } + } + + pub fn clear(&mut self) { + self.chunk_bits.clear(); + } + + pub fn chunks<'a>(&'a self) -> impl Iterator> + 'a { + self.chunk_bits.iter().map(|(&key, &bits)| SparseChunk { + key, + bits, + _marker: PhantomData, + }) + } + + pub fn contains(&self, index: I) -> bool { + self.contains_chunk(SparseChunk::one(index)).any() + } + + pub fn insert(&mut self, index: I) -> bool { + self.insert_chunk(SparseChunk::one(index)).any() + } + + pub fn remove(&mut self, index: I) -> bool { + self.remove_chunk(SparseChunk::one(index)).any() + } + + pub fn iter<'a>(&'a self) -> impl Iterator + 'a { + self.chunks().flat_map(|chunk| chunk.iter()) + } } #[inline] -fn word_mask(index: usize) -> (usize, u64) { - let word = index / 64; - let mask = 1 << (index % 64); +fn words(elements: usize) -> usize { + (elements + WORD_BITS - 1) / WORD_BITS +} + +#[inline] +fn word_mask(index: usize) -> (usize, Word) { + let word = index / WORD_BITS; + let mask = 1 << (index % WORD_BITS); (word, mask) } @@ -278,11 +489,12 @@ fn bitvec_iter_works() { bitvec.insert(65); bitvec.insert(66); bitvec.insert(99); - assert_eq!(bitvec.iter().collect::>(), - [1, 10, 19, 62, 63, 64, 65, 66, 99]); + assert_eq!( + bitvec.iter().collect::>(), + [1, 10, 19, 62, 63, 64, 65, 66, 99] + ); } - #[test] fn bitvec_iter_works_2() { let mut bitvec = BitVector::new(319); @@ -314,24 +526,24 @@ fn union_two_vecs() { #[test] fn grow() { let mut vec1 = BitVector::new(65); - for index in 0 .. 65 { + for index in 0..65 { assert!(vec1.insert(index)); assert!(!vec1.insert(index)); } vec1.grow(128); // Check if the bits set before growing are still set - for index in 0 .. 65 { + for index in 0..65 { assert!(vec1.contains(index)); } // Check if the new bits are all un-set - for index in 65 .. 128 { + for index in 65..128 { assert!(!vec1.contains(index)); } // Check that we can set all new bits without running out of bounds - for index in 65 .. 128 { + for index in 65..128 { assert!(vec1.insert(index)); assert!(!vec1.insert(index)); } diff --git a/src/librustc_data_structures/indexed_vec.rs b/src/librustc_data_structures/indexed_vec.rs index b11ca107af7..3e94b3f4d30 100644 --- a/src/librustc_data_structures/indexed_vec.rs +++ b/src/librustc_data_structures/indexed_vec.rs @@ -482,6 +482,21 @@ impl IndexVec { pub fn get_mut(&mut self, index: I) -> Option<&mut T> { self.raw.get_mut(index.index()) } + + /// Return mutable references to two distinct elements, a and b. Panics if a == b. + #[inline] + pub fn pick2_mut(&mut self, a: I, b: I) -> (&mut T, &mut T) { + let (ai, bi) = (a.index(), b.index()); + assert!(ai != bi); + + if ai < bi { + let (c1, c2) = self.raw.split_at_mut(bi); + (&mut c1[ai], &mut c2[0]) + } else { + let (c2, c1) = self.pick2_mut(b, a); + (c1, c2) + } + } } impl IndexVec { diff --git a/src/librustc_mir/borrow_check/nll/region_infer/values.rs b/src/librustc_mir/borrow_check/nll/region_infer/values.rs index 45236bbc4aa..e6f2a43bfc8 100644 --- a/src/librustc_mir/borrow_check/nll/region_infer/values.rs +++ b/src/librustc_mir/borrow_check/nll/region_infer/values.rs @@ -9,7 +9,7 @@ // except according to those terms. use std::rc::Rc; -use rustc_data_structures::bitvec::BitMatrix; +use rustc_data_structures::bitvec::SparseBitMatrix; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::indexed_vec::Idx; use rustc_data_structures::indexed_vec::IndexVec; @@ -69,9 +69,7 @@ impl RegionValueElements { /// Iterates over the `RegionElementIndex` for all points in the CFG. pub(super) fn all_point_indices<'a>(&'a self) -> impl Iterator + 'a { - (0..self.num_points).map(move |i| { - RegionElementIndex::new(i + self.num_universal_regions) - }) + (0..self.num_points).map(move |i| RegionElementIndex::new(i + self.num_universal_regions)) } /// Iterates over the `RegionElementIndex` for all points in the CFG. @@ -132,7 +130,7 @@ impl RegionValueElements { } /// A newtype for the integers that represent one of the possible -/// elements in a region. These are the rows in the `BitMatrix` that +/// elements in a region. These are the rows in the `SparseBitMatrix` that /// is used to store the values of all regions. They have the following /// convention: /// @@ -154,7 +152,6 @@ pub(super) enum RegionElement { UniversalRegion(RegionVid), } - pub(super) trait ToElementIndex { fn to_element_index(self, elements: &RegionValueElements) -> RegionElementIndex; } @@ -184,18 +181,18 @@ impl ToElementIndex for RegionElementIndex { } /// Stores the values for a set of regions. These are stored in a -/// compact `BitMatrix` representation, with one row per region +/// compact `SparseBitMatrix` representation, with one row per region /// variable. The columns consist of either universal regions or /// points in the CFG. #[derive(Clone)] pub(super) struct RegionValues { elements: Rc, - matrix: BitMatrix, + matrix: SparseBitMatrix, /// If cause tracking is enabled, maps from a pair (r, e) /// consisting of a region `r` that contains some element `e` to /// the reason that the element is contained. There should be an - /// entry for every bit set to 1 in `BitMatrix`. + /// entry for every bit set to 1 in `SparseBitMatrix`. causes: Option, } @@ -214,7 +211,10 @@ impl RegionValues { Self { elements: elements.clone(), - matrix: BitMatrix::new(num_region_variables, elements.num_elements()), + matrix: SparseBitMatrix::new( + RegionVid::new(num_region_variables), + RegionElementIndex::new(elements.num_elements()), + ), causes: if track_causes.0 { Some(CauseMap::default()) } else { @@ -238,7 +238,7 @@ impl RegionValues { where F: FnOnce(&CauseMap) -> Cause, { - if self.matrix.add(r.index(), i.index()) { + if self.matrix.add(r, i) { debug!("add(r={:?}, i={:?})", r, self.elements.to_element(i)); if let Some(causes) = &mut self.causes { @@ -289,13 +289,12 @@ impl RegionValues { constraint_location: Location, constraint_span: Span, ) -> bool { - // We could optimize this by improving `BitMatrix::merge` so + // We could optimize this by improving `SparseBitMatrix::merge` so // it does not always merge an entire row. That would // complicate causal tracking though. debug!( "add_universal_regions_outlived_by(from_region={:?}, to_region={:?})", - from_region, - to_region + from_region, to_region ); let mut changed = false; for elem in self.elements.all_universal_region_indices() { @@ -315,7 +314,7 @@ impl RegionValues { /// True if the region `r` contains the given element. pub(super) fn contains(&self, r: RegionVid, elem: E) -> bool { let i = self.elements.index(elem); - self.matrix.contains(r.index(), i.index()) + self.matrix.contains(r, i) } /// Iterate over the value of the region `r`, yielding up element @@ -325,9 +324,7 @@ impl RegionValues { &'a self, r: RegionVid, ) -> impl Iterator + 'a { - self.matrix - .iter(r.index()) - .map(move |i| RegionElementIndex::new(i)) + self.matrix.iter(r).map(move |i| i) } /// Returns just the universal regions that are contained in a given region's value. @@ -415,9 +412,7 @@ impl RegionValues { assert_eq!(location1.block, location2.block); str.push_str(&format!( "{:?}[{}..={}]", - location1.block, - location1.statement_index, - location2.statement_index + location1.block, location1.statement_index, location2.statement_index )); } }