diff --git a/analyzer/testdata/src/filtertest/f1.go b/analyzer/testdata/src/filtertest/f1.go index 2d54b33..c9f8905 100644 --- a/analyzer/testdata/src/filtertest/f1.go +++ b/analyzer/testdata/src/filtertest/f1.go @@ -1075,3 +1075,32 @@ type exampleStruct struct { w io.Writer buf *bytes.Buffer } + +type foo string + +func (foo) FooString(_ string) {} +func (foo) FooMap(_ map[string]string) {} +func (foo) FooArray(_ [32]byte) {} +func (foo) FooChan(_ chan string) {} +func (foo) FooType(_ io.Closer) {} +func (foo) FooWithResult(_ string) string { return "" } +func (foo) FooWithResult2(_ string) (io.Closer, error) { return nil, nil } +func (foo) FooWithResult3(_ string) (cl io.Closer, err error) { return } +func (foo) FooGrouped(_, _ io.Closer) {} +func (foo) FooFunc(_ func(x string) error) error { return nil } +func (foo) FooFunc2(func(string) error) error { return nil } + +func dynamicInterface() { + var f foo + f.FooString("") // want `\Qdynamic interface 1` + f.FooMap(nil) // want `\Qdynamic interface 2` + f.FooArray([32]byte{}) // want `\Qdynamic interface 3` + f.FooChan(nil) // want `\Qdynamic interface 4` + f.FooType(nil) // want `\Qdynamic interface 5` + f.FooGrouped(nil, nil) // want `\Qdynamic interface 6` + f.FooWithResult("") // want `\Qdynamic interface 7` + f.FooWithResult2("") // want `\Qdynamic interface 8` + f.FooWithResult3("") // want `\Qdynamic interface 9` + f.FooFunc(nil) // want `\Qdynamic interface 10` + f.FooFunc2(nil) // want `\Qdynamic interface 11` +} diff --git a/analyzer/testdata/src/filtertest/rules.go b/analyzer/testdata/src/filtertest/rules.go index fa13cbf..d5d6cf1 100644 --- a/analyzer/testdata/src/filtertest/rules.go +++ b/analyzer/testdata/src/filtertest/rules.go @@ -3,7 +3,9 @@ package gorules -import "github.com/quasilyte/go-ruleguard/dsl" +import ( + "github.com/quasilyte/go-ruleguard/dsl" +) func testRules(m dsl.Matcher) { m.Import(`github.com/quasilyte/go-ruleguard/analyzer/testdata/src/filtertest/foolib`) @@ -282,4 +284,27 @@ func testRules(m dsl.Matcher) { m.Match(`newIface("sink is interface{}").($_)`). Where(m["$$"].SinkType.Is(`interface{}`)). Report(`true`) + + m.Match(`$x.FooString($_)`). + Where(m["x"].Type.Implements(`interface { FooString(k string) }`)).Report(`dynamic interface 1`) + m.Match(`$x.FooMap($_)`). + Where(m["x"].Type.Implements(`interface { FooMap(k map[string]string) }`)).Report(`dynamic interface 2`) + m.Match(`$x.FooArray($_)`). + Where(m["x"].Type.Implements(`interface { FooArray(k [32]byte) }`)).Report(`dynamic interface 3`) + m.Match(`$x.FooChan($_)`). + Where(m["x"].Type.Implements(`interface { FooChan(k chan string) }`)).Report(`dynamic interface 4`) + m.Match(`$x.FooType($_)`). + Where(m["x"].Type.Implements(`interface { FooType(k io.Closer) }`)).Report(`dynamic interface 5`) + m.Match(`$x.FooGrouped($*_)`). + Where(m["x"].Type.Implements(`interface { FooGrouped(k io.Closer, l io.Closer) }`)).Report(`dynamic interface 6`) + m.Match(`$x.FooWithResult($_)`). + Where(m["x"].Type.Implements(`interface { FooWithResult(k string) string }`)).Report(`dynamic interface 7`) + m.Match(`$x.FooWithResult2($_)`). + Where(m["x"].Type.Implements(`interface { FooWithResult2(k string) (io.Closer, error) }`)).Report(`dynamic interface 8`) + m.Match(`$x.FooWithResult3($_)`). + Where(m["x"].Type.Implements(`interface { FooWithResult3(k string) (cl io.Closer, err error) }`)).Report(`dynamic interface 9`) + m.Match(`$x.FooFunc($_)`). + Where(m["x"].Type.Implements(`interface { FooFunc(x func (x string) error) error }`)).Report(`dynamic interface 10`) + m.Match(`$x.FooFunc2($_)`). + Where(m["x"].Type.Implements(`interface { FooFunc2(func (string) error) error }`)).Report(`dynamic interface 11`) } diff --git a/ruleguard/ir_loader.go b/ruleguard/ir_loader.go index c07a19f..9395fef 100644 --- a/ruleguard/ir_loader.go +++ b/ruleguard/ir_loader.go @@ -10,6 +10,7 @@ import ( "go/types" "io/ioutil" "regexp" + "strconv" "github.com/quasilyte/gogrep" "github.com/quasilyte/gogrep/nodetag" @@ -478,30 +479,51 @@ func (l *irLoader) unwrapInterfaceExpr(filter ir.FilterExpr) (*types.Interface, if err != nil { return nil, l.errorf(filter.Line, err, "parse %s type expr", typeString) } - qn, ok := n.(*ast.SelectorExpr) - if !ok { + + var iface *types.Interface + switch qn := n.(type) { + case *ast.SelectorExpr: + pkgName, ok := qn.X.(*ast.Ident) + if !ok { + return nil, l.errorf(filter.Line, nil, "invalid package name") + } + pkgPath, ok := l.itab.Lookup(pkgName.Name) + if !ok { + return nil, l.errorf(filter.Line, nil, "package %s is not imported", pkgName.Name) + } + pkg, err := l.importer.Import(pkgPath) + if err != nil { + return nil, l.importErrorf(filter.Line, err, "can't load %s", pkgPath) + } + obj := pkg.Scope().Lookup(qn.Sel.Name) + if obj == nil { + return nil, l.errorf(filter.Line, nil, "%s is not found in %s", qn.Sel.Name, pkgPath) + } + iface, ok = obj.Type().Underlying().(*types.Interface) + if !ok { + return nil, l.errorf(filter.Line, nil, "%s is not an interface type", qn.Sel.Name) + } + case *ast.InterfaceType: + methods := make([]*types.Func, 0, len(qn.Methods.List)) + for _, method := range qn.Methods.List { + fnType, ok := method.Type.(*ast.FuncType) + if !ok { + continue + } + + fn, err := l.mapAstFuncTypeToTypesFunc(method.Names[0].Name, fnType) + if err != nil { + return nil, fmt.Errorf("on unwrapInterfaceExpr: %w", err) + } + + methods = append(methods, fn) + } + + iface = types.NewInterfaceType(methods, nil).Complete() + default: return nil, l.errorf(filter.Line, nil, "can't resolve %s type; try a fully-qualified name", typeString) } - pkgName, ok := qn.X.(*ast.Ident) - if !ok { - return nil, l.errorf(filter.Line, nil, "invalid package name") - } - pkgPath, ok := l.itab.Lookup(pkgName.Name) - if !ok { - return nil, l.errorf(filter.Line, nil, "package %s is not imported", pkgName.Name) - } - pkg, err := l.importer.Import(pkgPath) - if err != nil { - return nil, l.importErrorf(filter.Line, err, "can't load %s", pkgPath) - } - obj := pkg.Scope().Lookup(qn.Sel.Name) - if obj == nil { - return nil, l.errorf(filter.Line, nil, "%s is not found in %s", qn.Sel.Name, pkgPath) - } - iface, ok := obj.Type().Underlying().(*types.Interface) - if !ok { - return nil, l.errorf(filter.Line, nil, "%s is not an interface type", qn.Sel.Name) - } + return iface, nil } @@ -881,6 +903,141 @@ func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr, info *filterInfo) ( return result, nil } +func (l *irLoader) mapAstFuncTypeToTypesFunc(name string, funcType *ast.FuncType) (*types.Func, error) { + var ( + vars []*types.Var + res []*types.Var + ) + + mapField := func(param *ast.Field, results []*types.Var) ([]*types.Var, error) { + tt, err := l.mapAstExprToTypesType(param.Type) + if err != nil { + return nil, err + } + + if param.Names != nil { + for _, name := range param.Names { // one param has several names when their type the same + results = append(results, types.NewVar(name.Pos(), nil, name.Name, tt)) + } + } else { // unnamed + results = append(results, types.NewVar(param.Pos(), nil, "", tt)) + } + return results, nil + } + + var err error + if funcType.Params != nil { + vars = make([]*types.Var, 0, len(funcType.Params.List)) + for _, param := range funcType.Params.List { + if vars, err = mapField(param, vars); err != nil { + return nil, err + } + } + } + + if funcType.Results != nil { + res = make([]*types.Var, 0, len(funcType.Results.List)) + for _, param := range funcType.Results.List { + if res, err = mapField(param, res); err != nil { + return nil, err + } + } + } + + return types.NewFunc(funcType.Pos(), + nil, + name, + types.NewSignature(nil, types.NewTuple(vars...), types.NewTuple(res...), false), + ), nil +} + +func (l *irLoader) mapAstExprToTypesType(param ast.Expr) (types.Type, error) { + switch p := param.(type) { + case *ast.Ident: + return typematch.BuiltinTypeByName[p.Name], nil + case *ast.StarExpr: + el, err := l.mapAstExprToTypesType(p.X) + if err != nil { + return nil, err + } + + return types.NewPointer(el), nil + case *ast.FuncType: + fn, err := l.mapAstFuncTypeToTypesFunc("", p) + if err != nil { + return nil, err + } + + return fn.Type(), nil + case *ast.Ellipsis: + //TODO + return nil, l.errorf(int(p.Pos()), nil, "on mapAstExprToTypesType: variadic types not supported") + case *ast.ChanType: + var dir types.ChanDir + switch { + case p.Dir&ast.SEND != 0 && p.Dir&ast.RECV != 0: + dir = types.SendRecv + case p.Dir&ast.SEND != 0: + dir = types.SendOnly + case p.Dir&ast.RECV != 0: + dir = types.RecvOnly + default: + return nil, nil + } + + v, err := l.mapAstExprToTypesType(p.Value) + if err != nil { + return nil, err + } + + return types.NewChan(dir, v), nil + case *ast.MapType: + key, err := l.mapAstExprToTypesType(p.Key) + if err != nil { + return nil, err + } + + val, err := l.mapAstExprToTypesType(p.Value) + if err != nil { + return nil, err + } + + return types.NewMap(key, val), nil + case *ast.ArrayType: + arrLen, err := strconv.ParseInt(p.Len.(*ast.BasicLit).Value, 10, 64) + if err != nil { + return nil, l.errorf(int(p.Pos()), nil, "invalid length provided: "+err.Error()) + } + + val, err := l.mapAstExprToTypesType(p.Elt) + if err != nil { + return nil, err + } + + return types.NewArray(val, arrLen), nil + case *ast.SelectorExpr: + pkgName, ok := p.X.(*ast.Ident) + if !ok { + return nil, l.errorf(int(p.Pos()), nil, "invalid package name") + } + pkgPath, ok := l.itab.Lookup(pkgName.Name) + if !ok { + return nil, l.errorf(int(p.Pos()), nil, "package %s is not imported", pkgName.Name) + } + pkg, err := l.importer.Import(pkgPath) + if err != nil { + return nil, l.importErrorf(int(p.Pos()), err, "can't load %s", pkgPath) + } + obj := pkg.Scope().Lookup(p.Sel.Name) + if obj == nil { + return nil, l.errorf(int(p.Pos()), nil, "%s is not found in %s", p.Sel.Name, pkgPath) + } + + return obj.Type(), nil + } + return nil, l.errorf(int(param.Pos()), nil, "unsupported statement provided: %T", param) +} + type filterInfo struct { Vars map[string]struct{} diff --git a/ruleguard/typematch/typematch.go b/ruleguard/typematch/typematch.go index b747403..ac9984d 100644 --- a/ruleguard/typematch/typematch.go +++ b/ruleguard/typematch/typematch.go @@ -135,7 +135,7 @@ func Parse(ctx *Context, s string) (*Pattern, error) { } var ( - builtinTypeByName = map[string]types.Type{ + BuiltinTypeByName = map[string]types.Type{ "bool": types.Typ[types.Bool], "int": types.Typ[types.Int], "int8": types.Typ[types.Int8], @@ -167,7 +167,7 @@ var ( func parseExpr(ctx *Context, e ast.Expr) *pattern { switch e := e.(type) { case *ast.Ident: - basic, ok := builtinTypeByName[e.Name] + basic, ok := BuiltinTypeByName[e.Name] if ok { return &pattern{op: opBuiltinType, value: basic} }