Skip to content

Commit

Permalink
fix: increase performance (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored Jun 23, 2023
1 parent 31c8451 commit ea68c39
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 226 deletions.
57 changes: 28 additions & 29 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"go/token"
"go/types"
"path"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -43,16 +43,23 @@ func New(funcs ...Func) *analysis.Analyzer {
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: func(pass *analysis.Pass) (any, error) {
l := len(builtins) + len(funcs) + len(flagFuncs)
m := make(map[string]Func, l)
f := make(map[string]Func, l)

toMap := func(slice []Func) {
for _, fn := range slice {
m[fn.Name] = fn
f[fn.Name] = fn
}
}
toMap(builtins)
toMap(funcs)
toMap(flagFuncs)
return run(pass, m)

mainModule, err := getMainModule()
if err != nil {
return nil, err
}

return run(pass, mainModule, f)
},
}
}
Expand Down Expand Up @@ -81,27 +88,16 @@ func flags(funcs *[]Func) flag.FlagSet {
}

// for tests only.
var (
report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s"
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
}
var report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s"
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
}

// HACK: mainModulePackages() does not return packages from `testdata`,
// because it is ignored by the go tool, and thus, by the `go list` command.
// For tests to pass we need to add the packages with tests to the main module manually.
testPackages []string
)
var cleanFullName = regexp.MustCompile(`([^*/(]+/vendor/)`)

// run starts the analysis.
func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
moduleDir, modulePackages, err := mainModule()
if err != nil {
return nil, err
}
for _, pkg := range testPackages {
modulePackages[pkg] = struct{}{}
}
func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, error) {
var err error

walk := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
filter := []ast.Node{(*ast.CallExpr)(nil)}
Expand All @@ -116,12 +112,13 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
return // not a function call.
}

caller := typeutil.StaticCallee(pass.TypesInfo, call)
if caller == nil {
callee := typeutil.StaticCallee(pass.TypesInfo, call)
if callee == nil {
return // not a static call.
}

fn, ok := funcs[caller.FullName()]
name := cleanFullName.ReplaceAllString(callee.FullName(), "")
fn, ok := funcs[name]
if !ok {
return // the function is not supported.
}
Expand All @@ -148,7 +145,7 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
}

checker := checker{
mainModule: modulePackages,
mainModule: mainModule,
seenTypes: make(map[string]struct{}),
}

Expand All @@ -164,7 +161,6 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
}

p := pass.Fset.Position(call.Pos())
p.Filename, _ = filepath.Rel(moduleDir, p.Filename)
report(pass, result, fn, p)
})

Expand All @@ -181,7 +177,7 @@ type structType struct {

// checker parses and checks struct types.
type checker struct {
mainModule map[string]struct{} // do not check types outside of the main module; see issue #17.
mainModule string
seenTypes map[string]struct{} // prevent panic on recursive types; see issue #16.
}

Expand All @@ -202,13 +198,16 @@ func (c *checker) parseStructType(t types.Type, pos token.Pos) (*structType, boo
if pkg == nil {
return nil, false
}
if _, ok := c.mainModule[pkg.Path()]; !ok {

if !strings.HasPrefix(pkg.Path(), c.mainModule) {
return nil, false
}

s, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}

return &structType{
Struct: s,
Pos: t.Obj().Pos(),
Expand Down
128 changes: 18 additions & 110 deletions musttag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"go/token"
"io"
"os"
"os/exec"
"path/filepath"
"testing"

Expand All @@ -14,13 +15,6 @@ import (
)

func TestAnalyzer(t *testing.T) {
// NOTE: analysistest does not yet support modules;
// see https://github.com/golang/go/issues/37054 for details.
// To be able to run tests with external dependencies,
// we first need to write a GOPATH-like tree of stubs.
prepareTestFiles(t)
testPackages = []string{"tests", "builtins"}

testdata := analysistest.TestData()

t.Run("tests", func(t *testing.T) {
Expand All @@ -31,11 +25,15 @@ func TestAnalyzer(t *testing.T) {
pass.Reportf(st.Pos, fn.shortName())
}

setupTestData(t, testdata, "tests")

analyzer := New()
analysistest.Run(t, testdata, analyzer, "tests")
})

t.Run("builtins", func(t *testing.T) {
setupTestData(t, testdata, "builtins")

analyzer := New(
Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0},
Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1},
Expand All @@ -44,6 +42,8 @@ func TestAnalyzer(t *testing.T) {
})

t.Run("bad Func.ArgPos", func(t *testing.T) {
setupTestData(t, testdata, "tests")

analyzer := New(
Func{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 10},
)
Expand Down Expand Up @@ -77,111 +77,19 @@ type nopT struct{}

func (nopT) Errorf(string, ...any) {}

func prepareTestFiles(t *testing.T) {
testdata := analysistest.TestData()

t.Cleanup(func() {
err := os.RemoveAll(filepath.Join(testdata, "src"))
assert.NoErr[F](t, err)
})

hardlink := func(dir, file string) {
target := filepath.Join(testdata, "src", dir, file)

err := os.MkdirAll(filepath.Dir(target), 0o777)
assert.NoErr[F](t, err)
// NOTE: analysistest does not yet support modules;
// see https://github.com/golang/go/issues/37054 for details.
func setupTestData(t *testing.T, testDataDir, dir string) {
t.Helper()

err = os.Link(filepath.Join(testdata, file), target)
assert.NoErr[F](t, err)
err := os.Chdir(filepath.Join(testDataDir, "src", dir))
if err != nil {
t.Fatal(err)
}

hardlink("tests", "tests.go")
hardlink("builtins", "builtins.go")

for file, data := range stubs {
target := filepath.Join(testdata, "src", file)

err := os.MkdirAll(filepath.Dir(target), 0o777)
assert.NoErr[F](t, err)

err = os.WriteFile(target, []byte(data), 0o666)
assert.NoErr[F](t, err)
output, err := exec.Command("go", "mod", "vendor").CombinedOutput()
if err != nil {
t.Log(string(output))
t.Fatal(err)
}
}

var stubs = map[string]string{
"gopkg.in/yaml.v3/yaml.go": `package yaml
import "io"
func Marshal(_ any) ([]byte, error) { return nil, nil }
func Unmarshal(_ []byte, _ any) error { return nil }
type Encoder struct{}
func NewEncoder(_ io.Writer) *Encoder { return nil }
func (*Encoder) Encode(_ any) error { return nil }
type Decoder struct{}
func NewDecoder(_ io.Reader) *Decoder { return nil }
func (*Decoder) Decode(_ any) error { return nil }`,

"github.com/BurntSushi/toml/toml.go": `package toml
import "io"
import "io/fs"
func Unmarshal(_ []byte, _ any) error { return nil }
type MetaData struct{}
func Decode(_ string, _ any) (MetaData, error) { return MetaData{}, nil }
func DecodeFS(_ fs.FS, _ string, _ any) (MetaData, error) { return MetaData{}, nil }
func DecodeFile(_ string, _ any) (MetaData, error) { return MetaData{}, nil }
type Encoder struct{}
func NewEncoder(_ io.Writer) *Encoder { return nil }
func (*Encoder) Encode(_ any) error { return nil }
type Decoder struct{}
func NewDecoder(_ io.Reader) *Decoder { return nil }
func (*Decoder) Decode(_ any) error { return nil }`,

"github.com/mitchellh/mapstructure/mapstructure.go": `package mapstructure
type Metadata struct{}
func Decode(_, _ any) error { return nil }
func DecodeMetadata(_, _ any, _ *Metadata) error { return nil }
func WeakDecode(_, _ any) error { return nil }
func WeakDecodeMetadata(_, _ any, _ *Metadata) error { return nil }`,

"github.com/jmoiron/sqlx/sqlx.go": `package sqlx
import "context"
type Queryer interface{}
type QueryerContext interface{}
type rowsi interface{}
func Get(Queryer, any, string, ...any) error { return nil }
func GetContext(context.Context, QueryerContext, any, string, ...any) error { return nil }
func Select(Queryer, any, string, ...any) error { return nil }
func SelectContext(context.Context, QueryerContext, any, string, ...any) error { return nil }
func StructScan(rowsi, any) error { return nil }
type Conn struct{}
func (*Conn) GetContext(context.Context, any, string, ...any) error { return nil }
func (*Conn) SelectContext(context.Context, any, string, ...any) error { return nil }
type DB struct{}
func (*DB) Get(any, string, ...any) error { return nil }
func (*DB) GetContext(context.Context, any, string, ...any) error { return nil }
func (*DB) Select(any, string, ...any) error { return nil }
func (*DB) SelectContext(context.Context, any, string, ...any) error { return nil }
type NamedStmt struct{}
func (n *NamedStmt) Get(any, any) error { return nil }
func (n *NamedStmt) GetContext(context.Context, any, any) error { return nil }
func (n *NamedStmt) Select(any, any) error { return nil }
func (n *NamedStmt) SelectContext(context.Context, any, any) error { return nil }
type Row struct{}
func (*Row) StructScan(any) error { return nil }
type Rows struct{}
func (*Rows) StructScan(any) error { return nil }
type Stmt struct{}
func (*Stmt) Get(any, ...any) error { return nil }
func (*Stmt) GetContext(context.Context, any, ...any) error { return nil }
func (*Stmt) Select(any, ...any) error { return nil }
func (*Stmt) SelectContext(context.Context, any, ...any) error { return nil }
type Tx struct{}
func (*Tx) Get(any, string, ...any) error { return nil }
func (*Tx) GetContext(context.Context, any, string, ...any) error { return nil }
func (*Tx) Select(any, string, ...any) error { return nil }
func (*Tx) SelectContext(context.Context, any, string, ...any) error { return nil }`,

"example.com/custom/custom.go": `package custom
func Marshal(_ any) ([]byte, error) { return nil, nil }
func Unmarshal(_ []byte, _ any) error { return nil }`,
}
1 change: 1 addition & 0 deletions testdata/src/builtins/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vendor
Loading

0 comments on commit ea68c39

Please sign in to comment.