diff --git a/src/librustc/hir/mod.rs b/src/librustc/hir/mod.rs index a3a133daa09..cc0d49c1a36 100644 --- a/src/librustc/hir/mod.rs +++ b/src/librustc/hir/mod.rs @@ -684,6 +684,16 @@ pub enum Mutability { MutImmutable, } +impl Mutability { + /// Return MutMutable only if both arguments are mutable. + pub fn and(self, other: Self) -> Self { + match self { + MutMutable => other, + MutImmutable => MutImmutable, + } + } +} + #[derive(Clone, PartialEq, Eq, RustcEncodable, RustcDecodable, Hash, Debug, Copy)] pub enum BinOp_ { /// The `+` operator (addition) diff --git a/src/librustc/ich/impls_mir.rs b/src/librustc/ich/impls_mir.rs index eb0c62a1161..dc41f981ed5 100644 --- a/src/librustc/ich/impls_mir.rs +++ b/src/librustc/ich/impls_mir.rs @@ -243,6 +243,8 @@ for mir::StatementKind<'tcx> { } } +impl_stable_hash_for!(struct mir::ValidationOperand<'tcx> { lval, ty, re, mutbl }); + impl_stable_hash_for!(enum mir::ValidationOp { Acquire, Release, Suspend(extent) }); impl<'a, 'gcx, 'tcx> HashStable> for mir::Lvalue<'tcx> { diff --git a/src/librustc/mir/mod.rs b/src/librustc/mir/mod.rs index dcab476ec23..4655f8a9c15 100644 --- a/src/librustc/mir/mod.rs +++ b/src/librustc/mir/mod.rs @@ -25,7 +25,7 @@ use ty::{self, AdtDef, ClosureSubsts, Region, Ty}; use ty::fold::{TypeFoldable, TypeFolder, TypeVisitor}; use util::ppaux; use rustc_back::slice; -use hir::InlineAsm; +use hir::{self, InlineAsm}; use std::ascii; use std::borrow::{Cow}; use std::cell::Ref; @@ -826,7 +826,7 @@ pub enum StatementKind<'tcx> { }, /// Assert the given lvalues to be valid inhabitants of their type. - Validate(ValidationOp, Vec<(Ty<'tcx>, Lvalue<'tcx>)>), + Validate(ValidationOp, Vec>), /// Mark one terminating point of an extent (i.e. static region). /// (The starting point(s) arise implicitly from borrows.) @@ -855,6 +855,28 @@ impl Debug for ValidationOp { } } +#[derive(Clone, RustcEncodable, RustcDecodable)] +pub struct ValidationOperand<'tcx> { + pub lval: Lvalue<'tcx>, + pub ty: Ty<'tcx>, + pub re: Option, + pub mutbl: hir::Mutability, +} + +impl<'tcx> Debug for ValidationOperand<'tcx> { + fn fmt(&self, fmt: &mut Formatter) -> fmt::Result { + write!(fmt, "{:?}@{:?}", self.lval, self.ty)?; + if let Some(ce) = self.re { + // (reuse lifetime rendering policy from ppaux.) + write!(fmt, "/{}", ty::ReScope(ce))?; + } + if let hir::MutImmutable = self.mutbl { + write!(fmt, " (imm)")?; + } + Ok(()) + } +} + impl<'tcx> Debug for Statement<'tcx> { fn fmt(&self, fmt: &mut Formatter) -> fmt::Result { use self::StatementKind::*; @@ -1505,6 +1527,21 @@ impl<'tcx> TypeFoldable<'tcx> for BasicBlockData<'tcx> { } } +impl<'tcx> TypeFoldable<'tcx> for ValidationOperand<'tcx> { + fn super_fold_with<'gcx: 'tcx, F: TypeFolder<'gcx, 'tcx>>(&self, folder: &mut F) -> Self { + ValidationOperand { + lval: self.lval.fold_with(folder), + ty: self.ty.fold_with(folder), + re: self.re, + mutbl: self.mutbl, + } + } + + fn super_visit_with>(&self, visitor: &mut V) -> bool { + self.lval.visit_with(visitor) || self.ty.visit_with(visitor) + } +} + impl<'tcx> TypeFoldable<'tcx> for Statement<'tcx> { fn super_fold_with<'gcx: 'tcx, F: TypeFolder<'gcx, 'tcx>>(&self, folder: &mut F) -> Self { use mir::StatementKind::*; @@ -1531,7 +1568,7 @@ impl<'tcx> TypeFoldable<'tcx> for Statement<'tcx> { Validate(ref op, ref lvals) => Validate(op.clone(), - lvals.iter().map(|ty_and_lval| ty_and_lval.fold_with(folder)).collect()), + lvals.iter().map(|operand| operand.fold_with(folder)).collect()), Nop => Nop, }; diff --git a/src/librustc/mir/visit.rs b/src/librustc/mir/visit.rs index 5284a613239..a05007503ce 100644 --- a/src/librustc/mir/visit.rs +++ b/src/librustc/mir/visit.rs @@ -334,9 +334,10 @@ macro_rules! make_mir_visitor { } StatementKind::EndRegion(_) => {} StatementKind::Validate(_, ref $($mutability)* lvalues) => { - for & $($mutability)* (ref $($mutability)* ty, ref $($mutability)* lvalue) in lvalues { - self.visit_ty(ty, Lookup::Loc(location)); - self.visit_lvalue(lvalue, LvalueContext::Validate, location); + for operand in lvalues { + self.visit_lvalue(& $($mutability)* operand.lval, + LvalueContext::Validate, location); + self.visit_ty(& $($mutability)* operand.ty, Lookup::Loc(location)); } } StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => { diff --git a/src/librustc_mir/transform/add_validation.rs b/src/librustc_mir/transform/add_validation.rs index b79c1a2d6fd..1fe16fb98f2 100644 --- a/src/librustc_mir/transform/add_validation.rs +++ b/src/librustc_mir/transform/add_validation.rs @@ -14,34 +14,67 @@ //! of MIR building, and only after this pass we think of the program has having the //! normal MIR semantics. -use rustc::ty::{TyCtxt, RegionKind}; +use rustc::ty::{self, TyCtxt, RegionKind}; +use rustc::hir; use rustc::mir::*; use rustc::mir::transform::{MirPass, MirSource}; +use rustc::middle::region::CodeExtent; pub struct AddValidation; - -fn is_lvalue_shared<'a, 'tcx, D>(lval: &Lvalue<'tcx>, local_decls: &D, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> bool +/// Determine the "context" of the lval: Mutability and region. +fn lval_context<'a, 'tcx, D>( + lval: &Lvalue<'tcx>, + local_decls: &D, + tcx: TyCtxt<'a, 'tcx, 'tcx> +) -> (Option, hir::Mutability) where D: HasLocalDecls<'tcx> { use rustc::mir::Lvalue::*; match *lval { - Local { .. } => false, - Static(_) => true, + Local { .. } => (None, hir::MutMutable), + Static(_) => (None, hir::MutImmutable), Projection(ref proj) => { - // If the base is shared, things stay shared - if is_lvalue_shared(&proj.base, local_decls, tcx) { - return true; - } - // A Deref projection may make things shared match proj.elem { ProjectionElem::Deref => { - // Computing the inside the recursion makes this quadratic. We don't expect deep paths though. + // Computing the inside the recursion makes this quadratic. + // We don't expect deep paths though. let ty = proj.base.ty(local_decls, tcx).to_ty(tcx); - !ty.is_mutable_pointer() + // A Deref projection may restrict the context, this depends on the type + // being deref'd. + let context = match ty.sty { + ty::TyRef(re, tam) => { + let re = match re { + &RegionKind::ReScope(ce) => Some(ce), + &RegionKind::ReErased => + bug!("AddValidation pass must be run before erasing lifetimes"), + _ => None + }; + (re, tam.mutbl) + } + ty::TyRawPtr(_) => + // There is no guarantee behind even a mutable raw pointer, + // no write locks are acquired there, so we also don't want to + // release any. + (None, hir::MutImmutable), + ty::TyAdt(adt, _) if adt.is_box() => (None, hir::MutMutable), + _ => bug!("Deref on a non-pointer type {:?}", ty), + }; + // "Intersect" this restriction with proj.base. + if let (Some(_), hir::MutImmutable) = context { + // This is already as restricted as it gets, no need to even recurse + context + } else { + let base_context = lval_context(&proj.base, local_decls, tcx); + // The region of the outermost Deref is always most restrictive. + let re = context.0.or(base_context.0); + let mutbl = context.1.and(base_context.1); + (re, mutbl) + } + } - _ => false, + _ => lval_context(&proj.base, local_decls, tcx), } } } @@ -52,41 +85,49 @@ impl MirPass for AddValidation { tcx: TyCtxt<'a, 'tcx, 'tcx>, _: MirSource, mir: &mut Mir<'tcx>) { + let local_decls = mir.local_decls.clone(); // TODO: Find a way to get rid of this clone. + + /// Convert an lvalue to a validation operand. + let lval_to_operand = |lval: Lvalue<'tcx>| -> ValidationOperand<'tcx> { + let (re, mutbl) = lval_context(&lval, &local_decls, tcx); + let ty = lval.ty(&local_decls, tcx).to_ty(tcx); + ValidationOperand { lval, ty, re, mutbl } + }; + // PART 1 // Add an AcquireValid at the beginning of the start block. if mir.arg_count > 0 { let acquire_stmt = Statement { source_info: SourceInfo { scope: ARGUMENT_VISIBILITY_SCOPE, - span: mir.span, // TODO: Consider using just the span covering the function argument declaration + span: mir.span, // TODO: Consider using just the span covering the function + // argument declaration. }, kind: StatementKind::Validate(ValidationOp::Acquire, // Skip return value, go over all the arguments mir.local_decls.iter_enumerated().skip(1).take(mir.arg_count) - .map(|(local, local_decl)| (local_decl.ty, Lvalue::Local(local))).collect() + .map(|(local, _)| lval_to_operand(Lvalue::Local(local))).collect() ) }; mir.basic_blocks_mut()[START_BLOCK].statements.insert(0, acquire_stmt); } // PART 2 - // Add ReleaseValid/AcquireValid around function call terminators. We don't use a visitor because - // we need to access the block that a Call jumps to. - let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new(); // Here we collect the destinations. - let local_decls = mir.local_decls.clone(); // TODO: Find a way to get rid of this clone. + // Add ReleaseValid/AcquireValid around function call terminators. We don't use a visitor + // because we need to access the block that a Call jumps to. + let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new(); for block_data in mir.basic_blocks_mut() { match block_data.terminator { - Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. }, source_info }) => { + Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. }, + source_info }) => { // Before the call: Release all arguments let release_stmt = Statement { source_info, kind: StatementKind::Validate(ValidationOp::Release, args.iter().filter_map(|op| { match op { - &Operand::Consume(ref lval) => { - let ty = lval.ty(&local_decls, tcx).to_ty(tcx); - Some((ty, lval.clone())) - }, + &Operand::Consume(ref lval) => + Some(lval_to_operand(lval.clone())), &Operand::Constant(..) => { None }, } }).collect()) @@ -97,13 +138,15 @@ impl MirPass for AddValidation { returns.push((source_info, destination.0.clone(), destination.1)); } } - Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. }, source_info }) | - Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. }, source_info }) => { + Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. }, + source_info }) | + Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. }, + source_info }) => { // Before the call: Release all arguments - let ty = lval.ty(&local_decls, tcx).to_ty(tcx); let release_stmt = Statement { source_info, - kind: StatementKind::Validate(ValidationOp::Release, vec![(ty, lval.clone())]) + kind: StatementKind::Validate(ValidationOp::Release, + vec![lval_to_operand(lval.clone())]), }; block_data.statements.push(release_stmt); // drop doesn't return anything, so we need no acquire. @@ -115,20 +158,20 @@ impl MirPass for AddValidation { } // Now we go over the returns we collected to acquire the return values. for (source_info, dest_lval, dest_block) in returns { - let ty = dest_lval.ty(&local_decls, tcx).to_ty(tcx); let acquire_stmt = Statement { source_info, - kind: StatementKind::Validate(ValidationOp::Acquire, vec![(ty, dest_lval)]) + kind: StatementKind::Validate(ValidationOp::Acquire, + vec![lval_to_operand(dest_lval)]), }; mir.basic_blocks_mut()[dest_block].statements.insert(0, acquire_stmt); } // PART 3 - // Add ReleaseValid/AcquireValid around Ref. Again an iterator does not seem very suited as - // we need to add new statements before and after each Ref. + // Add ReleaseValid/AcquireValid around Ref. Again an iterator does not seem very suited + // as we need to add new statements before and after each Ref. for block_data in mir.basic_blocks_mut() { - // We want to insert statements around Ref commands as we iterate. To this end, we iterate backwards - // using indices. + // We want to insert statements around Ref commands as we iterate. To this end, we + // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (dest_lval, re, src_lval) = match block_data.statements[i].kind { StatementKind::Assign(ref dest_lval, Rvalue::Ref(re, _, ref src_lval)) => { @@ -137,27 +180,25 @@ impl MirPass for AddValidation { _ => continue, }; // So this is a ref, and we got all the data we wanted. - let dest_ty = dest_lval.ty(&local_decls, tcx).to_ty(tcx); let acquire_stmt = Statement { source_info: block_data.statements[i].source_info, - kind: StatementKind::Validate(ValidationOp::Acquire, vec![(dest_ty, dest_lval)]), + kind: StatementKind::Validate(ValidationOp::Acquire, + vec![lval_to_operand(dest_lval)]), }; block_data.statements.insert(i+1, acquire_stmt); - // The source is released until the region of the borrow ends -- but not if it is shared. - if !is_lvalue_shared(&src_lval, &local_decls, tcx) { - let src_ty = src_lval.ty(&local_decls, tcx).to_ty(tcx); - let op = match re { - &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce), - &RegionKind::ReErased => bug!("AddValidation pass must be run before erasing lifetimes"), - _ => ValidationOp::Release, - }; - let release_stmt = Statement { - source_info: block_data.statements[i].source_info, - kind: StatementKind::Validate(op, vec![(src_ty, src_lval)]), - }; - block_data.statements.insert(i, release_stmt); - } + // The source is released until the region of the borrow ends. + let op = match re { + &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce), + &RegionKind::ReErased => + bug!("AddValidation pass must be run before erasing lifetimes"), + _ => ValidationOp::Release, + }; + let release_stmt = Statement { + source_info: block_data.statements[i].source_info, + kind: StatementKind::Validate(op, vec![lval_to_operand(src_lval)]), + }; + block_data.statements.insert(i, release_stmt); } } }