diff --git a/compiler/rustc_mir/src/transform/dest_prop.rs b/compiler/rustc_mir/src/transform/dest_prop.rs index cb4321ace7f..20f8f820176 100644 --- a/compiler/rustc_mir/src/transform/dest_prop.rs +++ b/compiler/rustc_mir/src/transform/dest_prop.rs @@ -109,8 +109,8 @@ use rustc_index::{ use rustc_middle::mir::tcx::PlaceTy; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::{ - traversal, Body, Local, LocalKind, Location, Operand, Place, PlaceElem, Rvalue, Statement, - StatementKind, Terminator, TerminatorKind, + traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, PlaceElem, + Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -397,7 +397,9 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { } } -struct Conflicts { +struct Conflicts<'a> { + relevant_locals: &'a BitSet, + /// The conflict matrix. It is always symmetric and the adjacency matrix of the corresponding /// conflict graph. matrix: BitMatrix, @@ -406,30 +408,21 @@ struct Conflicts { unify_cache: BitSet, } -impl Conflicts { +impl Conflicts<'a> { fn build<'tcx>( tcx: TyCtxt<'tcx>, body: &'_ Body<'tcx>, source: MirSource<'tcx>, - relevant_locals: &BitSet, + relevant_locals: &'a BitSet, ) -> Self { // We don't have to look out for locals that have their address taken, since // `find_candidates` already takes care of that. - let mut conflicts = BitMatrix::from_row_n( + let conflicts = BitMatrix::from_row_n( &BitSet::new_empty(body.local_decls.len()), body.local_decls.len(), ); - let mut record_conflicts = |new_conflicts: &mut BitSet<_>| { - // Remove all locals that are not candidates. - new_conflicts.intersect(relevant_locals); - - for local in new_conflicts.iter() { - conflicts.union_row_with(&new_conflicts, local); - } - }; - let def_id = source.def_id(); let mut init = MaybeInitializedLocals .into_engine(tcx, body, def_id) @@ -494,6 +487,12 @@ impl Conflicts { }, ); + let mut this = Self { + relevant_locals, + matrix: conflicts, + unify_cache: BitSet::new_empty(body.local_decls.len()), + }; + let mut live_and_init_locals = Vec::new(); // Visit only reachable basic blocks. The exact order is not important. @@ -511,14 +510,22 @@ impl Conflicts { BitSet::new_empty(body.local_decls.len()) }); - // First, go forwards for `MaybeInitializedLocals`. - for statement_index in 0..=data.statements.len() { - let loc = Location { block, statement_index }; + // First, go forwards for `MaybeInitializedLocals` and apply intra-statement/terminator + // conflicts. + for (i, statement) in data.statements.iter().enumerate() { + this.record_statement_conflicts(statement); + + let loc = Location { block, statement_index: i }; init.seek_before_primary_effect(loc); - live_and_init_locals[statement_index].clone_from(init.get()); + live_and_init_locals[i].clone_from(init.get()); } + this.record_terminator_conflicts(data.terminator()); + let term_loc = Location { block, statement_index: data.statements.len() }; + init.seek_before_primary_effect(term_loc); + live_and_init_locals[term_loc.statement_index].clone_from(init.get()); + // Now, go backwards and union with the liveness results. for statement_index in (0..=data.statements.len()).rev() { let loc = Location { block, statement_index }; @@ -528,7 +535,7 @@ impl Conflicts { trace!("record conflicts at {:?}", loc); - record_conflicts(&mut live_and_init_locals[statement_index]); + this.record_conflicts(&mut live_and_init_locals[statement_index]); } init.seek_to_block_end(block); @@ -537,10 +544,187 @@ impl Conflicts { conflicts.intersect(live.get()); trace!("record conflicts at end of {:?}", block); - record_conflicts(&mut conflicts); + this.record_conflicts(&mut conflicts); } - Self { matrix: conflicts, unify_cache: BitSet::new_empty(body.local_decls.len()) } + this + } + + fn record_conflicts(&mut self, new_conflicts: &mut BitSet) { + // Remove all locals that are not candidates. + new_conflicts.intersect(self.relevant_locals); + + for local in new_conflicts.iter() { + self.matrix.union_row_with(&new_conflicts, local); + } + } + + /// Records locals that must not overlap during the evaluation of `stmt`. These locals conflict + /// and must not be merged. + fn record_statement_conflicts(&mut self, stmt: &Statement<'_>) { + match &stmt.kind { + // While the left and right sides of an assignment must not overlap, we do not mark + // conflicts here as that would make this optimization useless. When we optimize, we + // eliminate the resulting self-assignments automatically. + StatementKind::Assign(_) => {} + + StatementKind::LlvmInlineAsm(asm) => { + // Inputs and outputs must not overlap. + for (_, input) in &*asm.inputs { + if let Some(in_place) = input.place() { + if !in_place.is_indirect() { + for out_place in &*asm.outputs { + if !out_place.is_indirect() && !in_place.is_indirect() { + self.matrix.insert(in_place.local, out_place.local); + self.matrix.insert(out_place.local, in_place.local); + } + } + } + } + } + } + + StatementKind::SetDiscriminant { .. } + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(_, _) + | StatementKind::FakeRead(_, _) + | StatementKind::AscribeUserType(_, _) + | StatementKind::Nop => {} + } + } + + fn record_terminator_conflicts(&mut self, term: &Terminator<'_>) { + match &term.kind { + TerminatorKind::DropAndReplace { location, value, target: _, unwind: _ } => { + if let Some(place) = value.place() { + if !place.is_indirect() && !location.is_indirect() { + self.matrix.insert(place.local, location.local); + self.matrix.insert(location.local, place.local); + } + } + } + TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { + if let Some(place) = value.place() { + if !place.is_indirect() && !resume_arg.is_indirect() { + self.matrix.insert(place.local, resume_arg.local); + self.matrix.insert(resume_arg.local, place.local); + } + } + } + TerminatorKind::Call { + func, + args, + destination: Some((dest_place, _)), + cleanup: _, + from_hir_call: _, + } => { + // No arguments may overlap with the destination. + for arg in args.iter().chain(Some(func)) { + if let Some(place) = arg.place() { + if !place.is_indirect() && !dest_place.is_indirect() { + self.matrix.insert(dest_place.local, place.local); + self.matrix.insert(place.local, dest_place.local); + } + } + } + } + TerminatorKind::InlineAsm { + template: _, + operands, + options: _, + line_spans: _, + destination: _, + } => { + // The intended semantics here aren't documented, we just assume that nothing that + // could be written to by the assembly may overlap with any other operands. + for op in operands { + match op { + InlineAsmOperand::Out { reg: _, late: _, place: Some(dest_place) } + | InlineAsmOperand::InOut { + reg: _, + late: _, + in_value: _, + out_place: Some(dest_place), + } => { + // For output place `place`, add all places accessed by the inline asm. + for op in operands { + match op { + InlineAsmOperand::In { reg: _, value } => { + if let Some(p) = value.place() { + if !p.is_indirect() && !dest_place.is_indirect() { + self.matrix.insert(p.local, dest_place.local); + self.matrix.insert(dest_place.local, p.local); + } + } + } + InlineAsmOperand::Out { + reg: _, + late: _, + place: Some(place), + } => { + if !place.is_indirect() && !dest_place.is_indirect() { + self.matrix.insert(place.local, dest_place.local); + self.matrix.insert(dest_place.local, place.local); + } + } + InlineAsmOperand::InOut { + reg: _, + late: _, + in_value, + out_place, + } => { + if let Some(place) = in_value.place() { + if !place.is_indirect() && !dest_place.is_indirect() { + self.matrix.insert(place.local, dest_place.local); + self.matrix.insert(dest_place.local, place.local); + } + } + + if let Some(place) = out_place { + if !place.is_indirect() && !dest_place.is_indirect() { + self.matrix.insert(place.local, dest_place.local); + self.matrix.insert(dest_place.local, place.local); + } + } + } + InlineAsmOperand::Out { reg: _, late: _, place: None } + | InlineAsmOperand::Const { value: _ } + | InlineAsmOperand::SymFn { value: _ } + | InlineAsmOperand::SymStatic { value: _ } => {} + } + } + } + InlineAsmOperand::Const { value } => { + assert!(value.place().is_none()); + } + InlineAsmOperand::InOut { + reg: _, + late: _, + in_value: _, + out_place: None, + } + | InlineAsmOperand::In { reg: _, value: _ } + | InlineAsmOperand::Out { reg: _, late: _, place: None } + | InlineAsmOperand::SymFn { value: _ } + | InlineAsmOperand::SymStatic { value: _ } => {} + } + } + } + + TerminatorKind::Goto { .. } + | TerminatorKind::Call { destination: None, .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdges { .. } + | TerminatorKind::FalseUnwind { .. } => {} + } } fn contains(&self, a: Local, b: Local) -> bool {