Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for dynamic interfaces #392

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
29 changes: 29 additions & 0 deletions analyzer/testdata/src/filtertest/f1.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,3 +1075,32 @@ type exampleStruct struct {
w io.Writer
buf *bytes.Buffer
}

type foo string

func (foo) FooString(_ string) {}
func (foo) FooMap(_ map[string]string) {}
func (foo) FooArray(_ [32]byte) {}
func (foo) FooChan(_ chan string) {}
func (foo) FooType(_ io.Closer) {}
func (foo) FooWithResult(_ string) string { return "" }
func (foo) FooWithResult2(_ string) (io.Closer, error) { return nil, nil }
func (foo) FooWithResult3(_ string) (cl io.Closer, err error) { return }
func (foo) FooGrouped(_, _ io.Closer) {}
func (foo) FooFunc(_ func(x string) error) error { return nil }
func (foo) FooFunc2(func(string) error) error { return nil }

func dynamicInterface() {
var f foo
f.FooString("") // want `\Qdynamic interface 1`
f.FooMap(nil) // want `\Qdynamic interface 2`
f.FooArray([32]byte{}) // want `\Qdynamic interface 3`
f.FooChan(nil) // want `\Qdynamic interface 4`
f.FooType(nil) // want `\Qdynamic interface 5`
f.FooGrouped(nil, nil) // want `\Qdynamic interface 6`
f.FooWithResult("") // want `\Qdynamic interface 7`
f.FooWithResult2("") // want `\Qdynamic interface 8`
f.FooWithResult3("") // want `\Qdynamic interface 9`
f.FooFunc(nil) // want `\Qdynamic interface 10`
f.FooFunc2(nil) // want `\Qdynamic interface 11`
}
27 changes: 26 additions & 1 deletion analyzer/testdata/src/filtertest/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package gorules

import "github.com/quasilyte/go-ruleguard/dsl"
import (
"github.com/quasilyte/go-ruleguard/dsl"
)

func testRules(m dsl.Matcher) {
m.Import(`github.com/quasilyte/go-ruleguard/analyzer/testdata/src/filtertest/foolib`)
Expand Down Expand Up @@ -282,4 +284,27 @@ func testRules(m dsl.Matcher) {
m.Match(`newIface("sink is interface{}").($_)`).
Where(m["$$"].SinkType.Is(`interface{}`)).
Report(`true`)

m.Match(`$x.FooString($_)`).
Where(m["x"].Type.Implements(`interface { FooString(k string) }`)).Report(`dynamic interface 1`)
m.Match(`$x.FooMap($_)`).
Where(m["x"].Type.Implements(`interface { FooMap(k map[string]string) }`)).Report(`dynamic interface 2`)
m.Match(`$x.FooArray($_)`).
Where(m["x"].Type.Implements(`interface { FooArray(k [32]byte) }`)).Report(`dynamic interface 3`)
m.Match(`$x.FooChan($_)`).
Where(m["x"].Type.Implements(`interface { FooChan(k chan string) }`)).Report(`dynamic interface 4`)
m.Match(`$x.FooType($_)`).
Where(m["x"].Type.Implements(`interface { FooType(k io.Closer) }`)).Report(`dynamic interface 5`)
m.Match(`$x.FooGrouped($*_)`).
Where(m["x"].Type.Implements(`interface { FooGrouped(k io.Closer, l io.Closer) }`)).Report(`dynamic interface 6`)
m.Match(`$x.FooWithResult($_)`).
Where(m["x"].Type.Implements(`interface { FooWithResult(k string) string }`)).Report(`dynamic interface 7`)
m.Match(`$x.FooWithResult2($_)`).
Where(m["x"].Type.Implements(`interface { FooWithResult2(k string) (io.Closer, error) }`)).Report(`dynamic interface 8`)
m.Match(`$x.FooWithResult3($_)`).
Where(m["x"].Type.Implements(`interface { FooWithResult3(k string) (cl io.Closer, err error) }`)).Report(`dynamic interface 9`)
m.Match(`$x.FooFunc($_)`).
Where(m["x"].Type.Implements(`interface { FooFunc(x func (x string) error) error }`)).Report(`dynamic interface 10`)
m.Match(`$x.FooFunc2($_)`).
Where(m["x"].Type.Implements(`interface { FooFunc2(func (string) error) error }`)).Report(`dynamic interface 11`)
}
201 changes: 179 additions & 22 deletions ruleguard/ir_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"go/types"
"io/ioutil"
"regexp"
"strconv"

"github.com/quasilyte/gogrep"
"github.com/quasilyte/gogrep/nodetag"
Expand Down Expand Up @@ -478,30 +479,51 @@ func (l *irLoader) unwrapInterfaceExpr(filter ir.FilterExpr) (*types.Interface,
if err != nil {
return nil, l.errorf(filter.Line, err, "parse %s type expr", typeString)
}
qn, ok := n.(*ast.SelectorExpr)
if !ok {

var iface *types.Interface
switch qn := n.(type) {
case *ast.SelectorExpr:
pkgName, ok := qn.X.(*ast.Ident)
if !ok {
return nil, l.errorf(filter.Line, nil, "invalid package name")
}
pkgPath, ok := l.itab.Lookup(pkgName.Name)
if !ok {
return nil, l.errorf(filter.Line, nil, "package %s is not imported", pkgName.Name)
}
pkg, err := l.importer.Import(pkgPath)
if err != nil {
return nil, l.importErrorf(filter.Line, err, "can't load %s", pkgPath)
}
obj := pkg.Scope().Lookup(qn.Sel.Name)
if obj == nil {
return nil, l.errorf(filter.Line, nil, "%s is not found in %s", qn.Sel.Name, pkgPath)
}
iface, ok = obj.Type().Underlying().(*types.Interface)
if !ok {
return nil, l.errorf(filter.Line, nil, "%s is not an interface type", qn.Sel.Name)
}
case *ast.InterfaceType:
methods := make([]*types.Func, 0, len(qn.Methods.List))
for _, method := range qn.Methods.List {
fnType, ok := method.Type.(*ast.FuncType)
if !ok {
continue
}

fn, err := l.mapAstFuncTypeToTypesFunc(method.Names[0].Name, fnType)
if err != nil {
return nil, fmt.Errorf("on unwrapInterfaceExpr: %w", err)
}

methods = append(methods, fn)
}

iface = types.NewInterfaceType(methods, nil).Complete()
default:
return nil, l.errorf(filter.Line, nil, "can't resolve %s type; try a fully-qualified name", typeString)
}
pkgName, ok := qn.X.(*ast.Ident)
if !ok {
return nil, l.errorf(filter.Line, nil, "invalid package name")
}
pkgPath, ok := l.itab.Lookup(pkgName.Name)
if !ok {
return nil, l.errorf(filter.Line, nil, "package %s is not imported", pkgName.Name)
}
pkg, err := l.importer.Import(pkgPath)
if err != nil {
return nil, l.importErrorf(filter.Line, err, "can't load %s", pkgPath)
}
obj := pkg.Scope().Lookup(qn.Sel.Name)
if obj == nil {
return nil, l.errorf(filter.Line, nil, "%s is not found in %s", qn.Sel.Name, pkgPath)
}
iface, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
return nil, l.errorf(filter.Line, nil, "%s is not an interface type", qn.Sel.Name)
}

return iface, nil
}

Expand Down Expand Up @@ -881,6 +903,141 @@ func (l *irLoader) newBinaryExprFilter(filter ir.FilterExpr, info *filterInfo) (
return result, nil
}

func (l *irLoader) mapAstFuncTypeToTypesFunc(name string, funcType *ast.FuncType) (*types.Func, error) {
var (
vars []*types.Var
res []*types.Var
)

mapField := func(param *ast.Field, results []*types.Var) ([]*types.Var, error) {
tt, err := l.mapAstExprToTypesType(param.Type)
if err != nil {
return nil, err
}

if param.Names != nil {
for _, name := range param.Names { // one param has several names when their type the same
results = append(results, types.NewVar(name.Pos(), nil, name.Name, tt))
}
} else { // unnamed
results = append(results, types.NewVar(param.Pos(), nil, "", tt))
}
return results, nil
}

var err error
if funcType.Params != nil {
vars = make([]*types.Var, 0, len(funcType.Params.List))
for _, param := range funcType.Params.List {
if vars, err = mapField(param, vars); err != nil {
return nil, err
}
}
}

if funcType.Results != nil {
res = make([]*types.Var, 0, len(funcType.Results.List))
for _, param := range funcType.Results.List {
if res, err = mapField(param, res); err != nil {
return nil, err
}
}
}

return types.NewFunc(funcType.Pos(),
nil,
name,
types.NewSignature(nil, types.NewTuple(vars...), types.NewTuple(res...), false),
), nil
}

func (l *irLoader) mapAstExprToTypesType(param ast.Expr) (types.Type, error) {
switch p := param.(type) {
case *ast.Ident:
return typematch.BuiltinTypeByName[p.Name], nil
case *ast.StarExpr:
el, err := l.mapAstExprToTypesType(p.X)
if err != nil {
return nil, err
}

return types.NewPointer(el), nil
case *ast.FuncType:
fn, err := l.mapAstFuncTypeToTypesFunc("", p)
if err != nil {
return nil, err
}

return fn.Type(), nil
case *ast.Ellipsis:
//TODO
return nil, l.errorf(int(p.Pos()), nil, "on mapAstExprToTypesType: variadic types not supported")
case *ast.ChanType:
var dir types.ChanDir
switch {
case p.Dir&ast.SEND != 0 && p.Dir&ast.RECV != 0:
dir = types.SendRecv
case p.Dir&ast.SEND != 0:
dir = types.SendOnly
case p.Dir&ast.RECV != 0:
dir = types.RecvOnly
default:
return nil, nil
}

v, err := l.mapAstExprToTypesType(p.Value)
if err != nil {
return nil, err
}

return types.NewChan(dir, v), nil
case *ast.MapType:
key, err := l.mapAstExprToTypesType(p.Key)
if err != nil {
return nil, err
}

val, err := l.mapAstExprToTypesType(p.Value)
if err != nil {
return nil, err
}

return types.NewMap(key, val), nil
case *ast.ArrayType:
arrLen, err := strconv.ParseInt(p.Len.(*ast.BasicLit).Value, 10, 64)
if err != nil {
return nil, l.errorf(int(p.Pos()), nil, "invalid length provided: "+err.Error())
}

val, err := l.mapAstExprToTypesType(p.Elt)
if err != nil {
return nil, err
}

return types.NewArray(val, arrLen), nil
case *ast.SelectorExpr:
pkgName, ok := p.X.(*ast.Ident)
if !ok {
return nil, l.errorf(int(p.Pos()), nil, "invalid package name")
}
pkgPath, ok := l.itab.Lookup(pkgName.Name)
if !ok {
return nil, l.errorf(int(p.Pos()), nil, "package %s is not imported", pkgName.Name)
}
pkg, err := l.importer.Import(pkgPath)
if err != nil {
return nil, l.importErrorf(int(p.Pos()), err, "can't load %s", pkgPath)
}
obj := pkg.Scope().Lookup(p.Sel.Name)
if obj == nil {
return nil, l.errorf(int(p.Pos()), nil, "%s is not found in %s", p.Sel.Name, pkgPath)
}

return obj.Type(), nil
}
return nil, l.errorf(int(param.Pos()), nil, "unsupported statement provided: %T", param)
}

type filterInfo struct {
Vars map[string]struct{}

Expand Down
4 changes: 2 additions & 2 deletions ruleguard/typematch/typematch.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func Parse(ctx *Context, s string) (*Pattern, error) {
}

var (
builtinTypeByName = map[string]types.Type{
BuiltinTypeByName = map[string]types.Type{
"bool": types.Typ[types.Bool],
"int": types.Typ[types.Int],
"int8": types.Typ[types.Int8],
Expand Down Expand Up @@ -167,7 +167,7 @@ var (
func parseExpr(ctx *Context, e ast.Expr) *pattern {
switch e := e.(type) {
case *ast.Ident:
basic, ok := builtinTypeByName[e.Name]
basic, ok := BuiltinTypeByName[e.Name]
if ok {
return &pattern{op: opBuiltinType, value: basic}
}
Expand Down