Skip to content

Commit

Permalink
follow sql.Rows across ssa.UnOps (#17)
Browse files Browse the repository at this point in the history
* rename noCallClose and isCloseCall

These functions look for Err() calls. Now the naming reflects that a
bit better.

* rename getResVal to getRows val

* rename getReqCall to getCallReturnsRow

Signed-off-by: Steven Danna <StevenSDanna@gmail.com>

* rename notCheck to errCallMissing

* remove unused isNamedType function

* reverse logic from errNotCalled to errCalled

I think that this structuring will be a bit eaiser to follow for
future changes.

* remove ineffective append

We only range over this slice once and the append doesn't impact the
range.

* follows sql.Rows across ssa.UnOps

Previously, code such as the following:

```go
rows, err := db.Query("select 1")
if err != nil {
	return err
}
defer func() { _ = rows.Close() }()
for rows.Next() {}
return rows.Err()
```

Would produce a warning despite the clear call to rows.Err(). It
appears this is cause by the fact that the closure around rows.Close()
induces a level of indirection that we could not follow:

```
func issue16() error:
0:                                                                entry P:0 S:2
	t0 = new *database/sql.Rows (rows)                  **database/sql.Rows
	t1 = *db                                               *database/sql.DB
	t2 = (*database/sql.DB).Query(t1, "select 1":string, nil:[]interface{}...) (*database/sql.Rows, error)
	t3 = extract t2 #0                                   *database/sql.Rows
	*t0 = t3
	t4 = extract t2 #1                                                error
	t5 = t4 != nil:error                                               bool
	if t5 goto 1 else 2
1:                                                              if.then P:1 S:0
	rundefers
	return t4
2:                                                              if.done P:1 S:1
	t6 = make closure issue16$1 [t0]                                 func()
	defer t6()
	jump 5
3:                                                              recover P:0 S:0
	return nil:error
4:                                                             for.done P:1 S:0
	t7 = *t0                                             *database/sql.Rows
	t8 = (*database/sql.Rows).Err(t7)                                 error
	rundefers
	return t8
5:                                                             for.loop P:2 S:2
	t9 = *t0                                             *database/sql.Rows
	t10 = (*database/sql.Rows).Next(t9)                                bool
	if t10 goto 5 else 4
```

Fixes #16

* bugfix: fix bad return in run method

```
passes/rowserr/rowserr.go:67:3: too many arguments to return
	have (nil, nil)
	want ()
```
  • Loading branch information
stevendanna authored Mar 15, 2021
1 parent 080ff0b commit d907ca7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 36 deletions.
74 changes: 38 additions & 36 deletions passes/rowserr/rowserr.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (r runner) run(pass *analysis.Pass, pkgPath string) {
rowsType := pkg.Type(rowsName)
if rowsType == nil {
// skip checking
return nil, nil
return
}

r.rowsObj = rowsType.Object()
Expand Down Expand Up @@ -102,66 +102,69 @@ func (r runner) run(pass *analysis.Pass, pkgPath string) {

for _, b := range f.Blocks {
for i := range b.Instrs {
if r.notCheck(b, i) {
if r.errCallMissing(b, i) {
pass.Reportf(b.Instrs[i].Pos(), "rows.Err must be checked")
}
}
}
}
}

func (r *runner) notCheck(b *ssa.BasicBlock, i int) (ret bool) {
call, ok := r.getReqCall(b.Instrs[i])
func (r *runner) errCallMissing(b *ssa.BasicBlock, i int) (ret bool) {
call, ok := r.getCallReturnsRow(b.Instrs[i])
if !ok {
return false
}

for _, cRef := range *call.Referrers() {
val, ok := r.getResVal(cRef)
val, ok := r.getRowsVal(cRef)
if !ok {
continue
}
if len(*val.Referrers()) == 0 {
return true
continue
}

resRefs := *val.Referrers()
var notCallClose func(resRef ssa.Instruction) bool
notCallClose = func(resRef ssa.Instruction) bool {
var errCalled func(resRef ssa.Instruction) bool
errCalled = func(resRef ssa.Instruction) bool {
switch resRef := resRef.(type) {
case *ssa.Phi:
resRefs = append(resRefs, *resRef.Referrers()...)
for _, rf := range *resRef.Referrers() {
if !notCallClose(rf) {
return false
if errCalled(rf) {
return true
}
}

case *ssa.Store: // Call in Closure function
if len(*resRef.Addr.Referrers()) == 0 {
return true
}

for _, aref := range *resRef.Addr.Referrers() {
if c, ok := aref.(*ssa.MakeClosure); ok {
switch c := aref.(type) {
case *ssa.MakeClosure:
f := c.Fn.(*ssa.Function)
if r.noImportedDBSQL(f) {
// skip this
return false
continue
}
called := r.isClosureCalled(c)

return r.calledInFunc(f, called)
if r.calledInFunc(f, called) {
return true
}
case *ssa.UnOp:
for _, rf := range *c.Referrers() {
if errCalled(rf) {
return true
}
}
}
}
case *ssa.Call: // Indirect function call
if r.isCloseCall(resRef) {
return false
if r.isErrCall(resRef) {
return true
}
if f, ok := resRef.Call.Value.(*ssa.Function); ok {
for _, b := range f.Blocks {
for i := range b.Instrs {
return r.notCheck(b, i)
if !r.errCallMissing(b, i) {
return true
}
}
}
}
Expand All @@ -173,18 +176,18 @@ func (r *runner) notCheck(b *ssa.BasicBlock, i int) (ret bool) {
}

for _, ccall := range *bOp.Referrers() {
if r.isCloseCall(ccall) {
return false
if r.isErrCall(ccall) {
return true
}
}
}
}

return true
return false
}

for _, resRef := range resRefs {
if !notCallClose(resRef) {
if errCalled(resRef) {
return false
}
}
Expand All @@ -193,7 +196,7 @@ func (r *runner) notCheck(b *ssa.BasicBlock, i int) (ret bool) {
return true
}

func (r *runner) getReqCall(instr ssa.Instruction) (*ssa.Call, bool) {
func (r *runner) getCallReturnsRow(instr ssa.Instruction) (*ssa.Call, bool) {
call, ok := instr.(*ssa.Call)
if !ok {
return nil, false
Expand All @@ -213,7 +216,7 @@ func (r *runner) getReqCall(instr ssa.Instruction) (*ssa.Call, bool) {
return call, true
}

func (r *runner) getResVal(instr ssa.Instruction) (ssa.Value, bool) {
func (r *runner) getRowsVal(instr ssa.Instruction) (ssa.Value, bool) {
switch instr := instr.(type) {
case *ssa.Call:
if len(instr.Call.Args) == 1 && types.Identical(instr.Call.Args[0].Type(), r.rowsTyp) {
Expand Down Expand Up @@ -241,7 +244,7 @@ func (r *runner) getBodyOp(instr ssa.Instruction) (*ssa.UnOp, bool) {
return op, true
}

func (r *runner) isCloseCall(ccall ssa.Instruction) bool {
func (r *runner) isErrCall(ccall ssa.Instruction) bool {
switch ccall := ccall.(type) {
case *ssa.Defer:
if ccall.Call.Value != nil && ccall.Call.Value.Name() == errMethod {
Expand Down Expand Up @@ -311,19 +314,18 @@ func (r *runner) calledInFunc(f *ssa.Function, called bool) bool {
if vCall, ok := v.(*ssa.Call); ok {
if vCall.Call.Value != nil && vCall.Call.Value.Name() == errMethod {
if called {
return false
return true
}
}
}
}
}
default:
if r.notCheck(b, i) || !called {
return true
if r.errCallMissing(b, i) || !called {
return false
}
}
}
}

return true
return false
}
12 changes: 12 additions & 0 deletions passes/rowserr/testdata/src/a/issue16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package a

func issue16() error {
rows, err := db.Query("select 1")
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
}
return rows.Err()
}

0 comments on commit d907ca7

Please sign in to comment.