From 471b3aff5b91d8482753785e7ecc7eedd301579b Mon Sep 17 00:00:00 2001 From: Fredrik de Vibe Date: Sat, 14 Oct 2023 22:02:02 +0200 Subject: [PATCH] Add support for external functions handling rows.Err() --- passes/rowserr/rowserr.go | 155 +++++++++++++++++------------ passes/rowserr/testdata/src/a/a.go | 38 +++++++ 2 files changed, 131 insertions(+), 62 deletions(-) diff --git a/passes/rowserr/rowserr.go b/passes/rowserr/rowserr.go index a142a67..a52b443 100644 --- a/passes/rowserr/rowserr.go +++ b/passes/rowserr/rowserr.go @@ -1,6 +1,7 @@ package rowserr import ( + "fmt" "go/ast" "go/types" @@ -100,88 +101,118 @@ func (r runner) run(pass *analysis.Pass, pkgPath string) { for _, b := range f.Blocks { for i := range b.Instrs { - if r.errCallMissing(b, i) { - pass.Reportf(b.Instrs[i].Pos(), "rows.Err must be checked") + if r.errCallMissing(b.Instrs[i]) { + pass.Reportf(b.Instrs[i].Pos(), fmt.Sprintf("rows.Err must be checked %d", b.Instrs[i].Pos())) } } } } } -func (r *runner) errCallMissing(b *ssa.BasicBlock, i int) (ret bool) { - call, ok := r.getCallReturnsRow(b.Instrs[i]) +func (r *runner) errCalled(resRef ssa.Instruction) bool { + switch resRef := resRef.(type) { + case *ssa.Phi: + for _, rf := range *resRef.Referrers() { + if r.errCalled(rf) { + return true + } + } + case *ssa.Store: // Call in Closure function + for _, aref := range *resRef.Addr.Referrers() { + switch c := aref.(type) { + case *ssa.MakeClosure: + f := c.Fn.(*ssa.Function) + called := r.isClosureCalled(c) + if r.calledInFunc(f, called) { + return true + } + case *ssa.UnOp: + for _, rf := range *c.Referrers() { + if r.errCalled(rf) { + return true + } + } + } + } + case *ssa.Call: // Indirect function call + if r.isErrCall(resRef) { + return true + } + if r.errCalledInFunc(resRef.Call.Value, resRef.Parent().Name()) { + return true + } + case *ssa.FieldAddr: + for _, bRef := range *resRef.Referrers() { + bOp, ok := r.getBodyOp(bRef) + if !ok { + continue + } + + for _, ccall := range *bOp.Referrers() { + if r.isErrCall(ccall) { + return true + } + } + } + case *ssa.Defer: + if r.isErrCall(resRef) { + return true + } + if r.errCalledInFunc(resRef.Call.Value, resRef.Parent().Name()) { + return true + } + } + + return false +} + +func (r *runner) errCalledInFunc(val ssa.Value, name string) bool { + var ( + f *ssa.Function + ok bool + ) + + if f, ok = val.(*ssa.Function); !ok { + var c *ssa.MakeClosure + if c, ok = val.(*ssa.MakeClosure); ok { + f, ok = c.Fn.(*ssa.Function) + } + if !ok { + return false + } + } + + for _, b := range f.Blocks { + for i := range b.Instrs { + if !r.errCallMissing(b.Instrs[i]) { + return true + } + } + } + return false +} + +func (r *runner) errCallMissing(instr ssa.Instruction) (ret bool) { + call, ok := r.getCallReturnsRow(instr) if !ok { return false } for _, cRef := range *call.Referrers() { val, ok := r.getRowsVal(cRef) + // fmt.Printf("- %++v\n", val) if !ok { + // fmt.Printf(" nok\n") continue } if len(*val.Referrers()) == 0 { + // fmt.Printf(" 0 refs") continue } resRefs := *val.Referrers() - var errCalled func(resRef ssa.Instruction) bool - errCalled = func(resRef ssa.Instruction) bool { - switch resRef := resRef.(type) { - case *ssa.Phi: - for _, rf := range *resRef.Referrers() { - if errCalled(rf) { - return true - } - } - case *ssa.Store: // Call in Closure function - for _, aref := range *resRef.Addr.Referrers() { - switch c := aref.(type) { - case *ssa.MakeClosure: - f := c.Fn.(*ssa.Function) - called := r.isClosureCalled(c) - if r.calledInFunc(f, called) { - return true - } - case *ssa.UnOp: - for _, rf := range *c.Referrers() { - if errCalled(rf) { - return true - } - } - } - } - case *ssa.Call: // Indirect function call - if r.isErrCall(resRef) { - return true - } - if f, ok := resRef.Call.Value.(*ssa.Function); ok { - for _, b := range f.Blocks { - for i := range b.Instrs { - if !r.errCallMissing(b, i) { - return true - } - } - } - } - case *ssa.FieldAddr: - for _, bRef := range *resRef.Referrers() { - bOp, ok := r.getBodyOp(bRef) - if !ok { - continue - } - - for _, ccall := range *bOp.Referrers() { - if r.isErrCall(ccall) { - return true - } - } - } - } - - return false - } for _, resRef := range resRefs { - if errCalled(resRef) { + if r.errCalled(resRef) { return false } } @@ -294,7 +325,7 @@ func (r *runner) calledInFunc(f *ssa.Function, called bool) bool { } } default: - if r.errCallMissing(b, i) || !called { + if r.errCallMissing(b.Instrs[i]) || !called { return false } } diff --git a/passes/rowserr/testdata/src/a/a.go b/passes/rowserr/testdata/src/a/a.go index 8e47425..a7b88b0 100644 --- a/passes/rowserr/testdata/src/a/a.go +++ b/passes/rowserr/testdata/src/a/a.go @@ -98,3 +98,41 @@ func f10() { } resCloser(rows) } + +func standAloneCloser(rs *sql.Rows) { + _ = rs.Err() +} + +func f11() { + rows, _ := db.Query("") // OK + defer standAloneCloser(rows) +} + +func f12() { + rows, _ := db.Query("") // OK + defer func(rows *sql.Rows) { + standAloneCloser(rows) + }(rows) +} + +func returningCloser(rs *sql.Rows) error { + return rs.Err() +} + +func f13() (err error) { + rows, _ := db.Query("") // OK + defer func(rows *sql.Rows) { + returningCloser(rows) + }(rows) + + return err +} + +func f14() (err error) { + rows, _ := db.Query("") // OK + defer func(rows *sql.Rows) { + err = returningCloser(rows) + }(rows) + + return err +}