Document BinaryHeap unsafe functions

This commit is contained in:
Giacomo Stevanato 2021-02-03 12:45:55 +01:00 committed by Mark Rousskov
parent e7c23ab933
commit 9b4e61255c
1 changed files with 112 additions and 48 deletions

View File

@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
impl<T: Ord> Drop for PeekMut<'_, T> {
fn drop(&mut self) {
if self.sift {
self.heap.sift_down(0);
// SAFETY: PeekMut is only instantiated for non-empty heaps.
unsafe { self.heap.sift_down(0) };
}
}
}
@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
self.data.pop().map(|mut item| {
if !self.is_empty() {
swap(&mut item, &mut self.data[0]);
self.sift_down_to_bottom(0);
// SAFETY: !self.is_empty() means that self.len() > 0
unsafe { self.sift_down_to_bottom(0) };
}
item
})
@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
pub fn push(&mut self, item: T) {
let old_len = self.len();
self.data.push(item);
self.sift_up(0, old_len);
// SAFETY: Since we pushed a new item it means that
// old_len = self.len() - 1 < self.len()
unsafe { self.sift_up(0, old_len) };
}
/// Consumes the `BinaryHeap` and returns a vector in sorted
@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
let ptr = self.data.as_mut_ptr();
ptr::swap(ptr, ptr.add(end));
}
self.sift_down_range(0, end);
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
// 0 < 1 <= end <= self.len() - 1 < self.len()
// Which means 0 < end and end < self.len().
unsafe { self.sift_down_range(0, end) };
}
self.into_vec()
}
@ -519,47 +526,82 @@ impl<T: Ord> BinaryHeap<T> {
// the hole is filled back at the end of its scope, even on panic.
// Using a hole reduces the constant factor compared to using swaps,
// which involves twice as many moves.
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
unsafe {
// Take out the value at `pos` and create a hole.
let mut hole = Hole::new(&mut self.data, pos);
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if hole.element() <= hole.get(parent) {
break;
}
hole.move_to(parent);
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
// Take out the value at `pos` and create a hole.
// SAFETY: The caller guarantees that pos < self.len()
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
// and so hole.pos() - 1 can't underflow.
// This guarantees that parent < hole.pos() so
// it's a valid index and also != hole.pos().
if hole.element() <= unsafe { hole.get(parent) } {
break;
}
hole.pos()
// SAFETY: Same as above
unsafe { hole.move_to(parent) };
}
hole.pos()
}
/// Take an element at `pos` and move it down the heap,
/// while its children are larger.
fn sift_down_range(&mut self, pos: usize, end: usize) {
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end - 1 {
// compare with the greater of the two children
child += (hole.get(child) <= hole.get(child + 1)) as usize;
// if we are already in order, stop.
if hole.element() >= hole.get(child) {
return;
}
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
if child == end - 1 && hole.element() < hole.get(child) {
hole.move_to(child);
///
/// # Safety
///
/// The caller must guarantee that `pos < end <= self.len()`.
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
// SAFETY: The caller guarantees that pos < end <= self.len().
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
// Loop invariant: child == 2 * hole.pos() + 1.
while child < end - 1 {
// compare with the greater of the two children
// SAFETY: child < end - 1 < self.len() and
// child + 1 < end <= self.len(), so they're valid indexes.
// child == 2 * hole.pos() + 1 != hole.pos() and
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
// if we are already in order, stop.
// SAFETY: child is now either the old child or the old child+1
// We already proven that both are < self.len() and != hole.pos()
if hole.element() >= unsafe { hole.get(child) } {
return;
}
// SAFETY: same as above.
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
// SAFETY: && short circuit, which means that in the
// second condition it's already true that child == end - 1 < self.len().
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
// SAFETY: child is already proven to be a valid index and
// child == 2 * hole.pos() + 1 != hole.pos().
unsafe { hole.move_to(child) };
}
}
fn sift_down(&mut self, pos: usize) {
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_down(&mut self, pos: usize) {
let len = self.len();
self.sift_down_range(pos, len);
// SAFETY: pos < len is guaranteed by the caller and
// obviously len = self.len() <= self.len().
unsafe { self.sift_down_range(pos, len) };
}
/// Take an element at `pos` and move it all the way down the heap,
@ -567,30 +609,52 @@ impl<T: Ord> BinaryHeap<T> {
///
/// Note: This is faster when the element is known to be large / should
/// be closer to the bottom.
fn sift_down_to_bottom(&mut self, mut pos: usize) {
///
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end - 1 {
child += (hole.get(child) <= hole.get(child + 1)) as usize;
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
if child == end - 1 {
hole.move_to(child);
}
pos = hole.pos;
// SAFETY: The caller guarantees that pos < self.len().
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
// Loop invariant: child == 2 * hole.pos() + 1.
while child < end - 1 {
// SAFETY: child < end - 1 < self.len() and
// child + 1 < end <= self.len(), so they're valid indexes.
// child == 2 * hole.pos() + 1 != hole.pos() and
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
// SAFETY: Same as above
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
self.sift_up(start, pos);
if child == end - 1 {
// SAFETY: child == end - 1 < self.len(), so it's a valid index
// and child == 2 * hole.pos() + 1 != hole.pos().
unsafe { hole.move_to(child) };
}
pos = hole.pos();
drop(hole);
// SAFETY: pos is the position in the hole and was already proven
// to be a valid index.
unsafe { self.sift_up(start, pos) };
}
fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
n -= 1;
self.sift_down(n);
// SAFETY: n starts from self.len() / 2 and goes down to 0.
// The only case when !(n < self.len()) is if
// self.len() == 0, but it's ruled out by the loop condition.
unsafe { self.sift_down(n) };
}
}