diff --git a/ast.go b/ast.go index 86b70e94..e80f1abc 100644 --- a/ast.go +++ b/ast.go @@ -637,7 +637,7 @@ retry: if (flags & instrFlagOpFunc) != 0 { // from callOpFunc mfn.Type = o.Type() } else { - mfn.Type = methodTypeOf(o.Type()) + mfn.Type = methodCallSig(o.Type()) } if ret, err = matchFuncCall(pkg, &mfn, args, flags); err == nil { fn.Val, fn.Type = mfn.Val, mfn.Type diff --git a/builtin_test.go b/builtin_test.go index 83f55b65..2bd39f08 100644 --- a/builtin_test.go +++ b/builtin_test.go @@ -44,6 +44,32 @@ func getConf() *Config { return &Config{Fset: fset, Importer: imp} } +func TestErrMethodSigOf(t *testing.T) { + foo := types.NewPackage("github.com/bar/foo", "foo") + tn := types.NewTypeName(0, foo, "t", nil) + recv := types.NewNamed(tn, types.Typ[types.Int], nil) + t.Run("Go+ extended method", func(t *testing.T) { + defer func() { + if e := recover(); e != "can't call methodToFunc to Go+ extended method\n" { + t.Fatal("TestErrMethodSigOf:", e) + } + }() + methodSigOf(NewOverloadFunc(0, foo, "foo").Type(), memberFlagMethodToFunc, nil, nil) + }) + t.Run("recv not pointer", func(t *testing.T) { + defer func() { + if e := recover(); e != "recv of method github.com/bar/foo.t.bar isn't a pointer\n" { + t.Fatal("TestErrMethodSigOf:", e) + } + }() + method := types.NewSignatureType(types.NewVar(0, foo, "", recv), nil, nil, nil, nil, false) + arg := &Element{ + Type: &TypeType{typ: types.NewPointer(recv)}, + } + methodSigOf(method, memberFlagMethodToFunc, arg, &ast.SelectorExpr{Sel: ast.NewIdent("bar")}) + }) +} + func TestMatchOverloadNamedTypeCast(t *testing.T) { pkg := NewPackage("", "foo", nil) foo := types.NewPackage("github.com/bar/foo", "foo") diff --git a/codebuild.go b/codebuild.go index 266738e0..cff5237c 100644 --- a/codebuild.go +++ b/codebuild.go @@ -1496,18 +1496,24 @@ const ( MemberFlagMethodAlias MemberFlagAutoProperty MemberFlagRef MemberFlag = -1 + + // private state + memberFlagMethodToFunc MemberFlag = -2 ) func aliasNameOf(name string, flag MemberFlag) (string, MemberFlag) { + // flag > 0: (flag == MemberFlagMethodAlias || flag == MemberFlagAutoProperty) if flag > 0 && name != "" { if c := name[0]; c >= 'a' && c <= 'z' { return string(rune(c)+('A'-'a')) + name[1:], flag } + flag = MemberFlagVal } - return "", MemberFlagVal + return "", flag } -// Member func +// Member access member by its name. +// src should point to the full source node `x.sel` func (p *CodeBuilder) Member(name string, flag MemberFlag, src ...ast.Node) (kind MemberKind, err error) { srcExpr := getSrc(src) arg := p.stk.Get(-1) @@ -1523,29 +1529,13 @@ func (p *CodeBuilder) Member(name string, flag MemberFlag, src ...ast.Node) (kin kind = p.refMember(at, name, arg.Val, srcExpr) } else { t, isType := at.(*TypeType) - if isType { + if isType { // (T).method or (*T).method at = t.Type() - if flag == MemberFlagAutoProperty { - flag = MemberFlagVal // can't use auto property to type - } + flag = memberFlagMethodToFunc } aliasName, flag := aliasNameOf(name, flag) kind = p.findMember(at, name, aliasName, flag, arg, srcExpr) - if isType { - if kind == MemberMethod { - e := p.Get(-1) - if sig, ok := e.Type.(*types.Signature); ok { - sp := sig.Params() - spLen := sp.Len() - vars := make([]*types.Var, spLen+1) - vars[0] = types.NewVar(token.NoPos, nil, "", at) - for i := 0; i < spLen; i++ { - vars[i+1] = sp.At(i) - } - e.Type = types.NewSignatureType(nil, nil, nil, types.NewTuple(vars...), sig.Results(), sig.Variadic()) - return - } - } + if isType && kind != MemberMethod { code, pos := p.loadExpr(srcExpr) return MemberInvalid, p.newCodeError( pos, fmt.Sprintf("%s undefined (type %v has no method %s)", code, at, name)) @@ -1700,9 +1690,10 @@ func (p *CodeBuilder) method( if autoprop && !methodHasAutoProperty(typ, 0) { return memberBad } + sel := selector(arg, found.Name()) p.stk.Ret(1, &internal.Elem{ - Val: selector(arg, found.Name()), - Type: methodTypeOf(typ), + Val: sel, + Type: methodSigOf(typ, flag, arg, sel), Src: src, }) if p.rec != nil { @@ -1807,7 +1798,42 @@ func (p *CodeBuilder) field( return p.embeddedField(o, name, aliasName, flag, arg, src) } -func methodTypeOf(typ types.Type) types.Type { +func methodSigOf(typ types.Type, flag MemberFlag, arg *Element, sel *ast.SelectorExpr) types.Type { + if flag != memberFlagMethodToFunc { + return methodCallSig(typ) + } + + sig := typ.(*types.Signature) + if _, ok := CheckFuncEx(sig); ok { + log.Panicln("can't call methodToFunc to Go+ extended method") + } + + recv := sig.Recv().Type() + _, isPtr := recv.(*types.Pointer) // recv is a pointer + at := arg.Type.(*TypeType).typ + if t, ok := at.(*types.Pointer); ok { + if !isPtr { + if _, ok := recv.Underlying().(*types.Interface); !ok { // and recv isn't a interface + log.Panicf("recv of method %v.%s isn't a pointer\n", t.Elem(), sel.Sel.Name) + } + } + } else if isPtr { // use *T + at = types.NewPointer(at) + sel.X = &ast.StarExpr{X: sel.X} + } + sel.X = &ast.ParenExpr{X: sel.X} + + sp := sig.Params() + spLen := sp.Len() + vars := make([]*types.Var, spLen+1) + vars[0] = types.NewVar(token.NoPos, nil, "", at) + for i := 0; i < spLen; i++ { + vars[i+1] = sp.At(i) + } + return types.NewSignatureType(nil, nil, nil, types.NewTuple(vars...), sig.Results(), sig.Variadic()) +} + +func methodCallSig(typ types.Type) types.Type { sig := typ.(*types.Signature) if _, ok := CheckFuncEx(sig); ok { return typ diff --git a/package_test.go b/package_test.go index d6f9f810..568c4221 100644 --- a/package_test.go +++ b/package_test.go @@ -3467,7 +3467,7 @@ func (tt t) bar(info string) { fmt.Println(tt, info) } func main() { - v := foo.bar + v := (foo).bar var tt t = 123 v(tt, "hello") } diff --git a/type_var_and_const.go b/type_var_and_const.go index f947d3fb..5d193a0a 100644 --- a/type_var_and_const.go +++ b/type_var_and_const.go @@ -36,6 +36,18 @@ func (p *CodeBuilder) EndConst() *Element { // ---------------------------------------------------------------------------- +// MethodToFunc: +// +// (T).method +// (*T).method +func (pkg *Package) MethodToFunc(typ types.Type, name string, src ...ast.Node) (ret *Element, err error) { + _, err = pkg.cb.Typ(typ, src...).Member(name, MemberFlagVal, src...) + ret = pkg.cb.stk.Pop() + return +} + +// ---------------------------------------------------------------------------- + type TyState int const ( diff --git a/typeparams_test.go b/typeparams_test.go index ab0e5e45..5e88800f 100644 --- a/typeparams_test.go +++ b/typeparams_test.go @@ -14,6 +14,7 @@ package gox_test import ( + "bytes" "go/token" "go/types" "log" @@ -21,8 +22,96 @@ import ( "testing" "github.com/goplus/gox" + "github.com/goplus/gox/internal/go/format" ) +func formatElement(pkg *gox.Package, ret *gox.Element) string { + var b bytes.Buffer + err := format.Node(&b, pkg.Fset, ret.Val) + if err != nil { + log.Fatalln("format.Node failed:", err) + } + return b.String() +} + +func TestMethodToFunc(t *testing.T) { + const src = `package hello + +type Itf interface { + X() +} + +type Base struct { +} + +func (p Base) F() {} + +func (p *Base) PtrF() {} + +type Foo struct { + Itf + Base + Val byte +} + +func (a Foo) Bar() int { + return 0 +} + +func (a *Foo) PtrBar() string { + return "" +} + +var _ = (Foo).Bar +var _ = (*Foo).PtrBar +var _ = (Foo).F +var _ = (*Foo).PtrF +var _ = (Foo).X +var _ = (*Foo).X +var _ = (Itf).X +` + gt := newGoxTest() + _, err := gt.LoadGoPackage("hello", "foo.go", src) + if err != nil { + t.Fatal(err) + } + pkg := gt.NewPackage("", "main") + pkgRef := pkg.Import("hello") + objFoo := pkgRef.Ref("Foo") + objItf := pkgRef.Ref("Itf") + typ := objFoo.Type() + typItf := objItf.Type() + _, err = pkg.MethodToFunc(typ, "Val") + if err == nil || err.Error() != "-: undefined (type hello.Foo has no method Val)" { + t.Fatal("MethodToFunc failed:", err) + } + checkMethodToFunc(t, pkg, typ, "Bar", "(hello.Foo).Bar") + checkMethodToFunc(t, pkg, types.NewPointer(typ), "PtrBar", "(*hello.Foo).PtrBar") + checkMethodToFunc(t, pkg, typ, "PtrBar", "(*hello.Foo).PtrBar") + checkMethodToFunc(t, pkg, typ, "F", "(hello.Foo).F") + checkMethodToFunc(t, pkg, types.NewPointer(typ), "PtrF", "(*hello.Foo).PtrF") + checkMethodToFunc(t, pkg, typ, "PtrF", "(*hello.Foo).PtrF") + checkMethodToFunc(t, pkg, typItf, "X", "(hello.Itf).X") + checkMethodToFunc(t, pkg, typ, "X", "(hello.Foo).X") + checkMethodToFunc(t, pkg, types.NewPointer(typ), "X", "(*hello.Foo).X") +} + +func checkMethodToFunc(t *testing.T, pkg *gox.Package, typ types.Type, name, code string) { + t.Helper() + ret, err := pkg.MethodToFunc(typ, name) + if err != nil { + t.Fatal("MethodToFunc failed:", err) + } + if _, isPtr := typ.(*types.Pointer); isPtr { + if recv := ret.Type.(*types.Signature).Params().At(0); !types.Identical(recv.Type(), typ) { + t.Fatalf("MethodToFunc: ResultType: %v, Expected: %v\n", recv.Type(), typ) + } + } + if v := formatElement(pkg, ret); v != code { + t.Fatalf("MethodToFunc:\nResult:\n%s\nExpected:\n%s\n", v, code) + } +} + func TestOverloadNamed(t *testing.T) { const src = `package foo