diff --git a/library/std/src/collections/hash/map/tests.rs b/library/std/src/collections/hash/map/tests.rs index 4283d80b78e..467968354e2 100644 --- a/library/std/src/collections/hash/map/tests.rs +++ b/library/std/src/collections/hash/map/tests.rs @@ -924,3 +924,164 @@ fn test_raw_entry() { } } } + +mod test_drain_filter { + use super::*; + + use crate::panic::{catch_unwind, AssertUnwindSafe}; + use crate::sync::atomic::{AtomicUsize, Ordering}; + + trait EqSorted: Iterator { + fn eq_sorted>(self, other: I) -> bool; + } + + impl EqSorted for T + where + T::Item: Eq + Ord, + { + fn eq_sorted>(self, other: I) -> bool { + let mut v: Vec<_> = self.collect(); + v.sort_unstable(); + v.into_iter().eq(other) + } + } + + #[test] + fn empty() { + let mut map: HashMap = HashMap::new(); + map.drain_filter(|_, _| unreachable!("there's nothing to decide on")); + assert!(map.is_empty()); + } + + #[test] + fn consuming_nothing() { + let pairs = (0..3).map(|i| (i, i)); + let mut map: HashMap<_, _> = pairs.collect(); + assert!(map.drain_filter(|_, _| false).eq_sorted(crate::iter::empty())); + assert_eq!(map.len(), 3); + } + + #[test] + fn consuming_all() { + let pairs = (0..3).map(|i| (i, i)); + let mut map: HashMap<_, _> = pairs.clone().collect(); + assert!(map.drain_filter(|_, _| true).eq_sorted(pairs)); + assert!(map.is_empty()); + } + + #[test] + fn mutating_and_keeping() { + let pairs = (0..3).map(|i| (i, i)); + let mut map: HashMap<_, _> = pairs.collect(); + assert!( + map.drain_filter(|_, v| { + *v += 6; + false + }) + .eq_sorted(crate::iter::empty()) + ); + assert!(map.keys().copied().eq_sorted(0..3)); + assert!(map.values().copied().eq_sorted(6..9)); + } + + #[test] + fn mutating_and_removing() { + let pairs = (0..3).map(|i| (i, i)); + let mut map: HashMap<_, _> = pairs.collect(); + assert!( + map.drain_filter(|_, v| { + *v += 6; + true + }) + .eq_sorted((0..3).map(|i| (i, i + 6))) + ); + assert!(map.is_empty()); + } + + #[test] + fn drop_panic_leak() { + static PREDS: AtomicUsize = AtomicUsize::new(0); + static DROPS: AtomicUsize = AtomicUsize::new(0); + + struct D; + impl Drop for D { + fn drop(&mut self) { + if DROPS.fetch_add(1, Ordering::SeqCst) == 1 { + panic!("panic in `drop`"); + } + } + } + + let mut map = (0..3).map(|i| (i, D)).collect::>(); + + catch_unwind(move || { + drop(map.drain_filter(|_, _| { + PREDS.fetch_add(1, Ordering::SeqCst); + true + })) + }) + .unwrap_err(); + + assert_eq!(PREDS.load(Ordering::SeqCst), 3); + assert_eq!(DROPS.load(Ordering::SeqCst), 3); + } + + #[test] + fn pred_panic_leak() { + static PREDS: AtomicUsize = AtomicUsize::new(0); + static DROPS: AtomicUsize = AtomicUsize::new(0); + + struct D; + impl Drop for D { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::SeqCst); + } + } + + let mut map = (0..3).map(|i| (i, D)).collect::>(); + + catch_unwind(AssertUnwindSafe(|| { + drop(map.drain_filter(|_, _| match PREDS.fetch_add(1, Ordering::SeqCst) { + 0 => true, + _ => panic!(), + })) + })) + .unwrap_err(); + + assert_eq!(PREDS.load(Ordering::SeqCst), 2); + assert_eq!(DROPS.load(Ordering::SeqCst), 1); + assert_eq!(map.len(), 2); + } + + // Same as above, but attempt to use the iterator again after the panic in the predicate + #[test] + fn pred_panic_reuse() { + static PREDS: AtomicUsize = AtomicUsize::new(0); + static DROPS: AtomicUsize = AtomicUsize::new(0); + + struct D; + impl Drop for D { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::SeqCst); + } + } + + let mut map = (0..3).map(|i| (i, D)).collect::>(); + + { + let mut it = map.drain_filter(|_, _| match PREDS.fetch_add(1, Ordering::SeqCst) { + 0 => true, + _ => panic!(), + }); + catch_unwind(AssertUnwindSafe(|| while it.next().is_some() {})).unwrap_err(); + // Iterator behaviour after a panic is explicitly unspecified, + // so this is just the current implementation: + let result = catch_unwind(AssertUnwindSafe(|| it.next())); + assert!(result.is_err()); + } + + assert_eq!(PREDS.load(Ordering::SeqCst), 3); + assert_eq!(DROPS.load(Ordering::SeqCst), 1); + assert_eq!(map.len(), 2); + } +} diff --git a/library/std/src/collections/hash/set/tests.rs b/library/std/src/collections/hash/set/tests.rs index 3582390cee4..40f8467fd93 100644 --- a/library/std/src/collections/hash/set/tests.rs +++ b/library/std/src/collections/hash/set/tests.rs @@ -1,6 +1,9 @@ use super::super::map::RandomState; use super::HashSet; +use crate::panic::{catch_unwind, AssertUnwindSafe}; +use crate::sync::atomic::{AtomicU32, Ordering}; + #[test] fn test_zero_capacities() { type HS = HashSet; @@ -413,3 +416,71 @@ fn test_retain() { assert!(set.contains(&4)); assert!(set.contains(&6)); } + +#[test] +fn test_drain_filter() { + let mut x: HashSet<_> = [1].iter().copied().collect(); + let mut y: HashSet<_> = [1].iter().copied().collect(); + + x.drain_filter(|_| true); + y.drain_filter(|_| false); + assert_eq!(x.len(), 0); + assert_eq!(y.len(), 1); +} + +#[test] +fn test_drain_filter_drop_panic_leak() { + static PREDS: AtomicU32 = AtomicU32::new(0); + static DROPS: AtomicU32 = AtomicU32::new(0); + + #[derive(PartialEq, Eq, PartialOrd, Hash)] + struct D(i32); + impl Drop for D { + fn drop(&mut self) { + if DROPS.fetch_add(1, Ordering::SeqCst) == 1 { + panic!("panic in `drop`"); + } + } + } + + let mut set = (0..3).map(|i| D(i)).collect::>(); + + catch_unwind(move || { + drop(set.drain_filter(|_| { + PREDS.fetch_add(1, Ordering::SeqCst); + true + })) + }) + .ok(); + + assert_eq!(PREDS.load(Ordering::SeqCst), 3); + assert_eq!(DROPS.load(Ordering::SeqCst), 3); +} + +#[test] +fn test_drain_filter_pred_panic_leak() { + static PREDS: AtomicU32 = AtomicU32::new(0); + static DROPS: AtomicU32 = AtomicU32::new(0); + + #[derive(PartialEq, Eq, PartialOrd, Hash)] + struct D; + impl Drop for D { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::SeqCst); + } + } + + let mut set: HashSet<_> = (0..3).map(|_| D).collect(); + + catch_unwind(AssertUnwindSafe(|| { + drop(set.drain_filter(|_| match PREDS.fetch_add(1, Ordering::SeqCst) { + 0 => true, + _ => panic!(), + })) + })) + .ok(); + + assert_eq!(PREDS.load(Ordering::SeqCst), 1); + assert_eq!(DROPS.load(Ordering::SeqCst), 3); + assert_eq!(set.len(), 0); +}