Rollup merge of #60745 - wesleywiser:const_prop_into_terminators, r=oli-obk

Perform constant propagation into terminators

Perform constant propagation into MIR `Assert` and `SwitchInt` `Terminator`s which in some cases allows them to be removed by the branch simplification pass.

r? @oli-obk
This commit is contained in:
Mazdak Farrokhzad 2019-05-20 01:01:38 +02:00 committed by GitHub
commit 5c84d779b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 152 additions and 73 deletions

View File

@ -546,6 +546,10 @@ impl<'a, 'mir, 'tcx> ConstPropagator<'a, 'mir, 'tcx> {
}
}
}
fn should_const_prop(&self) -> bool {
self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2
}
}
fn type_size_of<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
@ -639,7 +643,7 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
assert!(self.places[local].is_none());
self.places[local] = Some(value);
if self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2 {
if self.should_const_prop() {
self.replace_with_const(rval, value, statement.source_info.span);
}
}
@ -656,75 +660,112 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
location: Location,
) {
self.super_terminator(terminator, location);
let source_info = terminator.source_info;;
if let TerminatorKind::Assert { expected, msg, cond, .. } = &terminator.kind {
if let Some(value) = self.eval_operand(&cond, source_info) {
trace!("assertion on {:?} should be {:?}", value, expected);
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
if expected != self.ecx.read_scalar(value).unwrap() {
// poison all places this operand references so that further code
// doesn't use the invalid value
match cond {
Operand::Move(ref place) | Operand::Copy(ref place) => {
let mut place = place;
while let Place::Projection(ref proj) = *place {
place = &proj.base;
let source_info = terminator.source_info;
match &mut terminator.kind {
TerminatorKind::Assert { expected, msg, ref mut cond, .. } => {
if let Some(value) = self.eval_operand(&cond, source_info) {
trace!("assertion on {:?} should be {:?}", value, expected);
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
let value_const = self.ecx.read_scalar(value).unwrap();
if expected != value_const {
// poison all places this operand references so that further code
// doesn't use the invalid value
match cond {
Operand::Move(ref place) | Operand::Copy(ref place) => {
let mut place = place;
while let Place::Projection(ref proj) = *place {
place = &proj.base;
}
if let Place::Base(PlaceBase::Local(local)) = *place {
self.places[local] = None;
}
},
Operand::Constant(_) => {}
}
let span = terminator.source_info.span;
let hir_id = self
.tcx
.hir()
.as_local_hir_id(self.source.def_id())
.expect("some part of a failing const eval must be local");
use rustc::mir::interpret::InterpError::*;
let msg = match msg {
Overflow(_) |
OverflowNeg |
DivisionByZero |
RemainderByZero => msg.description().to_owned(),
BoundsCheck { ref len, ref index } => {
let len = self
.eval_operand(len, source_info)
.expect("len must be const");
let len = match self.ecx.read_scalar(len) {
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
bits, ..
})) => bits,
other => bug!("const len not primitive: {:?}", other),
};
let index = self
.eval_operand(index, source_info)
.expect("index must be const");
let index = match self.ecx.read_scalar(index) {
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
bits, ..
})) => bits,
other => bug!("const index not primitive: {:?}", other),
};
format!(
"index out of bounds: \
the len is {} but the index is {}",
len,
index,
)
},
// Need proper const propagator for these
_ => return,
};
self.tcx.lint_hir(
::rustc::lint::builtin::CONST_ERR,
hir_id,
span,
&msg,
);
} else {
if self.should_const_prop() {
if let ScalarMaybeUndef::Scalar(scalar) = value_const {
*cond = self.operand_from_scalar(
scalar,
self.tcx.types.bool,
source_info.span,
);
}
if let Place::Base(PlaceBase::Local(local)) = *place {
self.places[local] = None;
}
},
Operand::Constant(_) => {}
}
}
let span = terminator.source_info.span;
let hir_id = self
.tcx
.hir()
.as_local_hir_id(self.source.def_id())
.expect("some part of a failing const eval must be local");
use rustc::mir::interpret::InterpError::*;
let msg = match msg {
Overflow(_) |
OverflowNeg |
DivisionByZero |
RemainderByZero => msg.description().to_owned(),
BoundsCheck { ref len, ref index } => {
let len = self
.eval_operand(len, source_info)
.expect("len must be const");
let len = match self.ecx.read_scalar(len) {
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
bits, ..
})) => bits,
other => bug!("const len not primitive: {:?}", other),
};
let index = self
.eval_operand(index, source_info)
.expect("index must be const");
let index = match self.ecx.read_scalar(index) {
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
bits, ..
})) => bits,
other => bug!("const index not primitive: {:?}", other),
};
format!(
"index out of bounds: \
the len is {} but the index is {}",
len,
index,
)
},
// Need proper const propagator for these
_ => return,
};
self.tcx.lint_hir(
::rustc::lint::builtin::CONST_ERR,
hir_id,
span,
&msg,
);
}
}
},
TerminatorKind::SwitchInt { ref mut discr, switch_ty, .. } => {
if self.should_const_prop() {
if let Some(value) = self.eval_operand(&discr, source_info) {
if let ScalarMaybeUndef::Scalar(scalar) =
self.ecx.read_scalar(value).unwrap() {
*discr = self.operand_from_scalar(scalar, switch_ty, source_info.span);
}
}
}
},
//none of these have Operands to const-propagate
TerminatorKind::Goto { .. } |
TerminatorKind::Resume |
TerminatorKind::Abort |
TerminatorKind::Return |
TerminatorKind::Unreachable |
TerminatorKind::Drop { .. } |
TerminatorKind::DropAndReplace { .. } |
TerminatorKind::Yield { .. } |
TerminatorKind::GeneratorDrop |
TerminatorKind::FalseEdges { .. } |
TerminatorKind::FalseUnwind { .. } => { }
//FIXME(wesleywiser) Call does have Operands that could be const-propagated
TerminatorKind::Call { .. } => { }
}
}
}

View File

@ -23,7 +23,7 @@ fn main() {
// bb0: {
// ...
// _5 = const true;
// assert(move _5, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
// assert(const true, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
// }
// bb1: {
// _1 = _2[_3];

View File

@ -16,6 +16,6 @@ fn main() {
// bb0: {
// ...
// _2 = (const 2u32, const false);
// assert(!move (_2.1: bool), "attempt to add with overflow") -> bb1;
// assert(!const false, "attempt to add with overflow") -> bb1;
// }
// END rustc.main.ConstProp.after.mir

View File

@ -0,0 +1,38 @@
#[inline(never)]
fn foo(_: i32) { }
fn main() {
match 1 {
1 => foo(0),
_ => foo(-1),
}
}
// END RUST SOURCE
// START rustc.main.ConstProp.before.mir
// bb0: {
// ...
// _1 = const 1i32;
// switchInt(_1) -> [1i32: bb1, otherwise: bb2];
// }
// END rustc.main.ConstProp.before.mir
// START rustc.main.ConstProp.after.mir
// bb0: {
// ...
// switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
// }
// END rustc.main.ConstProp.after.mir
// START rustc.main.SimplifyBranches-after-const-prop.before.mir
// bb0: {
// ...
// _1 = const 1i32;
// switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
// }
// END rustc.main.SimplifyBranches-after-const-prop.before.mir
// START rustc.main.SimplifyBranches-after-const-prop.after.mir
// bb0: {
// ...
// _1 = const 1i32;
// goto -> bb1;
// }
// END rustc.main.SimplifyBranches-after-const-prop.after.mir

View File

@ -5,15 +5,15 @@ fn main() {
}
// END RUST SOURCE
// START rustc.main.SimplifyBranches-after-copy-prop.before.mir
// START rustc.main.SimplifyBranches-after-const-prop.before.mir
// bb0: {
// ...
// switchInt(const false) -> [false: bb3, otherwise: bb1];
// }
// END rustc.main.SimplifyBranches-after-copy-prop.before.mir
// START rustc.main.SimplifyBranches-after-copy-prop.after.mir
// END rustc.main.SimplifyBranches-after-const-prop.before.mir
// START rustc.main.SimplifyBranches-after-const-prop.after.mir
// bb0: {
// ...
// goto -> bb3;
// }
// END rustc.main.SimplifyBranches-after-copy-prop.after.mir
// END rustc.main.SimplifyBranches-after-const-prop.after.mir