diff --git a/library/alloc/src/collections/btree/node.rs b/library/alloc/src/collections/btree/node.rs index 1346ad19fe2..ce4a7e1bdd7 100644 --- a/library/alloc/src/collections/btree/node.rs +++ b/library/alloc/src/collections/btree/node.rs @@ -316,7 +316,9 @@ impl NodeRef { /// Note that, despite being safe, calling this function can have the side effect /// of invalidating mutable references that unsafe code has created. pub fn len(&self) -> usize { - self.as_leaf().len as usize + // Crucially, we only access the `len` field here. There might be outstanding mutable references + // to keys/values that we must not invalidate. + unsafe { (*self.as_leaf()).len as usize } } /// Returns the height of this node in the whole tree. Zero height denotes the @@ -334,11 +336,14 @@ impl NodeRef { /// If the node is a leaf, this function simply opens up its data. /// If the node is an internal node, so not a leaf, it does have all the data a leaf has /// (header, keys and values), and this function exposes that. - fn as_leaf(&self) -> &LeafNode { + /// + /// Returns a raw ptr to avoid invalidating other references to this node + /// (such as during iteration). + fn as_leaf(&self) -> *const LeafNode { // The node must be valid for at least the LeafNode portion. // This is not a reference in the NodeRef type because we don't know if // it should be unique or shared. - unsafe { self.node.as_ref() } + self.node.as_ptr() } /// Borrows a view into the keys stored in the node. @@ -361,7 +366,7 @@ impl NodeRef { pub fn ascend( self, ) -> Result, marker::Edge>, Self> { - let parent_as_leaf = self.as_leaf().parent as *const LeafNode; + let parent_as_leaf = unsafe { (*self.as_leaf()).parent as *const LeafNode }; if let Some(non_zero) = NonNull::new(parent_as_leaf as *mut _) { Ok(Handle { node: NodeRef { @@ -370,7 +375,7 @@ impl NodeRef { root: self.root, _marker: PhantomData, }, - idx: unsafe { usize::from(*self.as_leaf().parent_idx.as_ptr()) }, + idx: unsafe { usize::from(*(*self.as_leaf()).parent_idx.as_ptr()) }, _marker: PhantomData, }) } else { @@ -475,13 +480,13 @@ impl<'a, K, V, Type> NodeRef, K, V, Type> { impl<'a, K: 'a, V: 'a, Type> NodeRef, K, V, Type> { fn into_key_slice(self) -> &'a [K] { unsafe { - slice::from_raw_parts(MaybeUninit::slice_as_ptr(&self.as_leaf().keys), self.len()) + slice::from_raw_parts(MaybeUninit::slice_as_ptr(&(*self.as_leaf()).keys), self.len()) } } fn into_val_slice(self) -> &'a [V] { unsafe { - slice::from_raw_parts(MaybeUninit::slice_as_ptr(&self.as_leaf().vals), self.len()) + slice::from_raw_parts(MaybeUninit::slice_as_ptr(&(*self.as_leaf()).vals), self.len()) } } }