From fd9133b9c336e55d12bd5da1c610949473844dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Esteban=20K=C3=BCber?= Date: Mon, 17 Aug 2020 22:39:46 -0700 Subject: [PATCH] Suggest boxed trait objects in tail `match` and `if` expressions When encountering a `match` or `if` as a tail expression where the different arms do not have the same type *and* the return type of that `fn` is an `impl Trait`, check whether those arms can implement `Trait` and if so, suggest using boxed trait objects. --- .../src/infer/error_reporting/mod.rs | 66 +++++++++++++- compiler/rustc_middle/src/traits/mod.rs | 3 + compiler/rustc_typeck/src/check/_match.rs | 85 +++++++++++++++++-- compiler/rustc_typeck/src/check/mod.rs | 18 ++++ ...-to-type-err-cause-on-impl-trait-return.rs | 8 +- ...type-err-cause-on-impl-trait-return.stderr | 35 +++++++- 6 files changed, 203 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs index 1225776db45..e772aacfbf9 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs @@ -617,11 +617,20 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { ref prior_arms, last_ty, scrut_hir_id, + suggest_box, + arm_span, .. }) => match source { hir::MatchSource::IfLetDesugar { .. } => { let msg = "`if let` arms have incompatible types"; err.span_label(cause.span, msg); + if let Some(ret_sp) = suggest_box { + self.suggest_boxing_for_return_impl_trait( + err, + ret_sp, + prior_arms.iter().chain(std::iter::once(&arm_span)).map(|s| *s), + ); + } } hir::MatchSource::TryDesugar => { if let Some(ty::error::ExpectedFound { expected, .. }) = exp_found { @@ -675,9 +684,23 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { Applicability::MachineApplicable, ); } + if let Some(ret_sp) = suggest_box { + // Get return type span and point to it. + self.suggest_boxing_for_return_impl_trait( + err, + ret_sp, + prior_arms.iter().chain(std::iter::once(&arm_span)).map(|s| *s), + ); + } } }, - ObligationCauseCode::IfExpression(box IfExpressionCause { then, outer, semicolon }) => { + ObligationCauseCode::IfExpression(box IfExpressionCause { + then, + else_sp, + outer, + semicolon, + suggest_box, + }) => { err.span_label(then, "expected because of this"); if let Some(sp) = outer { err.span_label(sp, "`if` and `else` have incompatible types"); @@ -690,11 +713,52 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { Applicability::MachineApplicable, ); } + if let Some(ret_sp) = suggest_box { + self.suggest_boxing_for_return_impl_trait( + err, + ret_sp, + vec![then, else_sp].into_iter(), + ); + } } _ => (), } } + fn suggest_boxing_for_return_impl_trait( + &self, + err: &mut DiagnosticBuilder<'tcx>, + return_sp: Span, + arm_spans: impl Iterator, + ) { + let snippet = self + .tcx + .sess + .source_map() + .span_to_snippet(return_sp) + .unwrap_or_else(|_| "dyn Trait".to_string()); + err.span_suggestion_verbose( + return_sp, + "you could change the return type to be a boxed trait object", + format!("Box", &snippet[5..]), + Applicability::MaybeIncorrect, + ); + let sugg = arm_spans + .flat_map(|sp| { + vec![ + (sp.shrink_to_lo(), "Box::new(".to_string()), + (sp.shrink_to_hi(), ")".to_string()), + ] + .into_iter() + }) + .collect::>(); + err.multipart_suggestion( + "if you change the return type to expect trait objects box the returned expressions", + sugg, + Applicability::MaybeIncorrect, + ); + } + /// Given that `other_ty` is the same as a type argument for `name` in `sub`, populate `value` /// highlighting `name` and every type argument that isn't at `pos` (which is `other_ty`), and /// populate `other_value` with `other_ty`. diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs index f86403fa502..71582b7ed73 100644 --- a/compiler/rustc_middle/src/traits/mod.rs +++ b/compiler/rustc_middle/src/traits/mod.rs @@ -350,13 +350,16 @@ pub struct MatchExpressionArmCause<'tcx> { pub prior_arms: Vec, pub last_ty: Ty<'tcx>, pub scrut_hir_id: hir::HirId, + pub suggest_box: Option, } #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct IfExpressionCause { pub then: Span, + pub else_sp: Span, pub outer: Option, pub semicolon: Option, + pub suggest_box: Option, } #[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)] diff --git a/compiler/rustc_typeck/src/check/_match.rs b/compiler/rustc_typeck/src/check/_match.rs index afd4413069e..98e53a25879 100644 --- a/compiler/rustc_typeck/src/check/_match.rs +++ b/compiler/rustc_typeck/src/check/_match.rs @@ -1,12 +1,15 @@ use crate::check::coercion::CoerceMany; use crate::check::{Diverges, Expectation, FnCtxt, Needs}; -use rustc_hir as hir; -use rustc_hir::ExprKind; +use rustc_hir::{self as hir, ExprKind}; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; -use rustc_middle::ty::Ty; +use rustc_infer::traits::Obligation; +use rustc_middle::ty::{self, ToPredicate, Ty}; use rustc_span::Span; -use rustc_trait_selection::traits::ObligationCauseCode; -use rustc_trait_selection::traits::{IfExpressionCause, MatchExpressionArmCause, ObligationCause}; +use rustc_trait_selection::opaque_types::InferCtxtExt as _; +use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _; +use rustc_trait_selection::traits::{ + IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode, +}; impl<'a, 'tcx> FnCtxt<'a, 'tcx> { pub fn check_match( @@ -14,7 +17,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expr: &'tcx hir::Expr<'tcx>, scrut: &'tcx hir::Expr<'tcx>, arms: &'tcx [hir::Arm<'tcx>], - expected: Expectation<'tcx>, + orig_expected: Expectation<'tcx>, match_src: hir::MatchSource, ) -> Ty<'tcx> { let tcx = self.tcx; @@ -22,7 +25,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { use hir::MatchSource::*; let (source_if, if_no_else, force_scrutinee_bool) = match match_src { IfDesugar { contains_else_clause } => (true, !contains_else_clause, true), - IfLetDesugar { contains_else_clause } => (true, !contains_else_clause, false), + IfLetDesugar { contains_else_clause, .. } => (true, !contains_else_clause, false), WhileDesugar => (false, false, true), _ => (false, false, false), }; @@ -69,7 +72,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // type in that case) let mut all_arms_diverge = Diverges::WarnedAlways; - let expected = expected.adjust_for_branches(self); + let expected = orig_expected.adjust_for_branches(self); let mut coercion = { let coerce_first = match expected { @@ -112,6 +115,59 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { self.check_expr_with_expectation(&arm.body, expected) }; all_arms_diverge &= self.diverges.get(); + + // When we have a `match` as a tail expression in a `fn` with a returned `impl Trait` + // we check if the different arms would work with boxed trait objects instead and + // provide a structured suggestion in that case. + let suggest_box = match ( + orig_expected, + self.body_id + .owner + .to_def_id() + .as_local() + .and_then(|id| self.ret_coercion_impl_trait.map(|ty| (id, ty))), + ) { + (Expectation::ExpectHasType(expected), Some((id, ty))) + if self.in_tail_expr && self.can_coerce(arm_ty, expected) => + { + let impl_trait_ret_ty = self.infcx.instantiate_opaque_types( + id, + self.body_id, + self.param_env, + &ty, + arm.body.span, + ); + let mut suggest_box = !impl_trait_ret_ty.obligations.is_empty(); + for o in impl_trait_ret_ty.obligations { + match o.predicate.skip_binders_unchecked() { + ty::PredicateAtom::Trait(t, constness) => { + let pred = ty::PredicateAtom::Trait( + ty::TraitPredicate { + trait_ref: ty::TraitRef { + def_id: t.def_id(), + substs: self.infcx.tcx.mk_substs_trait(arm_ty, &[]), + }, + }, + constness, + ); + let obl = Obligation::new( + o.cause.clone(), + self.param_env, + pred.to_predicate(self.infcx.tcx), + ); + suggest_box &= self.infcx.predicate_must_hold_modulo_regions(&obl); + if !suggest_box { + break; + } + } + _ => {} + } + } + if suggest_box { self.ret_type_span.clone() } else { None } + } + _ => None, + }; + if source_if { let then_expr = &arms[0].body; match (i, if_no_else) { @@ -119,7 +175,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { (_, true) => {} // Handled above to avoid duplicated type errors (#60254). (_, _) => { let then_ty = prior_arm_ty.unwrap(); - let cause = self.if_cause(expr.span, then_expr, &arm.body, then_ty, arm_ty); + let cause = self.if_cause( + expr.span, + then_expr, + &arm.body, + then_ty, + arm_ty, + suggest_box, + ); coercion.coerce(self, &cause, &arm.body, arm_ty); } } @@ -142,6 +205,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { prior_arms: other_arms.clone(), last_ty: prior_arm_ty.unwrap(), scrut_hir_id: scrut.hir_id, + suggest_box, }), ), }; @@ -266,6 +330,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { else_expr: &'tcx hir::Expr<'tcx>, then_ty: Ty<'tcx>, else_ty: Ty<'tcx>, + suggest_box: Option, ) -> ObligationCause<'tcx> { let mut outer_sp = if self.tcx.sess.source_map().is_multiline(span) { // The `if`/`else` isn't in one line in the output, include some context to make it @@ -353,8 +418,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { error_sp, ObligationCauseCode::IfExpression(box IfExpressionCause { then: then_sp, + else_sp: error_sp, outer: outer_sp, semicolon: remove_semicolon, + suggest_box, }), ) } diff --git a/compiler/rustc_typeck/src/check/mod.rs b/compiler/rustc_typeck/src/check/mod.rs index cb5f5731aa6..d6abf5f225d 100644 --- a/compiler/rustc_typeck/src/check/mod.rs +++ b/compiler/rustc_typeck/src/check/mod.rs @@ -570,6 +570,14 @@ pub struct FnCtxt<'a, 'tcx> { /// any). ret_coercion: Option>>, + ret_coercion_impl_trait: Option>, + + ret_type_span: Option, + + /// Used exclusively to reduce cost of advanced evaluation used for + /// more helpful diagnostics. + in_tail_expr: bool, + /// First span of a return site that we find. Used in error messages. ret_coercion_span: RefCell>, @@ -1302,10 +1310,15 @@ fn check_fn<'a, 'tcx>( let hir = tcx.hir(); let declared_ret_ty = fn_sig.output(); + let revealed_ret_ty = fcx.instantiate_opaque_types_from_value(fn_id, &declared_ret_ty, decl.output.span()); debug!("check_fn: declared_ret_ty: {}, revealed_ret_ty: {}", declared_ret_ty, revealed_ret_ty); fcx.ret_coercion = Some(RefCell::new(CoerceMany::new(revealed_ret_ty))); + fcx.ret_type_span = Some(decl.output.span()); + if let ty::Opaque(..) = declared_ret_ty.kind() { + fcx.ret_coercion_impl_trait = Some(declared_ret_ty); + } fn_sig = tcx.mk_fn_sig( fn_sig.inputs().iter().cloned(), revealed_ret_ty, @@ -1366,6 +1379,7 @@ fn check_fn<'a, 'tcx>( inherited.typeck_results.borrow_mut().liberated_fn_sigs_mut().insert(fn_id, fn_sig); + fcx.in_tail_expr = true; if let ty::Dynamic(..) = declared_ret_ty.kind() { // FIXME: We need to verify that the return type is `Sized` after the return expression has // been evaluated so that we have types available for all the nodes being returned, but that @@ -1385,6 +1399,7 @@ fn check_fn<'a, 'tcx>( fcx.require_type_is_sized(declared_ret_ty, decl.output.span(), traits::SizedReturnType); fcx.check_return_expr(&body.value); } + fcx.in_tail_expr = false; // We insert the deferred_generator_interiors entry after visiting the body. // This ensures that all nested generators appear before the entry of this generator. @@ -3084,6 +3099,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { param_env, err_count_on_creation: inh.tcx.sess.err_count(), ret_coercion: None, + ret_coercion_impl_trait: None, + ret_type_span: None, + in_tail_expr: false, ret_coercion_span: RefCell::new(None), resume_yield_tys: None, ps: RefCell::new(UnsafetyState::function(hir::Unsafety::Normal, hir::CRATE_HIR_ID)), diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs index 6114e8c72a3..1a477e5db99 100644 --- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs +++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs @@ -53,7 +53,13 @@ fn cat() -> impl std::fmt::Display { 1u32 //~ ERROR mismatched types } } - } +} + +fn dog() -> impl std::fmt::Display { + match 13 { + 0 => 0i32, + 1 => 1u32, //~ ERROR `match` arms have incompatible types + _ => 2u32, } } diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr index 4e4d57f35d2..a198203e245 100644 --- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr +++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr @@ -72,6 +72,17 @@ LL | | 1u32 | | ^^^^ expected `i32`, found `u32` LL | | } | |_____- `if` and `else` have incompatible types + | +help: you could change the return type to be a boxed trait object + | +LL | fn qux() -> Box { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ +help: if you change the return type to expect trait objects box the returned expressions + | +LL | Box::new(0i32) +LL | } else { +LL | Box::new(1u32) + | error[E0308]: mismatched types --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:35:14 @@ -136,6 +147,28 @@ help: you could change the return type to be a boxed trait object LL | fn cat() -> Box { | ^^^^^^^^^^^^^^^^^^^^^^^^^^ -error: aborting due to 7 previous errors +error[E0308]: `match` arms have incompatible types + --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:61:14 + | +LL | / match 13 { +LL | | 0 => 0i32, + | | ---- this is found to be of type `i32` +LL | | 1 => 1u32, + | | ^^^^ expected `i32`, found `u32` +LL | | _ => 2u32, +LL | | } + | |_____- `match` arms have incompatible types + | +help: you could change the return type to be a boxed trait object + | +LL | fn dog() -> Box { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ +help: if you change the return type to expect trait objects box the returned expressions + | +LL | 0 => Box::new(0i32), +LL | 1 => Box::new(1u32), + | + +error: aborting due to 8 previous errors For more information about this error, try `rustc --explain E0308`.