Skip to content

Commit

Permalink
Add check for wrong error assertion
Browse files Browse the repository at this point in the history
Check this patterns:
`Expect(err).To(Equal(nil))`
`Expect(err).To(BeNil())`
`Expect(err == nil).To(BeTrue/BeFalse())`
`Expect(err == nil).To(Equal(true/false))`

Support new suppress comment: `ginkgo-linter:ignore-err-assert-warning`

Support new parameter: `Suppress-err-assertion`
  • Loading branch information
nunnatsa committed Jul 18, 2022
1 parent 4122b19 commit 3af492b
Show file tree
Hide file tree
Showing 14 changed files with 539 additions and 37 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/sanity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ jobs:

- name: Functional Test
run: |-
set -x
cp ginkgolinter testdata/src/a
cd testdata/src/a
[[ $(./ginkgolinter ./... 2>&1 | wc -l) == 1969 ]]
[[ $(./ginkgolinter --suppress-len-assertion=true ./... 2>&1 | wc -l) == 1355 ]]
[[ $(./ginkgolinter --suppress-nil-assertion=true ./... 2>&1 | wc -l) == 614 ]]
[[ $(./ginkgolinter --suppress-nil-assertion=true --suppress-len-assertion=true ./... 2>&1 | wc -l) == 0 ]]
[[ $(./ginkgolinter ./... 2>&1 | wc -l) == 2071 ]]
[[ $(./ginkgolinter --suppress-len-assertion=true --suppress-err-assertion=true ./... 2>&1 | wc -l) == 1360 ]]
[[ $(./ginkgolinter --suppress-nil-assertion=true --suppress-err-assertion=true ./... 2>&1 | wc -l) == 614 ]]
[[ $(./ginkgolinter --suppress-nil-assertion=true --suppress-len-assertion=true ./... 2>&1 | wc -l) == 97 ]]
[[ $(./ginkgolinter --suppress-nil-assertion=true --suppress-len-assertion=true --suppress-err-assertion=true ./... 2>&1 | wc -l) == 0 ]]
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,28 @@ Or even (double negative):

`Ω(x != nil).Should(Not(BeTrue()))` => `Ω(x).Should(BeNil())`

### Wrong Error Assertion
The linter finds assertion of errors compared with nil, or to be equal nil, or to be nil. The linter suggests to use `Succeed` for functions or `HaveOccurred` for error values..

There are several wrong patterns:

```go
Expect(err == nil).To(Equal(true)) // should be: Expect(err).ToNot(HaveOccurred())
Expect(err == nil).To(BeFalse()) // should be: Expect(err).To(HaveOccurred())
Expect(err != nil).To(BeTrue()) // should be: Expect(err).To(HaveOccurred())
Expect(funcReturnsError()).To(BeNil()) // should be: Expect(HaveOccurred).To(Succeed())

and so on
```
It also supports the embedded `Not()` matcher; e.g.

`Ω(err == nil).Should(Not(BeTrue()))` => `Ω(x).Should(HaveOccurred())`

## Suppress the linter
### Suppress warning from command line
* Use the `suppress-len-assertion=true` flag to suppress the wrong length assertion warning
* Use the `suppress-nil-assertion=true` flag to suppress the wrong nil assertion warning
* Use the `--suppress-len-assertion=true` flag to suppress the wrong length assertion warning
* Use the `--suppress-nil-assertion=true` flag to suppress the wrong nil assertion warning
* Use the `--suppress-err-assertion=true` flag to suppress the wrong error assertion warning

### Suppress warning from the code
To suppress the wrong length assertion warning, add a comment with (only)
Expand All @@ -94,6 +112,10 @@ To suppress the wrong nil assertion warning, add a comment with (only)

`ginkgo-linter:ignore-nil-assert-warning`.

To suppress the wrong error assertion warning, add a comment with (only)

`ginkgo-linter:ignore-err-assert-warning`.

There are two options to use these comments:
1. If the comment is at the top of the file, supress the warning for the whole file; e.g.:
```go
Expand Down
161 changes: 139 additions & 22 deletions ginkgo_linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"go/ast"
"go/printer"
"go/token"
gotypes "go/types"
"golang.org/x/tools/go/analysis"

"github.com/nunnatsa/ginkgolinter/gomegahandler"
Expand Down Expand Up @@ -40,9 +41,14 @@ const (
linterName = "ginkgo-linter"
wrongLengthWarningTemplate = linterName + ": wrong length assertion; consider using `%s` instead"
wrongNilWarningTemplate = linterName + ": wrong nil assertion; consider using `%s` instead"
wrongErrWarningTemplate = linterName + ": wrong error assertion; consider using `%s` instead"
beEmpty = "BeEmpty"
beNil = "BeNil"
equal = "Equal"
not = "Not"
haveLen = "HaveLen"
succeed = "Succeed"
haveOccurred = "HaveOccurred"
expect = "Expect"
omega = "Ω"
expectWithOffset = "ExpectWithOffset"
Expand Down Expand Up @@ -82,14 +88,13 @@ This should be replaced with:
}

a.Flags.Init("ginkgolinter", flag.ExitOnError)
a.Flags.Var(&linter.suppress.Len, "Suppress-len-assertion", "Suppress warning for wrong length assertions")
a.Flags.Var(&linter.suppress.Nil, "Suppress-nil-assertion", "Suppress warning for wrong nil assertions")
a.Flags.Var(&linter.suppress.Len, "suppress-len-assertion", "Suppress warning for wrong length assertions")
a.Flags.Var(&linter.suppress.Nil, "suppress-nil-assertion", "Suppress warning for wrong nil assertions")
a.Flags.Var(&linter.suppress.Err, "suppress-err-assertion", "Suppress warning for wrong error assertions")

return a
}

//var Analyzer = NewAnalyzer()

// main assertion function
func (l *ginkgoLinter) run(pass *analysis.Pass) (interface{}, error) {
if l.suppress.AllTrue() {
Expand Down Expand Up @@ -152,15 +157,25 @@ func checkExpression(pass *analysis.Pass, exprSuppress types.Suppress, actualArg
if !bool(exprSuppress.Len) && isActualIsLenFunc(actualArg) {

return checkLengthMatcher(assertionExp, pass, handler, oldExpr)
} else if !exprSuppress.Nil {
} else {
if nilable, compOp := getNilableFromComparison(actualArg); nilable != nil {
if IsExprError(pass, nilable) {
if exprSuppress.Err {
return true
}
} else if exprSuppress.Nil {
return true
}

return checkNilMatcher(assertionExp, pass, nilable, handler, compOp == token.NEQ, oldExpr)

} else if IsExprError(pass, actualArg) {
return bool(exprSuppress.Err) || checkNilError(pass, assertionExp, handler, actualArg, oldExpr)

} else {
return checkEqualNil(pass, assertionExp, handler, oldExpr)
return bool(exprSuppress.Nil) || checkEqualNil(pass, assertionExp, handler, actualArg, oldExpr)
}

}
return true
}

// Check if the "actual" argument is a call to the golang built-in len() function
Expand All @@ -187,7 +202,7 @@ func checkLengthMatcher(exp *ast.CallExpr, pass *analysis.Pass, handler gomegaha
}

switch matcherFuncName {
case "Equal":
case equal:
handleEqualMatcher(matcher, pass, exp, handler, oldExp)
return false

Expand All @@ -198,7 +213,7 @@ func checkLengthMatcher(exp *ast.CallExpr, pass *analysis.Pass, handler gomegaha
case "BeNumerically":
return handleBeNumerically(matcher, pass, exp, handler, oldExp)

case "Not":
case not:
reverseAssertionFuncLogic(exp)
exp.Args[0] = exp.Args[0].(*ast.CallExpr).Args[0]
return checkLengthMatcher(exp, pass, handler, oldExp)
Expand All @@ -221,7 +236,7 @@ func checkNilMatcher(exp *ast.CallExpr, pass *analysis.Pass, nilable ast.Expr, h
}

switch matcherFuncName {
case "Equal":
case equal:
handleEqualNilMatcher(matcher, pass, exp, handler, nilable, notEqual, oldExp)

case "BeTrue":
Expand All @@ -231,7 +246,7 @@ func checkNilMatcher(exp *ast.CallExpr, pass *analysis.Pass, nilable ast.Expr, h
reverseAssertionFuncLogic(exp)
handleNilBeBoolMatcher(pass, exp, handler, nilable, notEqual, oldExp)

case "Not":
case not:
reverseAssertionFuncLogic(exp)
exp.Args[0] = exp.Args[0].(*ast.CallExpr).Args[0]
return checkNilMatcher(exp, pass, nilable, handler, notEqual, oldExp)
Expand All @@ -242,7 +257,7 @@ func checkNilMatcher(exp *ast.CallExpr, pass *analysis.Pass, nilable ast.Expr, h
return false
}

func checkEqualNil(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gomegahandler.Handler, oldExpr string) bool {
func checkNilError(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gomegahandler.Handler, actualArg ast.Expr, oldExpr string) bool {
if len(assertionExp.Args) == 0 {
return true
}
Expand All @@ -258,7 +273,58 @@ func checkEqualNil(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gome
}

switch funcName {
case "Equal":
case beNil: // no additional processing needed.
case equal:

if len(equalFuncExpr.Args) == 0 {
return true
}

nilable, ok := equalFuncExpr.Args[0].(*ast.Ident)
if !ok || nilable.Name != "nil" {
return true
}

case not:
reverseAssertionFuncLogic(assertionExp)
assertionExp.Args[0] = assertionExp.Args[0].(*ast.CallExpr).Args[0]
return checkNilError(pass, assertionExp, handler, actualArg, oldExpr)
default:
return true
}

var newFuncName string
if _, ok := actualArg.(*ast.CallExpr); ok {
newFuncName = succeed
} else {
reverseAssertionFuncLogic(assertionExp)
newFuncName = haveOccurred
}

handler.ReplaceFunction(equalFuncExpr, ast.NewIdent(newFuncName))
equalFuncExpr.Args = nil

report(pass, assertionExp, wrongErrWarningTemplate, oldExpr)
return false
}

func checkEqualNil(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gomegahandler.Handler, actualArg ast.Expr, oldExpr string) bool {
if len(assertionExp.Args) == 0 {
return true
}

equalFuncExpr, ok := assertionExp.Args[0].(*ast.CallExpr)
if !ok {
return true
}

funcName, ok := handler.GetActualFuncName(equalFuncExpr)
if !ok {
return true
}

switch funcName {
case equal:
if len(equalFuncExpr.Args) == 0 {
return true
}
Expand All @@ -275,10 +341,10 @@ func checkEqualNil(pass *analysis.Pass, assertionExp *ast.CallExpr, handler gome

return false

case "Not":
case not:
reverseAssertionFuncLogic(assertionExp)
assertionExp.Args[0] = assertionExp.Args[0].(*ast.CallExpr).Args[0]
return checkEqualNil(pass, assertionExp, handler, oldExpr)
return checkEqualNil(pass, assertionExp, handler, actualArg, oldExpr)
default:
return true
}
Expand Down Expand Up @@ -427,19 +493,36 @@ func handleEqualNilMatcher(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.
return
}

handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(beNil))
newFuncName, isError := handleNilComparisonErr(pass, exp, nilable)

handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(newFuncName))
exp.Args[0].(*ast.CallExpr).Args = nil

reportNilAssertion(pass, exp, handler, nilable, notEqual, oldExp)
reportNilAssertion(pass, exp, handler, nilable, notEqual, oldExp, isError)
}

func handleNilBeBoolMatcher(pass *analysis.Pass, exp *ast.CallExpr, handler gomegahandler.Handler, nilable ast.Expr, notEqual bool, oldExp string) {
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(beNil))
newFuncName, isError := handleNilComparisonErr(pass, exp, nilable)
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(newFuncName))
exp.Args[0].(*ast.CallExpr).Args = nil

reportNilAssertion(pass, exp, handler, nilable, notEqual, oldExp)
reportNilAssertion(pass, exp, handler, nilable, notEqual, oldExp, isError)
}

func handleNilComparisonErr(pass *analysis.Pass, exp *ast.CallExpr, nilable ast.Expr) (string, bool) {
newFuncName := beNil
isError := IsExprError(pass, nilable)
if isError {
if _, ok := nilable.(*ast.CallExpr); ok {
newFuncName = succeed
} else {
reverseAssertionFuncLogic(exp)
newFuncName = haveOccurred
}
}

return newFuncName, isError
}
func isAssertionFunc(name string) bool {
switch name {
case "To", "ToNot", "NotTo", "Should", "ShouldNot":
Expand All @@ -454,7 +537,7 @@ func reportLengthAssertion(pass *analysis.Pass, expr *ast.CallExpr, handler gome
report(pass, expr, wrongLengthWarningTemplate, oldExpr)
}

func reportNilAssertion(pass *analysis.Pass, expr *ast.CallExpr, handler gomegahandler.Handler, nilable ast.Expr, notEqual bool, oldExpr string) {
func reportNilAssertion(pass *analysis.Pass, expr *ast.CallExpr, handler gomegahandler.Handler, nilable ast.Expr, notEqual bool, oldExpr string, isError bool) {
changed := replaceNilActualArg(expr.Fun.(*ast.SelectorExpr).X.(*ast.CallExpr), handler, nilable)
if !changed {
return
Expand All @@ -463,8 +546,12 @@ func reportNilAssertion(pass *analysis.Pass, expr *ast.CallExpr, handler gomegah
if notEqual {
reverseAssertionFuncLogic(expr)
}
template := wrongNilWarningTemplate
if isError {
template = wrongErrWarningTemplate
}

report(pass, expr, wrongNilWarningTemplate, oldExpr)
report(pass, expr, template, oldExpr)
}

func report(pass *analysis.Pass, expr *ast.CallExpr, messageTemplate, oldExpr string) {
Expand Down Expand Up @@ -514,3 +601,33 @@ func goFmt(fset *token.FileSet, x ast.Expr) string {
_ = printer.Fprint(&b, fset, x)
return b.String()
}

var errorType *gotypes.Interface

func init() {
errorType = gotypes.Universe.Lookup("error").Type().Underlying().(*gotypes.Interface)
}

func IsError(t gotypes.Type) bool {
return gotypes.Implements(t, errorType)
}

func IsExprError(pass *analysis.Pass, expr ast.Expr) bool {
actualArgType := pass.TypesInfo.TypeOf(expr)
switch t := actualArgType.(type) {
case *gotypes.Named:
if IsError(actualArgType) {
return true
}
case *gotypes.Tuple:
if t.Len() > 0 {
switch t0 := t.At(0).Type().(type) {
case *gotypes.Named, *gotypes.Pointer:
if IsError(t0) {
return true
}
}
}
}
return false
}
Loading

0 comments on commit 3af492b

Please sign in to comment.