Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support generic type oveload method #358

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ func TestCheckSigFuncExObjects(t *testing.T) {
count int
}{
{"TyOverloadFunc", sigFuncEx(nil, nil, &TyOverloadFunc{objs}), 2},
{"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{objs}), 2},
{"TyOverloadMethod", sigFuncEx(nil, nil, &TyOverloadMethod{Methods: objs}), 2},
{"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{types.NewParam(0, nil, "", tyInt)}), 1},
{"TyTemplateRecvMethod", sigFuncEx(nil, nil, &TyTemplateRecvMethod{fn}), 2},
{"TyOverloadNamed", sigFuncEx(nil, nil, &TyOverloadNamed{Types: []*types.Named{named}}), 1},
Expand Down
17 changes: 12 additions & 5 deletions codebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ retry:
return kind
}
}
if kind := p.method(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(t, name, aliasName, flag, arg, srcExpr, t.TypeArgs() != nil); kind != MemberInvalid {
return kind
}
if fstruc {
Expand All @@ -1641,7 +1641,7 @@ retry:
}
case *types.Named:
named, typ = o, p.getUnderlying(o) // may cause to loadNamed (delay-loaded)
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, o.TypeArgs() != nil); kind != MemberInvalid {
return kind
}
if _, ok := typ.(*types.Struct); ok {
Expand All @@ -1657,7 +1657,7 @@ retry:
}
case *types.Interface:
o.Complete()
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, false); kind != MemberInvalid {
return kind
}
case *types.Basic, *types.Slice, *types.Map, *types.Chan:
Expand All @@ -1667,6 +1667,7 @@ retry:
}

type methodList interface {
types.Type
NumMethods() int
Method(i int) *types.Func
}
Expand Down Expand Up @@ -1713,7 +1714,7 @@ func (p *CodeBuilder) allowAccess(pkg *types.Package, name string) bool {
}

func (p *CodeBuilder) method(
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) (kind MemberKind) {
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, namedHasTypeArgs bool) (kind MemberKind) {
var found *types.Func
var exact bool
for i, n := 0, o.NumMethods(); i < n; i++ {
Expand All @@ -1738,7 +1739,13 @@ func (p *CodeBuilder) method(
if autoprop && !methodHasAutoProperty(typ, 0) {
return memberBad
}

if namedHasTypeArgs {
if t, ok := CheckFuncEx(typ.(*types.Signature)); ok {
if m, ok := t.(*TyOverloadMethod); ok && m.IsGeneric() {
typ = m.Instantiate(o.(*types.Named))
}
}
}
sel := selector(arg, found.Name())
ret := &internal.Elem{
Val: sel,
Expand Down
47 changes: 43 additions & 4 deletions func_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ func CheckOverloadFunc(sig *types.Signature) (funcs []types.Object, ok bool) {

// TyOverloadMethod: overload function type
type TyOverloadMethod struct {
Methods []types.Object
Methods []types.Object
indexs []int // func object indexs
instance map[*types.Named]*types.Signature // cache type signature for named
}

func (p *TyOverloadMethod) At(i int) types.Object { return p.Methods[i] }
Expand All @@ -185,9 +187,46 @@ func (p *TyOverloadMethod) Underlying() types.Type { return p }
func (p *TyOverloadMethod) String() string { return "TyOverloadMethod" }
func (p *TyOverloadMethod) funcEx() {}

// NewOverloadMethod creates an overload method.
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, methods ...types.Object) *types.Func {
return newMethodEx(typ, pos, pkg, name, &TyOverloadMethod{methods})
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, objectIndex map[types.Object]int, methods ...types.Object) *types.Func {
t := &TyOverloadMethod{Methods: methods}
if typ.TypeParams() != nil {
t.indexs = make([]int, len(methods))
for i, obj := range methods {
t.indexs[i] = objectIndex[obj]
}
t.instance = make(map[*types.Named]*types.Signature)
}
return newMethodEx(typ, pos, pkg, name, t)
}

func (m *TyOverloadMethod) IsGeneric() bool {
return len(m.indexs) != 0
}

func (m *TyOverloadMethod) Instantiate(named *types.Named) *types.Signature {
sig, ok := m.instance[named]
if !ok {
sig = newOverloadMethodType(named, m)
m.instance[named] = sig
}
return sig
}

func newOverloadMethodType(named *types.Named, m *TyOverloadMethod) *types.Signature {
var list methodList
switch t := named.Underlying().(type) {
case *types.Interface:
list = t
default:
list = named
}
pkg := named.Obj().Pkg()
recv := types.NewVar(token.NoPos, pkg, "", named)
methods := make([]types.Object, len(m.indexs))
for i, index := range m.indexs {
methods[i] = list.Method(index)
}
return sigFuncEx(pkg, recv, &TyOverloadMethod{Methods: methods})
}

// CheckOverloadMethod checks a func is overload method or not.
Expand Down
10 changes: 6 additions & 4 deletions import.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func InitThisGopPkg(pkg *types.Package) {
scope := pkg.Scope()
gopos := make([]string, 0, 4)
overloads := make(map[omthd][]types.Object)
mobjectIndexs := make(map[types.Object]int)
onameds := make(map[string][]*types.Named)
names := scope.Names()
for _, name := range names {
Expand All @@ -120,6 +121,7 @@ func InitThisGopPkg(pkg *types.Package) {
mthd := mName[:len(mName)-3]
key := omthd{named, mthd}
overloads[key] = append(overloads[key], m)
mobjectIndexs[m] = i
}
}
if isOverload(name) { // overload named
Expand Down Expand Up @@ -150,15 +152,15 @@ func InitThisGopPkg(pkg *types.Package) {
}
}
if len(fns) > 0 {
newOverload(pkg, scope, m, fns)
newOverload(pkg, scope, m, fns, nil)
}
delete(overloads, m)
}
}
for key, items := range overloads {
off := len(key.name) + 2
fns := overloadFuncs(off, items)
newOverload(pkg, scope, key, fns)
newOverload(pkg, scope, key, fns, mobjectIndexs)
}
for name, items := range onameds {
off := len(name) + 2
Expand Down Expand Up @@ -282,7 +284,7 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b
return
}

func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) {
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, mobjectIndexs map[types.Object]int) {
if m.typ == nil {
if debugImport {
log.Println("==> NewOverloadFunc", m.name)
Expand All @@ -294,7 +296,7 @@ func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Ob
if debugImport {
log.Println("==> NewOverloadMethod", m.typ.Obj().Name(), m.name)
}
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, fns...)
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, mobjectIndexs, fns...)
}
}

Expand Down
32 changes: 32 additions & 0 deletions internal/foo/foo.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,36 @@ type NodeSeter interface {
Attr__1(k, v string) (ret NodeSeter)
}

type Data[T any] struct {
data []T
}

func (p *Data[T]) Size() int {
return len(p.data)
}

func (p *Data[T]) Add__0(v ...T) {
p.data = append(p.data, v...)
}

func (p *Data[T]) Add__1(v Data[T]) {
p.data = append(p.data, v.data...)
}

func (p *Data[T]) IndexOf__0(v T) int {
return -1
}

func (p *Data[T]) IndexOf__1(pos int, v T) int {
return -1
}

type DataInterface[T any] interface {
Size() int
Add__0(v ...T)
Add__1(v DataInterface[T])
IndexOf__0(v T) int
IndexOf__1(pos int, v T) int
}

// -----------------------------------------------------------------------------
84 changes: 84 additions & 0 deletions typeparams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1157,3 +1157,87 @@ func main() {
}
`)
}

func TestGenericTypeOverloadMethod(t *testing.T) {
pkg := newMainPackage()
foo := pkg.Import("github.com/goplus/gogen/internal/foo")
tyDataT := foo.Ref("Data").Type()
tyInt := types.Typ[types.Int]
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
v := pkg.NewParam(token.NoPos, "v", tyData)
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
DefineVarStart(token.NoPos, "n").Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("size", gogen.MemberFlagMethodAlias)
}).
Call(0).EndInit(1).EndStmt().
Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("add", gogen.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndStmt().
Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("add", gogen.MemberFlagMethodAlias)
}).
Val(v).Call(1).EndStmt().
DefineVarStart(token.NoPos, "i").Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("indexOf", gogen.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
End()
domTest(t, pkg, `package main

import "github.com/goplus/gogen/internal/foo"

func bar(v foo.Data[int]) {
n := v.Size()
v.Add__0(0, 1)
v.Add__1(v)
i := v.IndexOf__1(0, 1)
}
`)
}

func TestGenericInterfaceOverloadMethod(t *testing.T) {
pkg := newMainPackage()
foo := pkg.Import("github.com/goplus/gogen/internal/foo")
tyDataT := foo.Ref("DataInterface").Type()
tyInt := types.Typ[types.Int]
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
v := pkg.NewParam(token.NoPos, "v", tyData)
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
DefineVarStart(token.NoPos, "n").Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("size", gogen.MemberFlagMethodAlias)
}).
Call(0).EndInit(1).EndStmt().
Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("add", gogen.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndStmt().
Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("add", gogen.MemberFlagMethodAlias)
}).
Val(v).Call(1).EndStmt().
DefineVarStart(token.NoPos, "i").Val(v).
Debug(func(cb *gogen.CodeBuilder) {
cb.Member("indexOf", gogen.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
End()
domTest(t, pkg, `package main

import "github.com/goplus/gogen/internal/foo"

func bar(v foo.DataInterface[int]) {
n := v.Size()
v.Add__0(0, 1)
v.Add__1(v)
i := v.IndexOf__1(0, 1)
}
`)
}