323 lines
8.7 KiB
Go
323 lines
8.7 KiB
Go
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"cmd/vet/internal/cfg"
|
||
|
"fmt"
|
||
|
"go/ast"
|
||
|
"go/types"
|
||
|
"strconv"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
register("lostcancel",
|
||
|
"check for failure to call cancelation function returned by context.WithCancel",
|
||
|
checkLostCancel,
|
||
|
funcDecl, funcLit)
|
||
|
}
|
||
|
|
||
|
const debugLostCancel = false
|
||
|
|
||
|
var contextPackage = "context"
|
||
|
|
||
|
// checkLostCancel reports a failure to the call the cancel function
|
||
|
// returned by context.WithCancel, either because the variable was
|
||
|
// assigned to the blank identifier, or because there exists a
|
||
|
// control-flow path from the call to a return statement and that path
|
||
|
// does not "use" the cancel function. Any reference to the variable
|
||
|
// counts as a use, even within a nested function literal.
|
||
|
//
|
||
|
// checkLostCancel analyzes a single named or literal function.
|
||
|
func checkLostCancel(f *File, node ast.Node) {
|
||
|
// Fast path: bypass check if file doesn't use context.WithCancel.
|
||
|
if !hasImport(f.file, contextPackage) {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Maps each cancel variable to its defining ValueSpec/AssignStmt.
|
||
|
cancelvars := make(map[*types.Var]ast.Node)
|
||
|
|
||
|
// Find the set of cancel vars to analyze.
|
||
|
stack := make([]ast.Node, 0, 32)
|
||
|
ast.Inspect(node, func(n ast.Node) bool {
|
||
|
switch n.(type) {
|
||
|
case *ast.FuncLit:
|
||
|
if len(stack) > 0 {
|
||
|
return false // don't stray into nested functions
|
||
|
}
|
||
|
case nil:
|
||
|
stack = stack[:len(stack)-1] // pop
|
||
|
return true
|
||
|
}
|
||
|
stack = append(stack, n) // push
|
||
|
|
||
|
// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
|
||
|
//
|
||
|
// ctx, cancel := context.WithCancel(...)
|
||
|
// ctx, cancel = context.WithCancel(...)
|
||
|
// var ctx, cancel = context.WithCancel(...)
|
||
|
//
|
||
|
if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) {
|
||
|
var id *ast.Ident // id of cancel var
|
||
|
stmt := stack[len(stack)-3]
|
||
|
switch stmt := stmt.(type) {
|
||
|
case *ast.ValueSpec:
|
||
|
if len(stmt.Names) > 1 {
|
||
|
id = stmt.Names[1]
|
||
|
}
|
||
|
case *ast.AssignStmt:
|
||
|
if len(stmt.Lhs) > 1 {
|
||
|
id, _ = stmt.Lhs[1].(*ast.Ident)
|
||
|
}
|
||
|
}
|
||
|
if id != nil {
|
||
|
if id.Name == "_" {
|
||
|
f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
|
||
|
n.(*ast.SelectorExpr).Sel.Name)
|
||
|
} else if v, ok := f.pkg.uses[id].(*types.Var); ok {
|
||
|
cancelvars[v] = stmt
|
||
|
} else if v, ok := f.pkg.defs[id].(*types.Var); ok {
|
||
|
cancelvars[v] = stmt
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
})
|
||
|
|
||
|
if len(cancelvars) == 0 {
|
||
|
return // no need to build CFG
|
||
|
}
|
||
|
|
||
|
// Tell the CFG builder which functions never return.
|
||
|
info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
|
||
|
mayReturn := func(call *ast.CallExpr) bool {
|
||
|
name := callName(info, call)
|
||
|
return !noReturnFuncs[name]
|
||
|
}
|
||
|
|
||
|
// Build the CFG.
|
||
|
var g *cfg.CFG
|
||
|
var sig *types.Signature
|
||
|
switch node := node.(type) {
|
||
|
case *ast.FuncDecl:
|
||
|
obj := f.pkg.defs[node.Name]
|
||
|
if obj == nil {
|
||
|
return // type error (e.g. duplicate function declaration)
|
||
|
}
|
||
|
sig, _ = obj.Type().(*types.Signature)
|
||
|
g = cfg.New(node.Body, mayReturn)
|
||
|
case *ast.FuncLit:
|
||
|
sig, _ = f.pkg.types[node.Type].Type.(*types.Signature)
|
||
|
g = cfg.New(node.Body, mayReturn)
|
||
|
}
|
||
|
|
||
|
// Print CFG.
|
||
|
if debugLostCancel {
|
||
|
fmt.Println(g.Format(f.fset))
|
||
|
}
|
||
|
|
||
|
// Examine the CFG for each variable in turn.
|
||
|
// (It would be more efficient to analyze all cancelvars in a
|
||
|
// single pass over the AST, but seldom is there more than one.)
|
||
|
for v, stmt := range cancelvars {
|
||
|
if ret := lostCancelPath(f, g, v, stmt, sig); ret != nil {
|
||
|
lineno := f.fset.Position(stmt.Pos()).Line
|
||
|
f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
|
||
|
f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
|
||
|
|
||
|
func hasImport(f *ast.File, path string) bool {
|
||
|
for _, imp := range f.Imports {
|
||
|
v, _ := strconv.Unquote(imp.Path.Value)
|
||
|
if v == path {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// isContextWithCancel reports whether n is one of the qualified identifiers
|
||
|
// context.With{Cancel,Timeout,Deadline}.
|
||
|
func isContextWithCancel(f *File, n ast.Node) bool {
|
||
|
if sel, ok := n.(*ast.SelectorExpr); ok {
|
||
|
switch sel.Sel.Name {
|
||
|
case "WithCancel", "WithTimeout", "WithDeadline":
|
||
|
if x, ok := sel.X.(*ast.Ident); ok {
|
||
|
if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok {
|
||
|
return pkgname.Imported().Path() == contextPackage
|
||
|
}
|
||
|
// Import failed, so we can't check package path.
|
||
|
// Just check the local package name (heuristic).
|
||
|
return x.Name == "context"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// lostCancelPath finds a path through the CFG, from stmt (which defines
|
||
|
// the 'cancel' variable v) to a return statement, that doesn't "use" v.
|
||
|
// If it finds one, it returns the return statement (which may be synthetic).
|
||
|
// sig is the function's type, if known.
|
||
|
func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
|
||
|
vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
|
||
|
|
||
|
// uses reports whether stmts contain a "use" of variable v.
|
||
|
uses := func(f *File, v *types.Var, stmts []ast.Node) bool {
|
||
|
found := false
|
||
|
for _, stmt := range stmts {
|
||
|
ast.Inspect(stmt, func(n ast.Node) bool {
|
||
|
switch n := n.(type) {
|
||
|
case *ast.Ident:
|
||
|
if f.pkg.uses[n] == v {
|
||
|
found = true
|
||
|
}
|
||
|
case *ast.ReturnStmt:
|
||
|
// A naked return statement counts as a use
|
||
|
// of the named result variables.
|
||
|
if n.Results == nil && vIsNamedResult {
|
||
|
found = true
|
||
|
}
|
||
|
}
|
||
|
return !found
|
||
|
})
|
||
|
}
|
||
|
return found
|
||
|
}
|
||
|
|
||
|
// blockUses computes "uses" for each block, caching the result.
|
||
|
memo := make(map[*cfg.Block]bool)
|
||
|
blockUses := func(f *File, v *types.Var, b *cfg.Block) bool {
|
||
|
res, ok := memo[b]
|
||
|
if !ok {
|
||
|
res = uses(f, v, b.Nodes)
|
||
|
memo[b] = res
|
||
|
}
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
// Find the var's defining block in the CFG,
|
||
|
// plus the rest of the statements of that block.
|
||
|
var defblock *cfg.Block
|
||
|
var rest []ast.Node
|
||
|
outer:
|
||
|
for _, b := range g.Blocks {
|
||
|
for i, n := range b.Nodes {
|
||
|
if n == stmt {
|
||
|
defblock = b
|
||
|
rest = b.Nodes[i+1:]
|
||
|
break outer
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if defblock == nil {
|
||
|
panic("internal error: can't find defining block for cancel var")
|
||
|
}
|
||
|
|
||
|
// Is v "used" in the remainder of its defining block?
|
||
|
if uses(f, v, rest) {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Does the defining block return without using v?
|
||
|
if ret := defblock.Return(); ret != nil {
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
// Search the CFG depth-first for a path, from defblock to a
|
||
|
// return block, in which v is never "used".
|
||
|
seen := make(map[*cfg.Block]bool)
|
||
|
var search func(blocks []*cfg.Block) *ast.ReturnStmt
|
||
|
search = func(blocks []*cfg.Block) *ast.ReturnStmt {
|
||
|
for _, b := range blocks {
|
||
|
if !seen[b] {
|
||
|
seen[b] = true
|
||
|
|
||
|
// Prune the search if the block uses v.
|
||
|
if blockUses(f, v, b) {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// Found path to return statement?
|
||
|
if ret := b.Return(); ret != nil {
|
||
|
if debugLostCancel {
|
||
|
fmt.Printf("found path to return in block %s\n", b)
|
||
|
}
|
||
|
return ret // found
|
||
|
}
|
||
|
|
||
|
// Recur
|
||
|
if ret := search(b.Succs); ret != nil {
|
||
|
if debugLostCancel {
|
||
|
fmt.Printf(" from block %s\n", b)
|
||
|
}
|
||
|
return ret
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
return search(defblock.Succs)
|
||
|
}
|
||
|
|
||
|
func tupleContains(tuple *types.Tuple, v *types.Var) bool {
|
||
|
for i := 0; i < tuple.Len(); i++ {
|
||
|
if tuple.At(i) == v {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
var noReturnFuncs = map[string]bool{
|
||
|
"(*testing.common).FailNow": true,
|
||
|
"(*testing.common).Fatal": true,
|
||
|
"(*testing.common).Fatalf": true,
|
||
|
"(*testing.common).Skip": true,
|
||
|
"(*testing.common).SkipNow": true,
|
||
|
"(*testing.common).Skipf": true,
|
||
|
"log.Fatal": true,
|
||
|
"log.Fatalf": true,
|
||
|
"log.Fatalln": true,
|
||
|
"os.Exit": true,
|
||
|
"panic": true,
|
||
|
"runtime.Goexit": true,
|
||
|
}
|
||
|
|
||
|
// callName returns the canonical name of the builtin, method, or
|
||
|
// function called by call, if known.
|
||
|
func callName(info *types.Info, call *ast.CallExpr) string {
|
||
|
switch fun := call.Fun.(type) {
|
||
|
case *ast.Ident:
|
||
|
// builtin, e.g. "panic"
|
||
|
if obj, ok := info.Uses[fun].(*types.Builtin); ok {
|
||
|
return obj.Name()
|
||
|
}
|
||
|
case *ast.SelectorExpr:
|
||
|
if sel, ok := info.Selections[fun]; ok && sel.Kind() == types.MethodVal {
|
||
|
// method call, e.g. "(*testing.common).Fatal"
|
||
|
meth := sel.Obj()
|
||
|
return fmt.Sprintf("(%s).%s",
|
||
|
meth.Type().(*types.Signature).Recv().Type(),
|
||
|
meth.Name())
|
||
|
}
|
||
|
if obj, ok := info.Uses[fun.Sel]; ok {
|
||
|
// qualified identifier, e.g. "os.Exit"
|
||
|
return fmt.Sprintf("%s.%s",
|
||
|
obj.Pkg().Path(),
|
||
|
obj.Name())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// function with no name, or defined in missing imported package
|
||
|
return ""
|
||
|
}
|