diff --git a/builtins.go b/builtins.go index 66914fa..74f8df1 100644 --- a/builtins.go +++ b/builtins.go @@ -3,7 +3,7 @@ package musttag // builtins is a set of functions supported out of the box. var builtins = []Func{ // https://pkg.go.dev/encoding/json - {Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0}, + {Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0, ifaceWhitelist: []string{"Marshaler"}}, {Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0}, {Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1}, {Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0}, diff --git a/musttag.go b/musttag.go index 7f4e05e..1bd462e 100644 --- a/musttag.go +++ b/musttag.go @@ -24,6 +24,10 @@ type Func struct { Name string // Name is the full name of the function, including the package. Tag string // Tag is the struct tag whose presence should be ensured. ArgPos int // ArgPos is the position of the argument to check. + + // a list of interfaces from the same package; + // if at least one is implemented by the argument, no check is performed. + ifaceWhitelist []string } func (fn Func) shortName() string { @@ -31,6 +35,14 @@ func (fn Func) shortName() string { return path.Base(name) } +func (fn Func) pkgPath() string { + name := strings.NewReplacer("*", "", "(", "", ")", "").Replace(fn.Name) + if idx := strings.LastIndex(name, "."); idx != -1 { + return name[:idx] + } + return "" +} + // New creates a new musttag analyzer. // To report a custom function provide its description via Func, // it will be added to the builtin ones. @@ -144,13 +156,37 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er initialPos = arg.Pos() } + argType := pass.TypesInfo.TypeOf(arg) + if argType == nil { + return // no type info found. + } + + for _, pkg := range pass.Pkg.Imports() { + if pkg.Path() != fn.pkgPath() { + continue + } + for _, ifaceName := range fn.ifaceWhitelist { + obj := pkg.Scope().Lookup(ifaceName) + if obj == nil { + continue + } + iface, ok := obj.Type().Underlying().(*types.Interface) + if !ok { + continue + } + if types.Implements(argType, iface) { + return // the argument implements an (Un)Marshaler interface, no need to check; see issue #64. + } + } + break + } + checker := checker{ mainModule: mainModule, seenTypes: make(map[string]struct{}), } - t := pass.TypesInfo.TypeOf(arg) - st, ok := checker.parseStructType(t, initialPos) + st, ok := checker.parseStructType(argType, initialPos) if !ok { return // not a struct argument. } diff --git a/testdata/src/tests/tests.go b/testdata/src/tests/tests.go index 93c8b89..b4b3a98 100644 --- a/testdata/src/tests/tests.go +++ b/testdata/src/tests/tests.go @@ -94,3 +94,13 @@ func nothingToReport() { json.NewEncoder(nil).Encode(Foo{}) json.NewDecoder(nil).Decode(&Foo{}) } + +type marshaler struct{} + +func (marshaler) MarshalJSON() ([]byte, error) { return nil, nil } + +func implementsInterface() { + var m marshaler + json.Marshal(m) + json.Marshal(marshaler{}) +}