Skip to content

Commit

Permalink
generics parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
monoidic committed Mar 4, 2024
1 parent 0583975 commit cb6d588
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 57 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func lateInit() {
func main() {
flag.StringVar(&sourceRoot, "src", "", "path to directory with source code to examine")
flag.StringVar(&outPath, "out", "", "path to file to dump json results in")
flag.StringVar(&goVersion, "version", runtime.Version(), "go version to use, in go1.${minor}.${patch} format")
flag.StringVar(&goVersion, "version", strings.Split(runtime.Version(), " ")[0], "go version to use, in go1.${minor}.${patch} format")
flag.BoolVar(&permitInvalid, "permit_invalid", false, "permit \"invalid type\" results")
flag.BoolVar(&getCgo, "get_cgo", false, "get per-arch cgo definitions")
flag.Parse()
Expand Down
46 changes: 33 additions & 13 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"golang.org/x/tools/go/types/typeutil"
)

type parseFunc func(fset *token.FileSet, filename string, src []byte) (*ast.File, error)

func parseDiscardFuncBody(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
f, err := parser.ParseFile(fset, filename, src, 0)
if err != nil {
Expand All @@ -32,10 +30,6 @@ func parseDiscardFuncBody(fset *token.FileSet, filename string, src []byte) (*as

func (pkg *pkgData) parseFunc(obj *types.Func) {
signature := obj.Type().(*types.Signature)
// do not handle generic functions
if signature.TypeParams() != nil {
return
}

name := obj.FullName()

Expand All @@ -52,6 +46,15 @@ func (pkg *pkgData) parseFunc(obj *types.Func) {
}
}

if tp := signature.TypeParams(); tp != nil {
pkg.GenericFuncs[name] = genericFuncData{
Params: params,
Results: results,
TypeParams: getTypeParamArr(tp),
}
return
}

pkg.Funcs[name] = funcData{
Params: params,
Results: results,
Expand All @@ -77,10 +80,6 @@ func (pkg *pkgData) parseType(obj *types.TypeName) {
}
panic(obj)
}
// do not handle generic types
if named.TypeParams() != nil {
return
}

name := getTypeName(obj)

Expand All @@ -89,7 +88,7 @@ func (pkg *pkgData) parseType(obj *types.TypeName) {

switch t := named.Underlying().(type) {
case *types.Struct:
pkg.parseStruct(pkg.getTypeName(obj.Type(), ""), t)
pkg.parseStruct(pkg.getTypeName(obj.Type(), ""), t, named.TypeParams())
case *types.Interface:
isInterface = true
case *types.Basic:
Expand Down Expand Up @@ -234,14 +233,16 @@ func (pkg *pkgData) getTypeName(iface types.Type, name string) string {
if name == "" {
panic(iface)
}
pkg.parseStruct(name, dt)
pkg.parseStruct(name, dt, nil)
return name
case *types.Alias:
obj := dt.Obj()
aliasName := getTypeName(obj)
targetName := pkg.getTypeName(types.Unalias(dt), "alias_"+aliasName)
pkg.Aliases[aliasName] = alias{Target: targetName}
return aliasName
case *types.TypeParam:
return dt.String()
default:
_ = dt.(*types.Named)
panic("unreachable")
Expand All @@ -258,7 +259,7 @@ func (pkg *pkgData) parseMethods(obj *types.TypeName) {
}
}

func (pkg *pkgData) parseStruct(name string, obj *types.Struct) {
func (pkg *pkgData) parseStruct(name string, obj *types.Struct, typeParams *types.TypeParamList) {
numFields := obj.NumFields()
fields := make([]namedType, numFields)
for i := range numFields {
Expand All @@ -274,9 +275,28 @@ func (pkg *pkgData) parseStruct(name string, obj *types.Struct) {
DataType: dataType,
}
}

if typeParams != nil {
pkg.GenericStructs[name] = genericStructDef{
Fields: fields,
TypeParams: getTypeParamArr(typeParams),
}
return
}

pkg.Structs[name] = structDef{Fields: fields}
}

func getTypeName(tn *types.TypeName) string {
return fmt.Sprintf("%s.%s", tn.Pkg().Path(), tn.Name())
}

func getTypeParamArr(typeParams *types.TypeParamList) []string {
tParamsArr := make([]string, typeParams.Len())

for i := range typeParams.Len() {
tParamsArr[i] = typeParams.At(i).String()
}

return tParamsArr
}
130 changes: 87 additions & 43 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,34 @@ type namedType struct {
DataType string
}

type equalsI interface {
equals(equalsI) bool
}

type funcData struct {
Params []namedType
Results []namedType
}

type equalsI interface {
equals(equalsI) bool
}

func (x funcData) equals(yI equalsI) bool {
y := yI.(funcData)
return len(x.Params) == len(y.Params) &&
len(x.Results) == len(y.Results) &&
slices.Equal(x.Params, y.Params) &&
return slices.Equal(x.Params, y.Params) &&
slices.Equal(x.Results, y.Results)
}

type genericFuncData struct {
Params []namedType
Results []namedType
TypeParams []string
}

func (x genericFuncData) equals(yI equalsI) bool {
y := yI.(genericFuncData)
return slices.Equal(x.Params, y.Params) &&
slices.Equal(x.Results, y.Results) &&
slices.Equal(x.TypeParams, y.TypeParams)
}

type typeData struct {
Underlying string
}
Expand All @@ -102,6 +113,17 @@ func (x structDef) equals(yI equalsI) bool {
return slices.Equal(x.Fields, yI.(structDef).Fields)
}

type genericStructDef struct {
Fields []namedType
TypeParams []string
}

func (x genericStructDef) equals(yI equalsI) bool {
y := yI.(genericStructDef)
return slices.Equal(x.Fields, y.Fields) &&
slices.Equal(x.TypeParams, y.TypeParams)
}

type alias struct {
Target string
}
Expand All @@ -117,30 +139,36 @@ func (x iface) equals(yI equalsI) bool {
}

type pkgData struct {
Funcs map[string]funcData
Types map[string]typeData
Structs map[string]structDef
Aliases map[string]alias
Interfaces map[string]iface
Funcs map[string]funcData
GenericFuncs map[string]genericFuncData
Types map[string]typeData
Structs map[string]structDef
GenericStructs map[string]genericStructDef
Aliases map[string]alias
Interfaces map[string]iface
}

func newPkgData() *pkgData {
return &pkgData{
Funcs: make(map[string]funcData),
Types: make(map[string]typeData),
Structs: make(map[string]structDef),
Aliases: make(map[string]alias),
Interfaces: make(map[string]iface),
Funcs: make(map[string]funcData),
GenericFuncs: make(map[string]genericFuncData),
Types: make(map[string]typeData),
Structs: make(map[string]structDef),
GenericStructs: make(map[string]genericStructDef),
Aliases: make(map[string]alias),
Interfaces: make(map[string]iface),
}
}

func (pkgD *pkgData) Clone() *pkgData {
return &pkgData{
Funcs: maps.Clone(pkgD.Funcs),
Types: maps.Clone(pkgD.Types),
Structs: maps.Clone(pkgD.Structs),
Aliases: maps.Clone(pkgD.Aliases),
Interfaces: maps.Clone(pkgD.Interfaces),
Funcs: maps.Clone(pkgD.Funcs),
GenericFuncs: maps.Clone(pkgD.GenericFuncs),
Types: maps.Clone(pkgD.Types),
Structs: maps.Clone(pkgD.Structs),
GenericStructs: maps.Clone(pkgD.GenericStructs),
Aliases: maps.Clone(pkgD.Aliases),
Interfaces: maps.Clone(pkgD.Interfaces),
}
}

Expand Down Expand Up @@ -171,19 +199,23 @@ func MapAndIn[V equalsI](x, y map[string]V) {
// get pkgData with definitions existing in both pkg And y
func (pkg *pkgData) And(y *pkgData) *pkgData {
return &pkgData{
Funcs: MapAnd(pkg.Funcs, y.Funcs),
Types: MapAnd(pkg.Types, y.Types),
Structs: MapAnd(pkg.Structs, y.Structs),
Aliases: MapAnd(pkg.Aliases, y.Aliases),
Interfaces: MapAnd(pkg.Interfaces, y.Interfaces),
Funcs: MapAnd(pkg.Funcs, y.Funcs),
GenericFuncs: MapAnd(pkg.GenericFuncs, y.GenericFuncs),
Types: MapAnd(pkg.Types, y.Types),
Structs: MapAnd(pkg.Structs, y.Structs),
GenericStructs: MapAnd(pkg.GenericStructs, y.GenericStructs),
Aliases: MapAnd(pkg.Aliases, y.Aliases),
Interfaces: MapAnd(pkg.Interfaces, y.Interfaces),
}
}

// in-place and
func (pkg *pkgData) AndIn(y *pkgData) {
MapAndIn(pkg.Funcs, y.Funcs)
MapAndIn(pkg.GenericFuncs, y.GenericFuncs)
MapAndIn(pkg.Types, y.Types)
MapAndIn(pkg.Structs, y.Structs)
MapAndIn(pkg.GenericStructs, y.GenericStructs)
MapAndIn(pkg.Aliases, y.Aliases)
MapAndIn(pkg.Interfaces, y.Interfaces)
}
Expand All @@ -210,19 +242,23 @@ func MapAndNotIn[V equalsI](x, y map[string]V) {
// return map with key-value pairs from pkg that do not have an equal pair in y
func (pkg *pkgData) AndNot(y *pkgData) *pkgData {
return &pkgData{
Funcs: MapAndNot(pkg.Funcs, y.Funcs),
Types: MapAndNot(pkg.Types, y.Types),
Structs: MapAndNot(pkg.Structs, y.Structs),
Aliases: MapAndNot(pkg.Aliases, y.Aliases),
Interfaces: MapAndNot(pkg.Interfaces, y.Interfaces),
Funcs: MapAndNot(pkg.Funcs, y.Funcs),
GenericFuncs: MapAndNot(pkg.GenericFuncs, y.GenericFuncs),
Types: MapAndNot(pkg.Types, y.Types),
Structs: MapAndNot(pkg.Structs, y.Structs),
GenericStructs: MapAndNot(pkg.GenericStructs, y.GenericStructs),
Aliases: MapAndNot(pkg.Aliases, y.Aliases),
Interfaces: MapAndNot(pkg.Interfaces, y.Interfaces),
}
}

// in-place version of andNot
func (pkg *pkgData) AndNotIn(y *pkgData) {
MapAndNotIn(pkg.Funcs, y.Funcs)
MapAndNotIn(pkg.GenericFuncs, y.GenericFuncs)
MapAndNotIn(pkg.Types, y.Types)
MapAndNotIn(pkg.Structs, y.Structs)
MapAndNotIn(pkg.GenericStructs, y.GenericStructs)
MapAndNotIn(pkg.Aliases, y.Aliases)
MapAndNotIn(pkg.Interfaces, y.Interfaces)
}
Expand All @@ -236,19 +272,23 @@ func MapMerge[T any](x, y map[string]T) map[string]T {
// return merged map with both x and y
func (pkg *pkgData) Merge(y *pkgData) *pkgData {
return &pkgData{
Funcs: MapMerge(pkg.Funcs, y.Funcs),
Types: MapMerge(pkg.Types, y.Types),
Structs: MapMerge(pkg.Structs, y.Structs),
Aliases: MapMerge(pkg.Aliases, y.Aliases),
Interfaces: MapMerge(pkg.Interfaces, y.Interfaces),
Funcs: MapMerge(pkg.Funcs, y.Funcs),
GenericFuncs: MapMerge(pkg.GenericFuncs, y.GenericFuncs),
Types: MapMerge(pkg.Types, y.Types),
Structs: MapMerge(pkg.Structs, y.Structs),
GenericStructs: MapMerge(pkg.GenericStructs, y.GenericStructs),
Aliases: MapMerge(pkg.Aliases, y.Aliases),
Interfaces: MapMerge(pkg.Interfaces, y.Interfaces),
}
}

// in-place version of merge
func (pkg *pkgData) MergeIn(y *pkgData) {
maps.Copy(pkg.Funcs, y.Funcs)
maps.Copy(pkg.GenericFuncs, y.GenericFuncs)
maps.Copy(pkg.Types, y.Types)
maps.Copy(pkg.Structs, y.Structs)
maps.Copy(pkg.GenericStructs, y.GenericStructs)
maps.Copy(pkg.Aliases, y.Aliases)
maps.Copy(pkg.Interfaces, y.Interfaces)
}
Expand All @@ -273,22 +313,26 @@ func mapNotIn[T any](x, y map[string]T) {
// remove keys existing in y from pkg
func (pkg *pkgData) Not(y *pkgData) *pkgData {
return &pkgData{
Funcs: mapNot(pkg.Funcs, y.Funcs),
Types: mapNot(pkg.Types, y.Types),
Structs: mapNot(pkg.Structs, y.Structs),
Aliases: mapNot(pkg.Aliases, y.Aliases),
Interfaces: mapNot(pkg.Interfaces, y.Interfaces),
Funcs: mapNot(pkg.Funcs, y.Funcs),
GenericFuncs: mapNot(pkg.GenericFuncs, y.GenericFuncs),
Types: mapNot(pkg.Types, y.Types),
Structs: mapNot(pkg.Structs, y.Structs),
GenericStructs: mapNot(pkg.GenericStructs, y.GenericStructs),
Aliases: mapNot(pkg.Aliases, y.Aliases),
Interfaces: mapNot(pkg.Interfaces, y.Interfaces),
}
}

func (pkg *pkgData) NotIn(y *pkgData) {
mapNotIn(pkg.Funcs, y.Funcs)
mapNotIn(pkg.GenericFuncs, y.GenericFuncs)
mapNotIn(pkg.Types, y.Types)
mapNotIn(pkg.Structs, y.Structs)
mapNotIn(pkg.GenericStructs, y.GenericStructs)
mapNotIn(pkg.Aliases, y.Aliases)
mapNotIn(pkg.Interfaces, y.Interfaces)
}

func (pkg *pkgData) empty() bool {
return (len(pkg.Funcs) + len(pkg.Types) + len(pkg.Structs) + len(pkg.Aliases) + len(pkg.Interfaces)) == 0
return (len(pkg.Funcs) + len(pkg.GenericFuncs) + len(pkg.Types) + len(pkg.Structs) + len(pkg.GenericStructs) + len(pkg.Aliases) + len(pkg.Interfaces)) == 0
}

0 comments on commit cb6d588

Please sign in to comment.