Skip to content

Commit

Permalink
Rudimentary support for passing type parameters into generic functions.
Browse files Browse the repository at this point in the history
Instead of generating an independent function instance for every
combination of type parameters at compile time we construct generic
function instances at runtime using "generic factory functions". Such a
factory takes type params as arguments and returns a concrete instance
of the function for the given type params (type param values are
captured by the returned function as a closure and can be used as
necessary).

Here is an abbreviated example of how a generic function is compiled and
called:

```
// Go:
func F[T any](t T) {}
f(1)

// JS:
F = function(T){ return function(t) {}; };
F($Int)(1);
```

This approach minimizes the size of the generated JS source, which is
critical for the client-side use case, at the cost of runtime
performance. See gopherjs#1013 (comment)
for the detailed description.

Note that the implementation in this commit is far from complete:

  - Generic function instances are not cached.
  - Generic types are not supported.
  - Declaring types dependent on type parameters doesn't work correctly.
  - Operators (such as `+`) do not work correctly with generic
    arguments.
  • Loading branch information
nevkontakte committed Oct 15, 2022
1 parent 900dda7 commit 0406604
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 20 deletions.
48 changes: 46 additions & 2 deletions compiler/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,18 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
)
case *types.Basic:
return fc.formatExpr("%e.charCodeAt(%f)", e.X, e.Index)
case *types.Signature:
return fc.translateGenericInstance(e)
default:
panic(fmt.Sprintf("Unhandled IndexExpr: %T\n", t))
panic(fmt.Errorf("unhandled IndexExpr: %T", t))
}
case *ast.IndexListExpr:
switch t := fc.pkgCtx.TypeOf(e.X).Underlying().(type) {
case *types.Signature:
return fc.translateGenericInstance(e)
default:
panic(fmt.Errorf("unhandled IndexListExpr: %T", t))
}

case *ast.SliceExpr:
if b, isBasic := fc.pkgCtx.TypeOf(e.X).Underlying().(*types.Basic); isBasic && isString(b) {
switch {
Expand Down Expand Up @@ -749,6 +757,10 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
case *types.Var, *types.Const:
return fc.formatExpr("%s", fc.objectName(o))
case *types.Func:
if _, ok := fc.pkgCtx.Info.Instances[e]; ok {
// Generic function call with auto-inferred types.
return fc.translateGenericInstance(e)
}
return fc.formatExpr("%s", fc.objectName(o))
case *types.TypeName:
return fc.formatExpr("%s", fc.typeName(o.Type()))
Expand Down Expand Up @@ -788,6 +800,38 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
}
}

// translateGenericInstance translates a generic function instantiation.
//
// The returned JS expression evaluates into a callable function with type params
// substituted.
func (fc *funcContext) translateGenericInstance(e ast.Expr) *expression {
var identifier *ast.Ident
switch e := e.(type) {
case *ast.Ident:
identifier = e
case *ast.IndexExpr:
identifier = e.X.(*ast.Ident)
case *ast.IndexListExpr:
identifier = e.X.(*ast.Ident)
default:
err := bailout(fmt.Errorf("unexpected generic instantiation expression type %T at %s", e, fc.pkgCtx.fileSet.Position(e.Pos())))
panic(err)
}

instance, ok := fc.pkgCtx.Info.Instances[identifier]
if !ok {
err := fmt.Errorf("no matching generic instantiation for %q at %s", identifier, fc.pkgCtx.fileSet.Position(identifier.Pos()))
bailout(err)
}
typeParams := []string{}
for i := 0; i < instance.TypeArgs.Len(); i++ {
t := instance.TypeArgs.At(i)
typeParams = append(typeParams, fc.typeName(t))
}
o := fc.pkgCtx.Uses[identifier]
return fc.formatExpr("%s(%s)", fc.objectName(o), strings.Join(typeParams, ", "))
}

func (fc *funcContext) translateCall(e *ast.CallExpr, sig *types.Signature, fun *expression) *expression {
args := fc.translateArgs(sig, e.Args, e.Ellipsis.IsValid())
if fc.Blocking[e] {
Expand Down
24 changes: 21 additions & 3 deletions compiler/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
Implicits: make(map[ast.Node]types.Object),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
Scopes: make(map[ast.Node]*types.Scope),
Instances: make(map[*ast.Ident]types.Instance),
}

var errList ErrorList
Expand Down Expand Up @@ -294,7 +295,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
// but now we do it here to maintain previous behavior.
continue
}
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariable(importedPkg.Name(), true)
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariable(importedPkg.Name(), varPackage)
importedPaths = append(importedPaths, importedPkg.Path())
}
sort.Strings(importedPaths)
Expand Down Expand Up @@ -521,7 +522,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
d.DeclCode = funcCtx.CatchOutput(0, func() {
typeName := funcCtx.objectName(o)
lhs := typeName
if isPkgLevel(o) {
if typeVarLevel(o) == varPackage {
lhs += " = $pkg." + encodeIdent(o.Name())
}
size := int64(0)
Expand Down Expand Up @@ -898,5 +899,22 @@ func translateFunction(typ *ast.FuncType, recv *ast.Ident, body *ast.BlockStmt,

c.pkgCtx.escapingVars = prevEV

return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
if !c.sigTypes.IsGeneric() {
return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
}

// Generic functions are generated as factories to allow passing type parameters
// from the call site.
// TODO(nevkontakte): Cache function instances for a given combination of type
// parameters.
// TODO(nevkontakte): Generate type parameter arguments and derive all dependent
// types inside the function.
typeParams := []string{}
for i := 0; i < c.sigTypes.Sig.TypeParams().Len(); i++ {
typeParam := c.sigTypes.Sig.TypeParams().At(i)
typeParams = append(typeParams, c.typeName(typeParam))
}

return params, fmt.Sprintf("function%s(%s){ return function(%s) {\n%s%s}; }",
functionName, strings.Join(typeParams, ", "), strings.Join(params, ", "), bodyOutput, c.Indentation(0))
}
2 changes: 1 addition & 1 deletion compiler/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ func (fc *funcContext) translateStmt(stmt ast.Stmt, label *types.Label) {
for _, spec := range decl.Specs {
o := fc.pkgCtx.Defs[spec.(*ast.TypeSpec).Name].(*types.TypeName)
fc.pkgCtx.typeNames = append(fc.pkgCtx.typeNames, o)
fc.pkgCtx.objectNames[o] = fc.newVariable(o.Name(), true)
fc.pkgCtx.objectNames[o] = fc.newVariable(o.Name(), varPackage)
fc.pkgCtx.dependencies[o] = true
}
case token.CONST:
Expand Down
69 changes: 55 additions & 14 deletions compiler/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,23 @@ func (fc *funcContext) newConst(t types.Type, value constant.Value) ast.Expr {
// local variable name. In this context "local" means "in scope of the current"
// functionContext.
func (fc *funcContext) newLocalVariable(name string) string {
return fc.newVariable(name, false)
return fc.newVariable(name, varFuncLocal)
}

// varLevel specifies at which level a JavaScript variable should be declared.
type varLevel int

const (
// A variable defined at a function level (e.g. local variables).
varFuncLocal = iota
// A variable that should be declared in a generic type or function factory.
// This is mainly for type parameters and generic-dependent types.
varGenericFactory
// A variable that should be declared in a package factory. This user is for
// top-level functions, types, etc.
varPackage
)

// newVariable assigns a new JavaScript variable name for the given Go variable
// or type.
//
Expand All @@ -252,7 +266,7 @@ func (fc *funcContext) newLocalVariable(name string) string {
// to this functionContext, as well as all parents, but not to the list of local
// variables. If false, it is added to this context only, as well as the list of
// local vars.
func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
func (fc *funcContext) newVariable(name string, level varLevel) string {
if name == "" {
panic("newVariable: empty name")
}
Expand All @@ -261,7 +275,7 @@ func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
i := 0
for {
offset := int('a')
if pkgLevel {
if level == varPackage {
offset = int('A')
}
j := i
Expand All @@ -286,9 +300,22 @@ func (fc *funcContext) newVariable(name string, pkgLevel bool) string {
varName = fmt.Sprintf("%s$%d", name, n)
}

if pkgLevel {
for c2 := fc.parent; c2 != nil; c2 = c2.parent {
c2.allVars[name] = n + 1
// Package-level variables are registered in all outer scopes.
if level == varPackage {
for c := fc.parent; c != nil; c = c.parent {
c.allVars[name] = n + 1
}
return varName
}

// Generic-factory level variables are registered in outer scopes up to the
// level of the generic function or method.
if level == varGenericFactory {
for c := fc; c != nil; c = c.parent {
c.allVars[name] = n + 1
if c.sigTypes.IsGeneric() {
break
}
}
return varName
}
Expand Down Expand Up @@ -331,14 +358,20 @@ func isVarOrConst(o types.Object) bool {
return false
}

func isPkgLevel(o types.Object) bool {
return o.Parent() != nil && o.Parent().Parent() == types.Universe
func typeVarLevel(o types.Object) varLevel {
if _, ok := o.Type().(*types.TypeParam); ok {
return varGenericFactory
}
if o.Parent() != nil && o.Parent().Parent() == types.Universe {
return varPackage
}
return varFuncLocal
}

// objectName returns a JS identifier corresponding to the given types.Object.
// Repeated calls for the same object will return the same name.
func (fc *funcContext) objectName(o types.Object) string {
if isPkgLevel(o) {
if typeVarLevel(o) == varPackage {
fc.pkgCtx.dependencies[o] = true

if o.Pkg() != fc.pkgCtx.Pkg || (isVarOrConst(o) && o.Exported()) {
Expand All @@ -348,7 +381,7 @@ func (fc *funcContext) objectName(o types.Object) string {

name, ok := fc.pkgCtx.objectNames[o]
if !ok {
name = fc.newVariable(o.Name(), isPkgLevel(o))
name = fc.newVariable(o.Name(), typeVarLevel(o))
fc.pkgCtx.objectNames[o] = name
}

Expand All @@ -359,13 +392,13 @@ func (fc *funcContext) objectName(o types.Object) string {
}

func (fc *funcContext) varPtrName(o *types.Var) string {
if isPkgLevel(o) && o.Exported() {
if typeVarLevel(o) == varPackage && o.Exported() {
return fc.pkgVar(o.Pkg()) + "." + o.Name() + "$ptr"
}

name, ok := fc.pkgCtx.varPtrNames[o]
if !ok {
name = fc.newVariable(o.Name()+"$ptr", isPkgLevel(o))
name = fc.newVariable(o.Name()+"$ptr", typeVarLevel(o))
fc.pkgCtx.varPtrNames[o] = name
}
return name
Expand All @@ -385,6 +418,8 @@ func (fc *funcContext) typeName(ty types.Type) string {
return "$error"
}
return fc.objectName(t.Obj())
case *types.TypeParam:
return fc.objectName(t.Obj())
case *types.Interface:
if t.Empty() {
return "$emptyInterface"
Expand All @@ -397,8 +432,8 @@ func (fc *funcContext) typeName(ty types.Type) string {
// repeatedly.
anonType, ok := fc.pkgCtx.anonTypeMap.At(ty).(*types.TypeName)
if !ok {
fc.initArgs(ty) // cause all embedded types to be registered
varName := fc.newVariable(strings.ToLower(typeKind(ty)[5:])+"Type", true)
fc.initArgs(ty) // cause all dependency types to be registered
varName := fc.newVariable(strings.ToLower(typeKind(ty)[5:])+"Type", varPackage)
anonType = types.NewTypeName(token.NoPos, fc.pkgCtx.Pkg, varName, ty) // fake types.TypeName
fc.pkgCtx.anonTypes = append(fc.pkgCtx.anonTypes, anonType)
fc.pkgCtx.anonTypeMap.Set(ty, anonType)
Expand Down Expand Up @@ -815,6 +850,12 @@ func (st signatureTypes) HasNamedResults() bool {
return st.HasResults() && st.Sig.Results().At(0).Name() != ""
}

// IsGeneric returns true if the signature represents a generic function or a
// method of a generic type.
func (st signatureTypes) IsGeneric() bool {
return st.Sig.TypeParams().Len() > 0 || st.Sig.RecvTypeParams().Len() > 0
}

// ErrorAt annotates an error with a position in the source code.
func ErrorAt(err error, fset *token.FileSet, pos token.Pos) error {
return fmt.Errorf("%s: %w", fset.Position(pos), err)
Expand Down

0 comments on commit 0406604

Please sign in to comment.