diff --git a/clippy_lints/src/if_let_mutex.rs b/clippy_lints/src/if_let_mutex.rs index 3998360e4bf..5f83d952fe3 100644 --- a/clippy_lints/src/if_let_mutex.rs +++ b/clippy_lints/src/if_let_mutex.rs @@ -1,6 +1,7 @@ use crate::utils::{match_type, paths, span_lint_and_help}; -use if_chain::if_chain; -use rustc_hir::{Arm, Expr, ExprKind, MatchSource, Stmt, StmtKind}; +use rustc::hir::map::Map; +use rustc_hir::intravisit::{self as visit, NestedVisitorMap, Visitor}; +use rustc_hir::{Arm, Expr, ExprKind, MatchSource, StmtKind}; use rustc_lint::{LateContext, LateLintPass}; use rustc_session::{declare_lint_pass, declare_tool_lint}; @@ -40,100 +41,115 @@ declare_lint_pass!(IfLetMutex => [IF_LET_MUTEX]); impl LateLintPass<'_, '_> for IfLetMutex { fn check_expr(&mut self, cx: &LateContext<'_, '_>, ex: &'_ Expr<'_>) { - if_chain! { - if let ExprKind::Match(ref op, ref arms, MatchSource::IfLetDesugar { + let mut arm_visit = ArmVisitor { + arm_mutex: false, + arm_lock: false, + cx, + }; + let mut op_visit = IfLetMutexVisitor { + op_mutex: false, + op_lock: false, + cx, + }; + if let ExprKind::Match( + ref op, + ref arms, + MatchSource::IfLetDesugar { contains_else_clause: true, - }) = ex.kind; // if let ... {} else {} - if let ExprKind::MethodCall(_, _, ref args) = op.kind; - let ty = cx.tables.expr_ty(&args[0]); - if match_type(cx, ty, &paths::MUTEX); // make sure receiver is Mutex - if method_chain_names(op, 10).iter().any(|s| s == "lock"); // and lock is called + }, + ) = ex.kind + { + op_visit.visit_expr(op); + if op_visit.op_mutex && op_visit.op_lock { + for arm in *arms { + arm_visit.visit_arm(arm); + } - if arms.iter().any(|arm| matching_arm(arm, op, ex, cx)); - then { - span_lint_and_help( - cx, - IF_LET_MUTEX, - ex.span, - "calling `Mutex::lock` inside the scope of another `Mutex::lock` causes a deadlock", - "move the lock call outside of the `if let ...` expression", - ); + if arm_visit.arm_mutex && arm_visit.arm_lock { + span_lint_and_help( + cx, + IF_LET_MUTEX, + ex.span, + "calling `Mutex::lock` inside the scope of another `Mutex::lock` causes a deadlock", + "move the lock call outside of the `if let ...` expression", + ); + } } } } } -fn matching_arm(arm: &Arm<'_>, op: &Expr<'_>, ex: &Expr<'_>, cx: &LateContext<'_, '_>) -> bool { - if let ExprKind::Block(ref block, _l) = arm.body.kind { - block.stmts.iter().any(|stmt| matching_stmt(stmt, op, ex, cx)) - } else { - false - } +/// Checks if `Mutex::lock` is called in the `if let _ = expr. +pub struct IfLetMutexVisitor<'tcx, 'l> { + pub op_mutex: bool, + pub op_lock: bool, + pub cx: &'tcx LateContext<'tcx, 'l>, } -fn matching_stmt(stmt: &Stmt<'_>, op: &Expr<'_>, ex: &Expr<'_>, cx: &LateContext<'_, '_>) -> bool { - match stmt.kind { - StmtKind::Local(l) => if_chain! { - if let Some(ex) = l.init; - if let ExprKind::MethodCall(_, _, _) = op.kind; - if method_chain_names(ex, 10).iter().any(|s| s == "lock"); // and lock is called - then { - match_type_method_chain(cx, ex, 5) - } else { - false - } - }, - StmtKind::Expr(e) => if_chain! { - if let ExprKind::MethodCall(_, _, _) = e.kind; - if method_chain_names(e, 10).iter().any(|s| s == "lock"); // and lock is called - then { - match_type_method_chain(cx, ex, 5) - } else { - false - } - }, - StmtKind::Semi(e) => if_chain! { - if let ExprKind::MethodCall(_, _, _) = e.kind; - if method_chain_names(e, 10).iter().any(|s| s == "lock"); // and lock is called - then { - match_type_method_chain(cx, ex, 5) - } else { - false - } - }, - _ => false, - } -} +impl<'tcx, 'l> Visitor<'tcx> for IfLetMutexVisitor<'tcx, 'l> { + type Map = Map<'tcx>; -/// Return the names of `max_depth` number of methods called in the chain. -fn method_chain_names<'tcx>(expr: &'tcx Expr<'tcx>, max_depth: usize) -> Vec { - let mut method_names = Vec::with_capacity(max_depth); - let mut current = expr; - for _ in 0..max_depth { - if let ExprKind::MethodCall(path, _, args) = ¤t.kind { - if args.iter().any(|e| e.span.from_expansion()) { - break; + fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { + if let ExprKind::MethodCall(path, _span, args) = &expr.kind { + if path.ident.to_string() == "lock" { + self.op_lock = true; + } + let ty = self.cx.tables.expr_ty(&args[0]); + if match_type(self.cx, ty, &paths::MUTEX) { + self.op_mutex = true; } - method_names.push(path.ident.to_string()); - current = &args[0]; - } else { - break; } + visit::walk_expr(self, expr); + } + + fn nested_visit_map(&mut self) -> NestedVisitorMap { + NestedVisitorMap::None } - method_names } -/// Check that lock is called on a `Mutex`. -fn match_type_method_chain<'tcx>(cx: &LateContext<'_, '_>, expr: &'tcx Expr<'tcx>, max_depth: usize) -> bool { - let mut current = expr; - for _ in 0..max_depth { - if let ExprKind::MethodCall(_, _, args) = ¤t.kind { - let ty = cx.tables.expr_ty(&args[0]); - if match_type(cx, ty, &paths::MUTEX) { - return true; - } - current = &args[0]; - } - } - false +/// Checks if `Mutex::lock` is called in any of the branches. +pub struct ArmVisitor<'tcx, 'l> { + pub arm_mutex: bool, + pub arm_lock: bool, + pub cx: &'tcx LateContext<'tcx, 'l>, +} + +impl<'tcx, 'l> Visitor<'tcx> for ArmVisitor<'tcx, 'l> { + type Map = Map<'tcx>; + + fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { + if let ExprKind::MethodCall(path, _span, args) = &expr.kind { + if path.ident.to_string() == "lock" { + self.arm_lock = true; + } + let ty = self.cx.tables.expr_ty(&args[0]); + if match_type(self.cx, ty, &paths::MUTEX) { + self.arm_mutex = true; + } + } + visit::walk_expr(self, expr); + } + + fn visit_arm(&mut self, arm: &'tcx Arm<'_>) { + if let ExprKind::Block(ref block, _l) = arm.body.kind { + for stmt in block.stmts { + match stmt.kind { + StmtKind::Local(loc) => { + if let Some(expr) = loc.init { + self.visit_expr(expr) + } + }, + StmtKind::Expr(expr) => self.visit_expr(expr), + StmtKind::Semi(expr) => self.visit_expr(expr), + // we don't care about `Item` + _ => {}, + } + } + }; + visit::walk_arm(self, arm); + } + + fn nested_visit_map(&mut self) -> NestedVisitorMap { + NestedVisitorMap::None + } } diff --git a/tests/ui/if_let_mutex.rs b/tests/ui/if_let_mutex.rs index 059764e9b21..1d7cc756dc5 100644 --- a/tests/ui/if_let_mutex.rs +++ b/tests/ui/if_let_mutex.rs @@ -3,9 +3,9 @@ use std::sync::Mutex; fn do_stuff(_: T) {} -fn foo() { - let m = Mutex::new(1u8); +fn if_let() { + let m = Mutex::new(1u8); if let Err(locked) = m.lock() { do_stuff(locked); } else {