diff --git a/rules/bad_defer.go b/rules/bad_defer.go index 13b42070da..f6ca0be81f 100644 --- a/rules/bad_defer.go +++ b/rules/bad_defer.go @@ -38,10 +38,11 @@ func contains(methods []string, method string) bool { func (r *badDefer) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { if deferStmt, ok := n.(*ast.DeferStmt); ok { for _, deferTyp := range r.types { - if typ, method, err := gosec.GetCallInfo(deferStmt.Call, c); err == nil { - if normalize(typ) == deferTyp.typ && contains(deferTyp.methods, method) { - return gosec.NewIssue(c, n, r.ID(), fmt.Sprintf(r.What, method, typ), r.Severity, r.Confidence), nil - } + if issue := r.checkChild(n, c, deferStmt.Call, deferTyp); issue != nil { + return issue, nil + } + if issue := r.checkFunction(n, c, deferStmt, deferTyp); issue != nil { + return issue, nil } } } @@ -49,6 +50,42 @@ func (r *badDefer) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) { return nil, nil } +func (r *badDefer) checkChild(n ast.Node, c *gosec.Context, callExp *ast.CallExpr, deferTyp deferType) *gosec.Issue { + if typ, method, err := gosec.GetCallInfo(callExp, c); err == nil { + if normalize(typ) == deferTyp.typ && contains(deferTyp.methods, method) { + return gosec.NewIssue(c, n, r.ID(), fmt.Sprintf(r.What, method, typ), r.Severity, r.Confidence) + } + } + return nil +} + +func (r *badDefer) checkFunction(n ast.Node, c *gosec.Context, deferStmt *ast.DeferStmt, deferTyp deferType) *gosec.Issue { + if anonFunc, isAnonFunc := deferStmt.Call.Fun.(*ast.FuncLit); isAnonFunc { + for _, subElem := range anonFunc.Body.List { + if issue := r.checkStmt(n, c, subElem, deferTyp); issue != nil { + return issue + } + } + } + return nil +} + +func (r *badDefer) checkStmt(n ast.Node, c *gosec.Context, subElem ast.Stmt, deferTyp deferType) *gosec.Issue { + switch stmt := subElem.(type) { + case *ast.AssignStmt: + for _, rh := range stmt.Rhs { + if e, isCallExp := rh.(*ast.CallExpr); isCallExp { + return r.checkChild(n, c, e, deferTyp) + } + } + case *ast.IfStmt: + if s, is := stmt.Init.(*ast.AssignStmt); is { + return r.checkStmt(n, c, s, deferTyp) + } + } + return nil +} + // NewDeferredClosing detects unsafe defer of error returning methods func NewDeferredClosing(id string, conf gosec.Config) (gosec.Rule, []ast.Node) { return &badDefer{ diff --git a/testutils/source.go b/testutils/source.go index 5ce6ad22c6..b336a4834b 100644 --- a/testutils/source.go +++ b/testutils/source.go @@ -1984,7 +1984,6 @@ func main() { {[]string{`package main import ( - "bufio" "fmt" "io/ioutil" "os" @@ -2016,16 +2015,86 @@ func main() { defer check(err) fmt.Printf("wrote %d bytes\n", n2) - n3, err := f.WriteString("writes\n") - fmt.Printf("wrote %d bytes\n", n3) +}`}, 1, gosec.NewConfig()}, + {[]string{`package main - f.Sync() +import ( + "fmt" + "io/ioutil" + "log" + "os" +) - w := bufio.NewWriter(f) - n4, err := w.WriteString("buffered\n") - fmt.Printf("wrote %d bytes\n", n4) +func check(e error) { + if e != nil { + panic(e) + } +} - w.Flush() +func main() { + + d1 := []byte("hello\ngo\n") + err := ioutil.WriteFile("/tmp/dat1", d1, 0744) + check(err) + + allowed := ioutil.WriteFile("/tmp/dat1", d1, 0600) + check(allowed) + + f, err := os.Create("/tmp/dat2") + check(err) + + defer func() { + if err := f.Close(); err != nil { + log.Println(err) + } + }() + + d2 := []byte{115, 111, 109, 101, 10} + n2, err := f.Write(d2) + + defer check(err) + fmt.Printf("wrote %d bytes\n", n2) + +}`}, 1, gosec.NewConfig()}, + {[]string{`package main + +import ( + "fmt" + "io/ioutil" + "log" + "os" +) + +func check(e error) { + if e != nil { + panic(e) + } +} + +func main() { + + d1 := []byte("hello\ngo\n") + err := ioutil.WriteFile("/tmp/dat1", d1, 0744) + check(err) + + allowed := ioutil.WriteFile("/tmp/dat1", d1, 0600) + check(allowed) + + f, err := os.Create("/tmp/dat2") + check(err) + + defer func() { + err := f.Close() + if err != nil { + log.Println(err) + } + }() + + d2 := []byte{115, 111, 109, 101, 10} + n2, err := f.Write(d2) + + defer check(err) + fmt.Printf("wrote %d bytes\n", n2) }`}, 1, gosec.NewConfig()}, }