Skip to content

Commit

Permalink
feat: do not report types implementing sql.Scanner
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane committed Nov 10, 2024
1 parent 3799ac8 commit 974874b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 145 deletions.
168 changes: 51 additions & 117 deletions builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,131 +3,65 @@ package musttag
// builtins is a set of functions supported out of the box.
var builtins = []Func{
// https://pkg.go.dev/encoding/json
{
Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1,
ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/json.Decoder).Decode", Tag: "json", ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"},
},
{"encoding/json.Marshal", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}},
{"encoding/json.MarshalIndent", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}},
{"encoding/json.Unmarshal", "json", 1, []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}},
{"(*encoding/json.Encoder).Encode", "json", 0, []string{"encoding/json.Marshaler", "encoding.TextMarshaler"}},
{"(*encoding/json.Decoder).Decode", "json", 0, []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"}},

// https://pkg.go.dev/encoding/xml
{
Name: "encoding/xml.Marshal", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/xml.MarshalIndent", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/xml.Unmarshal", Tag: "xml", ArgPos: 1,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/xml.Encoder).Encode", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/xml.Decoder).Decode", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/xml.Encoder).EncodeElement", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/xml.Decoder).DecodeElement", Tag: "xml", ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{"encoding/xml.Marshal", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}},
{"encoding/xml.MarshalIndent", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}},
{"encoding/xml.Unmarshal", "xml", 1, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"(*encoding/xml.Encoder).Encode", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}},
{"(*encoding/xml.Decoder).Decode", "xml", 0, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"(*encoding/xml.Encoder).EncodeElement", "xml", 0, []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"}},
{"(*encoding/xml.Decoder).DecodeElement", "xml", 0, []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"}},

// https://pkg.go.dev/gopkg.in/yaml.v3
{
Name: "gopkg.in/yaml.v3.Marshal", Tag: "yaml", ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"},
},
{
Name: "gopkg.in/yaml.v3.Unmarshal", Tag: "yaml", ArgPos: 1,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"},
},
{
Name: "(*gopkg.in/yaml.v3.Encoder).Encode", Tag: "yaml", ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"},
},
{
Name: "(*gopkg.in/yaml.v3.Decoder).Decode", Tag: "yaml", ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"},
},
{"gopkg.in/yaml.v3.Marshal", "yaml", 0, []string{"gopkg.in/yaml.v3.Marshaler"}},
{"gopkg.in/yaml.v3.Unmarshal", "yaml", 1, []string{"gopkg.in/yaml.v3.Unmarshaler"}},
{"(*gopkg.in/yaml.v3.Encoder).Encode", "yaml", 0, []string{"gopkg.in/yaml.v3.Marshaler"}},
{"(*gopkg.in/yaml.v3.Decoder).Decode", "yaml", 0, []string{"gopkg.in/yaml.v3.Unmarshaler"}},

// https://pkg.go.dev/github.com/BurntSushi/toml
{
Name: "github.com/BurntSushi/toml.Unmarshal", Tag: "toml", ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.Decode", Tag: "toml", ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.DecodeFS", Tag: "toml", ArgPos: 2,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.DecodeFile", Tag: "toml", ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*github.com/BurntSushi/toml.Encoder).Encode", Tag: "toml", ArgPos: 0,
ifaceWhitelist: []string{"encoding.TextMarshaler"},
},
{
Name: "(*github.com/BurntSushi/toml.Decoder).Decode", Tag: "toml", ArgPos: 0,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{"github.com/BurntSushi/toml.Unmarshal", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"github.com/BurntSushi/toml.Decode", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"github.com/BurntSushi/toml.DecodeFS", "toml", 2, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"github.com/BurntSushi/toml.DecodeFile", "toml", 1, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}},
{"(*github.com/BurntSushi/toml.Encoder).Encode", "toml", 0, []string{"encoding.TextMarshaler"}},
{"(*github.com/BurntSushi/toml.Decoder).Decode", "toml", 0, []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"}},

// https://pkg.go.dev/github.com/mitchellh/mapstructure
{Name: "github.com/mitchellh/mapstructure.Decode", Tag: "mapstructure", ArgPos: 1},
{Name: "github.com/mitchellh/mapstructure.DecodeMetadata", Tag: "mapstructure", ArgPos: 1},
{Name: "github.com/mitchellh/mapstructure.WeakDecode", Tag: "mapstructure", ArgPos: 1},
{Name: "github.com/mitchellh/mapstructure.WeakDecodeMetadata", Tag: "mapstructure", ArgPos: 1},
{"github.com/mitchellh/mapstructure.Decode", "mapstructure", 1, nil},
{"github.com/mitchellh/mapstructure.DecodeMetadata", "mapstructure", 1, nil},
{"github.com/mitchellh/mapstructure.WeakDecode", "mapstructure", 1, nil},
{"github.com/mitchellh/mapstructure.WeakDecodeMetadata", "mapstructure", 1, nil},

// https://pkg.go.dev/github.com/jmoiron/sqlx
{Name: "github.com/jmoiron/sqlx.Get", Tag: "db", ArgPos: 1},
{Name: "github.com/jmoiron/sqlx.GetContext", Tag: "db", ArgPos: 2},
{Name: "github.com/jmoiron/sqlx.Select", Tag: "db", ArgPos: 1},
{Name: "github.com/jmoiron/sqlx.SelectContext", Tag: "db", ArgPos: 2},
{Name: "github.com/jmoiron/sqlx.StructScan", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Conn).GetContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Conn).SelectContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.DB).Get", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.DB).GetContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.DB).Select", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.DB).SelectContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.NamedStmt).Get", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.NamedStmt).GetContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.NamedStmt).Select", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.NamedStmt).SelectContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Row).StructScan", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Rows).StructScan", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Stmt).Get", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Stmt).GetContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Stmt).Select", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Stmt).SelectContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Tx).Get", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Tx).GetContext", Tag: "db", ArgPos: 1},
{Name: "(*github.com/jmoiron/sqlx.Tx).Select", Tag: "db", ArgPos: 0},
{Name: "(*github.com/jmoiron/sqlx.Tx).SelectContext", Tag: "db", ArgPos: 1},
{"github.com/jmoiron/sqlx.Get", "db", 1, []string{"database/sql.Scanner"}},
{"github.com/jmoiron/sqlx.GetContext", "db", 2, []string{"database/sql.Scanner"}},
{"github.com/jmoiron/sqlx.Select", "db", 1, []string{"database/sql.Scanner"}},
{"github.com/jmoiron/sqlx.SelectContext", "db", 2, []string{"database/sql.Scanner"}},
{"github.com/jmoiron/sqlx.StructScan", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Conn).GetContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Conn).SelectContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.DB).Get", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.DB).GetContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.DB).Select", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.DB).SelectContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.NamedStmt).Get", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.NamedStmt).GetContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.NamedStmt).Select", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.NamedStmt).SelectContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Row).StructScan", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Rows).StructScan", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Stmt).Get", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Stmt).GetContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Stmt).Select", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Stmt).SelectContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Tx).Get", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Tx).GetContext", "db", 1, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Tx).Select", "db", 0, []string{"database/sql.Scanner"}},
{"(*github.com/jmoiron/sqlx.Tx).SelectContext", "db", 1, []string{"database/sql.Scanner"}},
}
48 changes: 20 additions & 28 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,

call, ok := node.(*ast.CallExpr)
if !ok {
return // not a function call.
return
}

callee := typeutil.StaticCallee(pass.TypesInfo, call)
if callee == nil {
return // not a static call.
return
}

fn, ok := funcs[cutVendor(callee.FullName())]
if !ok {
return // unsupported function.
return
}

if len(call.Args) <= fn.ArgPos {
Expand All @@ -116,7 +116,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,

typ := pass.TypesInfo.TypeOf(arg)
if typ == nil {
return // no type info found.
return
}

checker := checker{
Expand All @@ -125,9 +125,8 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
ifaceWhitelist: fn.ifaceWhitelist,
imports: pass.Pkg.Imports(),
}

if valid := checker.checkType(typ, fn.Tag); valid {
return // nothing to report.
if checker.isValidType(typ, fn.Tag) {
return
}

pass.Reportf(arg.Pos(), "the given struct should be annotated with the `%s` tag", fn.Tag)
Expand All @@ -143,43 +142,32 @@ type checker struct {
imports []*types.Package
}

func (c *checker) checkType(typ types.Type, tag string) bool {
func (c *checker) isValidType(typ types.Type, tag string) bool {
if _, ok := c.seenTypes[typ.String()]; ok {
return true // already checked.
return true
}
c.seenTypes[typ.String()] = struct{}{}

styp, ok := c.parseStruct(typ)
if !ok {
return true // not a struct.
return true
}

return c.checkStruct(styp, tag)
return c.isValidStruct(styp, tag)
}

// recursively unwrap a type until we get to an underlying
// raw struct type that should have its fields checked
//
// SomeStruct -> struct{SomeStructField: ... }
// []*SomeStruct -> struct{SomeStructField: ... }
// ...
//
// exits early if it hits a type that implements a whitelisted interface
func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
if implementsInterface(typ, c.ifaceWhitelist, c.imports) {
return nil, false // the type implements a Marshaler interface; see issue #64.
return nil, false
}

switch typ := typ.(type) {
case *types.Pointer:
return c.parseStruct(typ.Elem())

case *types.Array:
return c.parseStruct(typ.Elem())

case *types.Slice:
return c.parseStruct(typ.Elem())

case *types.Map:
return c.parseStruct(typ.Elem())

Expand All @@ -205,7 +193,7 @@ func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
}
}

func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) {
func (c *checker) isValidStruct(styp *types.Struct, tag string) bool {
for i := 0; i < styp.NumFields(); i++ {
field := styp.Field(i)
if !field.Exported() {
Expand All @@ -214,18 +202,18 @@ func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) {

tagValue, ok := reflect.StructTag(styp.Tag(i)).Lookup(tag)
if !ok {
// tag is not required for embedded types; see issue #12.
// tag is not required for embedded types.
if !field.Embedded() {
return false
}
}

// Do not recurse into ignored fields.
// the field is explicitly ignored.
if tagValue == "-" {
continue
}

if valid := c.checkType(field.Type(), tag); !valid {
if !c.isValidType(field.Type(), tag) {
return false
}
}
Expand Down Expand Up @@ -254,25 +242,29 @@ func implementsInterface(typ types.Type, ifaces []string, imports []*types.Packa
}

for _, ifacePath := range ifaces {
// "encoding/json.Marshaler" -> "encoding/json" + "Marshaler"
// e.g. "encoding/json.Marshaler" -> "encoding/json" + "Marshaler".
idx := strings.LastIndex(ifacePath, ".")
if idx == -1 {
continue
}

pkgName, ifaceName := ifacePath[:idx], ifacePath[idx+1:]

scope, ok := findScope(pkgName)
if !ok {
continue
}

obj := scope.Lookup(ifaceName)
if obj == nil {
continue
}

iface, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
continue
}

if types.Implements(typ, iface) || types.Implements(types.NewPointer(typ), iface) {
return true
}
Expand Down
31 changes: 31 additions & 0 deletions testdata/src/tests/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type TextMarshaler struct{ NoTag string }
func (TextMarshaler) MarshalText() ([]byte, error) { return nil, nil }
func (*TextMarshaler) UnmarshalText([]byte) error { return nil }

type Scanner struct{ NotTag string }

func (*Scanner) Scan(any) error { return nil }

func testJSON() {
var st Struct
json.Marshal(st) // want "the given struct should be annotated with the `json` tag"
Expand Down Expand Up @@ -154,6 +158,33 @@ func testSQLX() {
new(sqlx.Tx).GetContext(nil, &st, "") // want "the given struct should be annotated with the `db` tag"
new(sqlx.Tx).Select(&st, "") // want "the given struct should be annotated with the `db` tag"
new(sqlx.Tx).SelectContext(nil, &st, "") // want "the given struct should be annotated with the `db` tag"

var sc Scanner
sqlx.Get(nil, &sc, "")
sqlx.GetContext(nil, nil, &sc, "")
sqlx.Select(nil, &sc, "")
sqlx.SelectContext(nil, nil, &sc, "")
sqlx.StructScan(nil, &sc)
new(sqlx.Conn).GetContext(nil, &sc, "")
new(sqlx.Conn).SelectContext(nil, &sc, "")
new(sqlx.DB).Get(&sc, "")
new(sqlx.DB).GetContext(nil, &sc, "")
new(sqlx.DB).Select(&sc, "")
new(sqlx.DB).SelectContext(nil, &sc, "")
new(sqlx.NamedStmt).Get(&sc, nil)
new(sqlx.NamedStmt).GetContext(nil, &sc, nil)
new(sqlx.NamedStmt).Select(&sc, nil)
new(sqlx.NamedStmt).SelectContext(nil, &sc, nil)
new(sqlx.Row).StructScan(&sc)
new(sqlx.Rows).StructScan(&sc)
new(sqlx.Stmt).Get(&sc)
new(sqlx.Stmt).GetContext(nil, &sc)
new(sqlx.Stmt).Select(&sc)
new(sqlx.Stmt).SelectContext(nil, &sc)
new(sqlx.Tx).Get(&sc, "")
new(sqlx.Tx).GetContext(nil, &sc, "")
new(sqlx.Tx).Select(&sc, "")
new(sqlx.Tx).SelectContext(nil, &sc, "")
}

func testCustom() {
Expand Down

0 comments on commit 974874b

Please sign in to comment.