check for cycles when unifying const variables

This commit is contained in:
Bastian Kauschke 2020-09-09 09:43:47 +02:00
parent 39245400c5
commit 073127a04f
3 changed files with 237 additions and 14 deletions

View File

@ -45,7 +45,7 @@ use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{self, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::DUMMY_SP;
use rustc_span::{Span, DUMMY_SP};
/// Small-storage-optimized implementation of a map
/// made specifically for caching results.
@ -219,11 +219,11 @@ impl<'infcx, 'tcx> InferCtxt<'infcx, 'tcx> {
}
(ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
return self.unify_const_variable(a_is_expected, vid, b);
return self.unify_const_variable(relation.param_env(), vid, b, a_is_expected);
}
(_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
return self.unify_const_variable(!a_is_expected, vid, a);
return self.unify_const_variable(relation.param_env(), vid, a, !a_is_expected);
}
(ty::ConstKind::Unevaluated(..), _) if self.tcx.lazy_normalization() => {
// FIXME(#59490): Need to remove the leak check to accommodate
@ -247,17 +247,66 @@ impl<'infcx, 'tcx> InferCtxt<'infcx, 'tcx> {
ty::relate::super_relate_consts(relation, a, b)
}
pub fn unify_const_variable(
/// Unifies the const variable `target_vid` with the given constant.
///
/// This also tests if the given const `ct` contains an inference variable which was previously
/// unioned with `target_vid`. If this is the case, inferring `target_vid` to `ct`
/// would result in an infinite type as we continously replace an inference variable
/// in `ct` with `ct` itself.
///
/// This is especially important as unevaluated consts use their parents generics.
/// They therefore often contain unused substs, making these errors far more likely.
///
/// A good example of this is the following:
///
/// ```rust
/// #![feature(const_generics)]
///
/// fn bind<const N: usize>(value: [u8; N]) -> [u8; 3 + 4] {
/// todo!()
/// }
///
/// fn main() {
/// let mut arr = Default::default();
/// arr = bind(arr);
/// }
/// ```
///
/// Here `3 + 4` ends up as `ConstKind::Unevaluated` which uses the generics
/// of `fn bind` (meaning that its substs contain `N`).
///
/// `bind(arr)` now infers that the type of `arr` must be `[u8; N]`.
/// The assignment `arr = bind(arr)` now tries to equate `N` with `3 + 4`.
///
/// As `3 + 4` contains `N` in its substs, this must not succeed.
///
/// See `src/test/ui/const-generics/occurs-check/` for more examples where this is relevant.
fn unify_const_variable(
&self,
param_env: ty::ParamEnv<'tcx>,
target_vid: ty::ConstVid<'tcx>,
ct: &'tcx ty::Const<'tcx>,
vid_is_expected: bool,
vid: ty::ConstVid<'tcx>,
value: &'tcx ty::Const<'tcx>,
) -> RelateResult<'tcx, &'tcx ty::Const<'tcx>> {
let (for_universe, span) = {
let mut inner = self.inner.borrow_mut();
let variable_table = &mut inner.const_unification_table();
let var_value = variable_table.probe_value(target_vid);
match var_value.val {
ConstVariableValue::Known { value } => {
bug!("instantiating {:?} which has a known value {:?}", target_vid, value)
}
ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span),
}
};
let value = ConstInferUnifier { infcx: self, span, param_env, for_universe, target_vid }
.relate(ct, ct)?;
self.inner
.borrow_mut()
.const_unification_table()
.unify_var_value(
vid,
target_vid,
ConstVarValue {
origin: ConstVariableOrigin {
kind: ConstVariableOriginKind::ConstInference,
@ -266,8 +315,8 @@ impl<'infcx, 'tcx> InferCtxt<'infcx, 'tcx> {
val: ConstVariableValue::Known { value },
},
)
.map_err(|e| const_unification_error(vid_is_expected, e))?;
Ok(value)
.map(|()| value)
.map_err(|e| const_unification_error(vid_is_expected, e))
}
fn unify_integral_variable(
@ -422,7 +471,7 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
let for_universe = match self.infcx.inner.borrow_mut().type_variables().probe(for_vid) {
v @ TypeVariableValue::Known { .. } => {
panic!("instantiating {:?} which has a known value {:?}", for_vid, v,)
bug!("instantiating {:?} which has a known value {:?}", for_vid, v,)
}
TypeVariableValue::Unknown { universe } => universe,
};
@ -740,7 +789,6 @@ impl TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
}
}
}
ty::ConstKind::Unevaluated(..) if self.tcx().lazy_normalization() => Ok(c),
_ => relate::super_relate_consts(self, c, c),
}
}
@ -790,3 +838,175 @@ fn float_unification_error<'tcx>(
let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
TypeError::FloatMismatch(ty::relate::expected_found_bool(a_is_expected, a, b))
}
struct ConstInferUnifier<'cx, 'tcx> {
infcx: &'cx InferCtxt<'cx, 'tcx>,
span: Span,
param_env: ty::ParamEnv<'tcx>,
for_universe: ty::UniverseIndex,
/// The vid of the const variable that is in the process of being
/// instantiated; if we find this within the const we are folding,
/// that means we would have created a cyclic const.
target_vid: ty::ConstVid<'tcx>,
}
// We use `TypeRelation` here to propagate `RelateResult` upwards.
//
// Both inputs are expected to be the same.
impl TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.infcx.tcx
}
fn param_env(&self) -> ty::ParamEnv<'tcx> {
self.param_env
}
fn tag(&self) -> &'static str {
"ConstInferUnifier"
}
fn a_is_expected(&self) -> bool {
true
}
fn relate_with_variance<T: Relate<'tcx>>(
&mut self,
_variance: ty::Variance,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
// We don't care about variance here.
self.relate(a, b)
}
fn binders<T>(
&mut self,
a: ty::Binder<T>,
b: ty::Binder<T>,
) -> RelateResult<'tcx, ty::Binder<T>>
where
T: Relate<'tcx>,
{
Ok(ty::Binder::bind(self.relate(a.skip_binder(), b.skip_binder())?))
}
fn tys(&mut self, t: Ty<'tcx>, _t: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
debug_assert_eq!(t, _t);
debug!("ConstInferUnifier: t={:?}", t);
match t.kind() {
&ty::Infer(ty::TyVar(vid)) => {
let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid);
let probe = self.infcx.inner.borrow_mut().type_variables().probe(vid);
match probe {
TypeVariableValue::Known { value: u } => {
debug!("ConstOccursChecker: known value {:?}", u);
self.tys(u, u)
}
TypeVariableValue::Unknown { universe } => {
if self.for_universe.can_name(universe) {
return Ok(t);
}
let origin =
*self.infcx.inner.borrow_mut().type_variables().var_origin(vid);
let new_var_id = self.infcx.inner.borrow_mut().type_variables().new_var(
self.for_universe,
false,
origin,
);
let u = self.tcx().mk_ty_var(new_var_id);
debug!(
"ConstInferUnifier: replacing original vid={:?} with new={:?}",
vid, u
);
Ok(u)
}
}
}
_ => relate::super_relate_tys(self, t, t),
}
}
fn regions(
&mut self,
r: ty::Region<'tcx>,
_r: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
debug_assert_eq!(r, _r);
debug!("ConstInferUnifier: r={:?}", r);
match r {
// Never make variables for regions bound within the type itself,
// nor for erased regions.
ty::ReLateBound(..) | ty::ReErased => {
return Ok(r);
}
ty::RePlaceholder(..)
| ty::ReVar(..)
| ty::ReEmpty(_)
| ty::ReStatic
| ty::ReEarlyBound(..)
| ty::ReFree(..) => {
// see common code below
}
}
let r_universe = self.infcx.universe_of_region(r);
if self.for_universe.can_name(r_universe) {
return Ok(r);
} else {
// FIXME: This is non-ideal because we don't give a
// very descriptive origin for this region variable.
Ok(self.infcx.next_region_var_in_universe(MiscVariable(self.span), self.for_universe))
}
}
fn consts(
&mut self,
c: &'tcx ty::Const<'tcx>,
_c: &'tcx ty::Const<'tcx>,
) -> RelateResult<'tcx, &'tcx ty::Const<'tcx>> {
debug_assert_eq!(c, _c);
debug!("ConstInferUnifier: c={:?}", c);
match c.val {
ty::ConstKind::Infer(InferConst::Var(vid)) => {
let mut inner = self.infcx.inner.borrow_mut();
let variable_table = &mut inner.const_unification_table();
// Check if the current unification would end up
// unifying `target_vid` with a const which contains
// an inference variable which is unioned with `target_vid`.
//
// Not doing so can easily result in stack overflows.
if variable_table.unioned(self.target_vid, vid) {
return Err(TypeError::CyclicConst(c));
}
let var_value = variable_table.probe_value(vid);
match var_value.val {
ConstVariableValue::Known { value: u } => self.consts(u, u),
ConstVariableValue::Unknown { universe } => {
if self.for_universe.can_name(universe) {
Ok(c)
} else {
let new_var_id = variable_table.new_key(ConstVarValue {
origin: var_value.origin,
val: ConstVariableValue::Unknown { universe: self.for_universe },
});
Ok(self.tcx().mk_const_var(new_var_id, c.ty))
}
}
}
}
_ => relate::super_relate_consts(self, c, c),
}
}
}

View File

@ -56,6 +56,7 @@ pub enum TypeError<'tcx> {
/// created a cycle (because it appears somewhere within that
/// type).
CyclicTy(Ty<'tcx>),
CyclicConst(&'tcx ty::Const<'tcx>),
ProjectionMismatched(ExpectedFound<DefId>),
ExistentialMismatch(ExpectedFound<&'tcx ty::List<ty::ExistentialPredicate<'tcx>>>),
ObjectUnsafeCoercion(DefId),
@ -100,6 +101,7 @@ impl<'tcx> fmt::Display for TypeError<'tcx> {
match *self {
CyclicTy(_) => write!(f, "cyclic type of infinite size"),
CyclicConst(_) => write!(f, "encountered a self-referencing constant"),
Mismatch => write!(f, "types differ"),
UnsafetyMismatch(values) => {
write!(f, "expected {} fn, found {} fn", values.expected, values.found)
@ -195,9 +197,9 @@ impl<'tcx> TypeError<'tcx> {
pub fn must_include_note(&self) -> bool {
use self::TypeError::*;
match self {
CyclicTy(_) | UnsafetyMismatch(_) | Mismatch | AbiMismatch(_) | FixedArraySize(_)
| Sorts(_) | IntMismatch(_) | FloatMismatch(_) | VariadicMismatch(_)
| TargetFeatureCast(_) => false,
CyclicTy(_) | CyclicConst(_) | UnsafetyMismatch(_) | Mismatch | AbiMismatch(_)
| FixedArraySize(_) | Sorts(_) | IntMismatch(_) | FloatMismatch(_)
| VariadicMismatch(_) | TargetFeatureCast(_) => false,
Mutability
| TupleSize(_)

View File

@ -689,6 +689,7 @@ impl<'a, 'tcx> Lift<'tcx> for ty::error::TypeError<'a> {
Traits(x) => Traits(x),
VariadicMismatch(x) => VariadicMismatch(x),
CyclicTy(t) => return tcx.lift(&t).map(|t| CyclicTy(t)),
CyclicConst(ct) => return tcx.lift(&ct).map(|ct| CyclicConst(ct)),
ProjectionMismatched(x) => ProjectionMismatched(x),
Sorts(ref x) => return tcx.lift(x).map(Sorts),
ExistentialMismatch(ref x) => return tcx.lift(x).map(ExistentialMismatch),