Skip to content

Commit

Permalink
Clean up generics support
Browse files Browse the repository at this point in the history
  • Loading branch information
Josiah McMenamy committed Oct 28, 2024
1 parent 93b4a36 commit f71c93e
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 70 deletions.
29 changes: 5 additions & 24 deletions glang/coq.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,25 +232,6 @@ func (e GallinaIdent) Coq(needs_paren bool) string {
return string(e)
}

// GallinaIdentGeneric is a identifier in Gallina (and not a variable)
//
// A GallinaIdentGeneric is translated literally to Coq along with the type args
type GallinaIdentGeneric struct {
Name GallinaIdent
TypeArgs []Type
}

func (e GallinaIdentGeneric) Coq(needs_paren bool) string {
out := e.Name.Coq(false)
for _, typeArg := range e.TypeArgs {
out += " " + typeArg.Coq(false)
}
if len(e.TypeArgs) > 0 {
return fmt.Sprintf("(%v)", out)
}
return out
}

// A Go qualified identifier, which is translated to a Gallina qualified
// identifier.
type PackageIdent struct {
Expand Down Expand Up @@ -438,13 +419,13 @@ type fieldVal struct {
// Relies on Coq record syntax to correctly order fields for the record's
// constructor.
type StructLiteral struct {
gallinaIdent GallinaIdentGeneric
elts []fieldVal
structType Expr
elts []fieldVal
}

// NewStructLiteral creates a StructLiteral with no values.
func NewStructLiteral(gallinaIdent GallinaIdentGeneric) StructLiteral {
return StructLiteral{gallinaIdent: gallinaIdent}
func NewStructLiteral(structType Expr) StructLiteral {
return StructLiteral{structType: structType}
}

// AddField appends a new (field, val) pair to a StructLiteral.
Expand All @@ -455,7 +436,7 @@ func (sl *StructLiteral) AddField(field string, value Expr) {
func (sl StructLiteral) Coq(needs_paren bool) string {
var pp buffer
method := "struct.make"
pp.Add("%s %s [{", method, sl.gallinaIdent.Coq(true))
pp.Add("%s %s [{", method, sl.structType.Coq(true))
pp.Indent(2)
for i, f := range sl.elts {
terminator := ";"
Expand Down
31 changes: 12 additions & 19 deletions goose.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,7 @@ func (ctx Ctx) fieldSelection(n locatable, index *[]int, curType *types.Type, ex
}
v := info.structType.Field(i)
*expr = glang.NewCallExpr(glang.GallinaIdent("struct.field_get"),
info.gallinaIdent, glang.GallinaString(v.Name()), *expr)
ctx.dep.addDep(info.gallinaIdent.Name.Coq(false))
ctx.structInfoToGlangExpr(info), glang.GallinaString(v.Name()), *expr)
*curType = v.Type()
}
return
Expand All @@ -619,7 +618,6 @@ func (ctx Ctx) fieldSelection(n locatable, index *[]int, curType *types.Type, ex
func (ctx Ctx) fieldAddrSelection(n locatable, index []int, curType *types.Type, expr *glang.Expr) {
for _, i := range index {
info, ok := ctx.getStructInfo(*curType)
ctx.dep.addDep(info.gallinaIdent.Name.Coq(false))
if !ok {
if _, ok := (*curType).(*types.Struct); ok {
ctx.unsupported(n, "anonymous struct")
Expand All @@ -632,7 +630,7 @@ func (ctx Ctx) fieldAddrSelection(n locatable, index []int, curType *types.Type,
v := info.structType.Field(i)

*expr = glang.NewCallExpr(glang.GallinaIdent("struct.field_ref"),
info.gallinaIdent, glang.GallinaString(v.Name()), *expr)
ctx.structInfoToGlangExpr(info), glang.GallinaString(v.Name()), *expr)
*curType = v.Type()
}
return
Expand Down Expand Up @@ -800,8 +798,7 @@ func (ctx Ctx) compositeLiteral(e *ast.CompositeLit) glang.Expr {
}

func (ctx Ctx) structLiteral(info structTypeInfo, e *ast.CompositeLit) glang.StructLiteral {
ctx.dep.addDep(info.gallinaIdent.Name.Coq(false))
lit := glang.NewStructLiteral(info.gallinaIdent)
lit := glang.NewStructLiteral(ctx.structInfoToGlangExpr(info))
isUnkeyedStruct := false

getFieldType := func(fieldName string) types.Type {
Expand Down Expand Up @@ -1094,13 +1091,9 @@ func (ctx Ctx) function(s *ast.Ident) glang.Expr {
return glang.GallinaIdent(s.Name)

}
glangTypeArgs := make([]glang.Type, typeArgs.Len())
for i := range glangTypeArgs {
glangTypeArgs[i] = ctx.glangType(s, typeArgs.At(i))
}
return glang.GallinaIdentGeneric{
Name: glang.GallinaIdent(s.Name),
TypeArgs: glangTypeArgs,
return glang.CallExpr{
MethodName: glang.GallinaIdent(s.Name),
Args: ctx.convertTypeArgsToGlang(nil, typeArgs),
}
}

Expand Down Expand Up @@ -1277,10 +1270,10 @@ func (ctx Ctx) indexExpr(e *ast.IndexExpr, isSpecial bool) glang.Expr {
return glang.CallExpr{}
}

// func (ctx Ctx) indexListExpr(e *ast.IndexListExpr) glang.Expr {
// // generic arguments are grabbed from go ast, ignore explicit type args
// return ctx.expr(e.X)
// }
func (ctx Ctx) indexListExpr(e *ast.IndexListExpr) glang.Expr {
// generic arguments are grabbed from go ast, ignore explicit type args
return ctx.expr(e.X)
}

func (ctx Ctx) derefExpr(e ast.Expr) glang.Expr {
return glang.DerefExpr{
Expand Down Expand Up @@ -1339,8 +1332,8 @@ func (ctx Ctx) exprSpecial(e ast.Expr, isSpecial bool) glang.Expr {
return ctx.sliceExpr(e)
case *ast.IndexExpr:
return ctx.indexExpr(e, isSpecial)
// case *ast.IndexListExpr:
// return ctx.indexListExpr(e)
case *ast.IndexListExpr:
return ctx.indexListExpr(e)
case *ast.UnaryExpr:
return ctx.unaryExpr(e, isSpecial)
case *ast.ParenExpr:
Expand Down
42 changes: 32 additions & 10 deletions testdata/examples/semantics/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,41 @@ type genericStruct[A any, B any] struct {
y B
}

func identity[A any, B any](data genericStruct[A, B]) (A, B) {
return data.x, data.y
type genericStruct2[T any] struct {
g T
}

func testGenericStructs() bool {
z := genericStruct[uint64, uint64]{
x: uint64(7),
y: uint64(8),
}
type nonGenericStruct struct {
p uint64
}

f := identity[uint64, uint64]
type IntMap[T any] map[uint64]T

func identity[A any, B any](a A, b B) (A, B) {
return a, b
}

x, y := f(z)
func identity2[A any](a A) A {
return a
}

return x == z.x && y == z.y
func testGenericStructs() bool {
var intMap IntMap[uint64]
intMap = make(IntMap[uint64])
intMap[1] = 2
c := genericStruct2[uint64]{
g: 2,
}
u := genericStruct[string, uint64]{
x: "test",
y: uint64(7),
}
d := identity2[uint64](uint64(5))
_, d2 := identity("test", uint64(5))
g := identity[string, uint64]
_, b := g("test", uint64(3))
h := nonGenericStruct{
p: uint64(3),
}
return d+d2+c.g+u.y+b+h.p+intMap[1] == 27
}
101 changes: 101 additions & 0 deletions testdata/examples/semantics/semantics.gold.v
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,107 @@ Definition failing_testArgumentOrder : val :=
do: ("ok" <-[boolT] "$r0");;;
return: (![boolT] "ok")).

Definition genericStruct (A: go_type) (B: go_type) : go_type := structT [
"x" :: A;
"y" :: B
].

Definition genericStruct__mset : list (string * val) := [
].

Definition genericStruct__mset_ptr : list (string * val) := [
].

Definition genericStruct2 (T: go_type) : go_type := structT [
"g" :: T
].

Definition genericStruct2__mset : list (string * val) := [
].

Definition genericStruct2__mset_ptr : list (string * val) := [
].

Definition nonGenericStruct : go_type := structT [
"p" :: uint64T
].

Definition nonGenericStruct__mset : list (string * val) := [
].

Definition nonGenericStruct__mset_ptr : list (string * val) := [
].

Definition IntMap (T: go_type) : go_type := mapT uint64T T.

Definition IntMap__mset : list (string * val) := [
].

Definition IntMap__mset_ptr : list (string * val) := [
].

(* go: generics.go:18:6 *)
Definition identity (A: go_type) (B: go_type) : val :=
rec: "identity" "a" "b" :=
exception_do (let: "b" := (ref_ty B "b") in
let: "a" := (ref_ty A "a") in
return: (![A] "a", ![B] "b")).

(* go: generics.go:22:6 *)
Definition identity2 (A: go_type) : val :=
rec: "identity2" "a" :=
exception_do (let: "a" := (ref_ty A "a") in
return: (![A] "a")).

(* go: generics.go:26:6 *)
Definition testGenericStructs : val :=
rec: "testGenericStructs" <> :=
exception_do (let: "intMap" := (ref_ty (IntMap uint64T) (zero_val (IntMap uint64T))) in
let: "$r0" := (map.make uint64T uint64T #()) in
do: ("intMap" <-[IntMap uint64T] "$r0");;;
let: "$r0" := #(W64 2) in
do: (map.insert (![IntMap uint64T] "intMap") #(W64 1) "$r0");;;
let: "c" := (ref_ty (genericStruct2 uint64T) (zero_val (genericStruct2 uint64T))) in
let: "$r0" := (struct.make (genericStruct2 uint64T) [{
"g" ::= #(W64 2)
}]) in
do: ("c" <-[genericStruct2 uint64T] "$r0");;;
let: "u" := (ref_ty (genericStruct stringT uint64T) (zero_val (genericStruct stringT uint64T))) in
let: "$r0" := (struct.make (genericStruct stringT uint64T) [{
"x" ::= #(str "test");
"y" ::= #(W64 7)
}]) in
do: ("u" <-[genericStruct stringT uint64T] "$r0");;;
let: "d" := (ref_ty uint64T (zero_val uint64T)) in
let: "$r0" := (let: "$a0" := #(W64 5) in
(identity2 uint64T) "$a0") in
do: ("d" <-[uint64T] "$r0");;;
let: "d2" := (ref_ty uint64T (zero_val uint64T)) in
let: ("$ret0", "$ret1") := (let: "$a0" := #(str "test") in
let: "$a1" := #(W64 5) in
(identity stringT uint64T) "$a0" "$a1") in
let: "$r0" := "$ret0" in
let: "$r1" := "$ret1" in
do: "$r0";;;
do: ("d2" <-[uint64T] "$r1");;;
let: "g" := (ref_ty funcT (zero_val funcT)) in
let: "$r0" := (identity stringT uint64T) in
do: ("g" <-[funcT] "$r0");;;
let: "b" := (ref_ty uint64T (zero_val uint64T)) in
let: ("$ret0", "$ret1") := (let: "$a0" := #(str "test") in
let: "$a1" := #(W64 3) in
(![funcT] "g") "$a0" "$a1") in
let: "$r0" := "$ret0" in
let: "$r1" := "$ret1" in
do: "$r0";;;
do: ("b" <-[uint64T] "$r1");;;
let: "h" := (ref_ty nonGenericStruct (zero_val nonGenericStruct)) in
let: "$r0" := (struct.make nonGenericStruct [{
"p" ::= #(W64 3)
}]) in
do: ("h" <-[nonGenericStruct] "$r0");;;
return: ((((((((![uint64T] "d") + (![uint64T] "d2")) + (![uint64T] (struct.field_ref (genericStruct2 uint64T) "g" "c"))) + (![uint64T] (struct.field_ref (genericStruct stringT uint64T) "y" "u"))) + (![uint64T] "b")) + (![uint64T] (struct.field_ref nonGenericStruct "p" "h"))) + (Fst (map.get (![IntMap uint64T] "intMap") #(W64 1)))) = #(W64 27))).

(* go: int_conversions.go:3:6 *)
Definition testU64ToU32 : val :=
rec: "testU64ToU32" <> :=
Expand Down
3 changes: 0 additions & 3 deletions testdata/negative-tests/badgenerics/badgenerics.go

This file was deleted.

5 changes: 0 additions & 5 deletions testdata/negative-tests/badgenerics2/badgenerics2.go

This file was deleted.

33 changes: 24 additions & 9 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,15 @@ func (ctx Ctx) glangType(n locatable, t types.Type) glang.Type {
return glang.TypeIdent("disk.Disk")
}
if info, ok := ctx.getStructInfo(t); ok {
ctx.dep.addDep(info.gallinaIdent.Name.Coq(false))
return info.gallinaIdent
return ctx.structInfoToGlangExpr(info)
}
ctx.dep.addDep(ctx.qualifiedName(t.Obj()))
if t.TypeArgs().Len() != 0 {
return glang.CallExpr{
MethodName: glang.TypeIdent(ctx.qualifiedName(t.Obj())),
Args: ctx.convertTypeArgsToGlang(nil, t.TypeArgs()),
}
}
return glang.TypeIdent(ctx.qualifiedName(t.Obj()))
case *types.Slice:
return glang.SliceType{Value: ctx.glangType(n, t.Elem())}
Expand Down Expand Up @@ -174,18 +179,30 @@ func getIntegerType(t types.Type) (intTypeInfo, bool) {
}
}

func (ctx Ctx) convertTypeArgsToGlang(l locatable, typeList *types.TypeList) (glangTypeArgs []glang.Type) {
glangTypeArgs = make([]glang.Type, typeList.Len())
func (ctx Ctx) convertTypeArgsToGlang(l locatable, typeList *types.TypeList) (glangTypeArgs []glang.Expr) {
glangTypeArgs = make([]glang.Expr, typeList.Len())
for i := range glangTypeArgs {
glangTypeArgs[i] = ctx.glangType(l, typeList.At(i))
}
return
}

type structTypeInfo struct {
name string
throughPointer bool
structType *types.Struct
gallinaIdent glang.GallinaIdentGeneric
typeArgs *types.TypeList
}

func (ctx Ctx) structInfoToGlangExpr(info structTypeInfo) glang.Expr {
ctx.dep.addDep(info.name)
if info.typeArgs.Len() == 0 {
return glang.GallinaIdent(info.name)
}
return glang.CallExpr{
MethodName: glang.GallinaIdent(info.name),
Args: ctx.convertTypeArgsToGlang(nil, info.typeArgs),
}
}

func (ctx Ctx) getStructInfo(t types.Type) (structTypeInfo, bool) {
Expand All @@ -198,10 +215,8 @@ func (ctx Ctx) getStructInfo(t types.Type) (structTypeInfo, bool) {
name := ctx.qualifiedName(t.Obj())
if structType, ok := t.Underlying().(*types.Struct); ok {
return structTypeInfo{
gallinaIdent: glang.GallinaIdentGeneric{
Name: glang.GallinaIdent(name),
TypeArgs: ctx.convertTypeArgsToGlang(nil, t.TypeArgs()),
},
name: name,
typeArgs: t.TypeArgs(),
throughPointer: throughPointer,
structType: structType,
}, true
Expand Down

0 comments on commit f71c93e

Please sign in to comment.