Skip to content

Commit

Permalink
Merge pull request #11 from lfportal/default-signifies-exhaustive
Browse files Browse the repository at this point in the history
Add `default-signifies-exhasutive` flag to prevent default clause from automatically passing checks
  • Loading branch information
alecthomas authored Sep 24, 2024
2 parents 187668c + 6c750dd commit 42dc6df
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 21 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing
```

Adding either a `default` clause or a clause to handle `*VariantB` will cause
exhaustive checks to pass.
exhaustive checks to pass. To prevent `default` clauses from automatically
passing checks, set the `-default-signifies-exhasutive=false` flag.

As a special case, if the type switch statement contains a `default` clause
that always panics, then exhaustiveness checks are still performed.
Expand Down
10 changes: 6 additions & 4 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ func (e inexhaustiveError) Names() []string {

// check does exhaustiveness checking for the given sum type definitions in the
// given package. Every instance of inexhaustive case analysis is returned.
func check(pkg *packages.Package, defs []sumTypeDef) []error {
func check(pkg *packages.Package, defs []sumTypeDef, config Config) []error {
var errs []error
for _, astfile := range pkg.Syntax {
ast.Inspect(astfile, func(n ast.Node) bool {
swtch, ok := n.(*ast.TypeSwitchStmt)
if !ok {
return true
}
if err := checkSwitch(pkg, defs, swtch); err != nil {
if err := checkSwitch(pkg, defs, swtch, config); err != nil {
errs = append(errs, err)
}
return true
Expand All @@ -67,8 +67,9 @@ func checkSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
config Config,
) error {
def, missing := missingVariantsInSwitch(pkg, defs, swtch)
def, missing := missingVariantsInSwitch(pkg, defs, swtch, config)
if len(missing) > 0 {
return inexhaustiveError{
Position: pkg.Fset.Position(swtch.Pos()),
Expand All @@ -87,6 +88,7 @@ func missingVariantsInSwitch(
pkg *packages.Package,
defs []sumTypeDef,
swtch *ast.TypeSwitchStmt,
config Config,
) (*sumTypeDef, []types.Object) {
asserted := findTypeAssertExpr(swtch)
ty := pkg.TypesInfo.TypeOf(asserted)
Expand All @@ -97,7 +99,7 @@ func missingVariantsInSwitch(
return nil, nil
}
variantExprs, hasDefault := switchVariants(swtch)
if hasDefault && !defaultClauseAlwaysPanics(swtch) {
if config.DefaultSignifiesExhaustive && hasDefault && !defaultClauseAlwaysPanics(swtch) {
// A catch-all case defeats all exhaustiveness checks.
return def, nil
}
Expand Down
53 changes: 42 additions & 11 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ func main() {
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 1, len(errs))
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
}

// TestMissingTwo tests that we detect a two missing variants.
// TestMissingTwo tests that we detect two missing variants.
func TestMissingTwo(t *testing.T) {
code := `
package gochecksumtype
Expand All @@ -60,7 +60,7 @@ func main() {
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 1, len(errs))
assert.Equal(t, []string{"B", "C"}, missingNames(t, errs[0]))
}
Expand Down Expand Up @@ -91,7 +91,7 @@ func main() {
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 1, len(errs))
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
}
Expand Down Expand Up @@ -122,13 +122,13 @@ func main() {
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 0, len(errs))
}

// TestNoMissingDefault tests that even if we have a missing variant, a default
// case should thwart exhaustiveness checking.
func TestNoMissingDefault(t *testing.T) {
// TestNoMissingDefaultWithDefaultSignifiesExhaustive tests that even if we have a missing variant, a default
// case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is true.
func TestNoMissingDefaultWithDefaultSignifiesExhaustive(t *testing.T) {
code := `
package gochecksumtype
Expand All @@ -152,10 +152,41 @@ func main() {
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 0, len(errs))
}

// TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive tests that even if we have a missing variant, a default
// case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is false.
func TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive(t *testing.T) {
code := `
package gochecksumtype
//sumtype:decl
type T interface { sealed() }
type A struct {}
func (a *A) sealed() {}
type B struct {}
func (b *B) sealed() {}
func main() {
switch T(nil).(type) {
case *A:
default:
println("legit catch all goes here")
}
}
`
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs, Config{DefaultSignifiesExhaustive: false})
assert.Equal(t, 1, len(errs))
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
}

// TestNotSealed tests that we report an error if one tries to declare a sum
// type with an unsealed interface.
func TestNotSealed(t *testing.T) {
Expand All @@ -170,7 +201,7 @@ func main() {}
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 1, len(errs))
assert.Equal(t, "T", errs[0].(unsealedError).Decl.TypeName)
}
Expand All @@ -189,7 +220,7 @@ func main() {}
tmpdir, pkgs := setupPackages(t, code)
defer teardownPackage(t, tmpdir)

errs := Run(pkgs)
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
assert.Equal(t, 1, len(errs))
assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName)
}
Expand Down
17 changes: 14 additions & 3 deletions cmd/go-check-sumtype/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@ import (

func main() {
log.SetFlags(0)

defaultSignifiesExhaustive := flag.Bool(
"default-signifies-exhaustive",
true,
"Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.",
)

flag.Parse()
if len(flag.Args()) < 1 {
if flag.NArg() < 1 {
log.Fatalf("Usage: sumtype <packages>\n")
}
args := os.Args[1:]
args := os.Args[flag.NFlag()+1:]

config := gochecksumtype.Config{
DefaultSignifiesExhaustive: *defaultSignifiesExhaustive,
}

conf := &packages.Config{
Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes |
Expand All @@ -37,7 +48,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
if errs := gochecksumtype.Run(pkgs); len(errs) > 0 {
if errs := gochecksumtype.Run(pkgs, config); len(errs) > 0 {
var list []string
for _, err := range errs {
list = append(list, err.Error())
Expand Down
5 changes: 5 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package gochecksumtype

type Config struct {
DefaultSignifiesExhaustive bool
}
4 changes: 2 additions & 2 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gochecksumtype
import "golang.org/x/tools/go/packages"

// Run sumtype checking on the given packages.
func Run(pkgs []*packages.Package) []error {
func Run(pkgs []*packages.Package, config Config) []error {
var errs []error

decls, err := findSumTypeDecls(pkgs)
Expand All @@ -18,7 +18,7 @@ func Run(pkgs []*packages.Package) []error {
}

for _, pkg := range pkgs {
if pkgErrs := check(pkg, defs); pkgErrs != nil {
if pkgErrs := check(pkg, defs, config); pkgErrs != nil {
errs = append(errs, pkgErrs...)
}
}
Expand Down

0 comments on commit 42dc6df

Please sign in to comment.