Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
elias-orijtech committed Nov 24, 2023
1 parent 3f5f5eb commit d05f41e
Show file tree
Hide file tree
Showing 10 changed files with 1,735 additions and 61 deletions.
1 change: 1 addition & 0 deletions cmd/protoc-gen-go-pulsar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

_ "github.com/cosmos/cosmos-proto/features/fastreflection"
_ "github.com/cosmos/cosmos-proto/features/protoc"
_ "github.com/cosmos/cosmos-proto/features/zeropb"
"github.com/cosmos/cosmos-proto/generator"
"google.golang.org/protobuf/reflect/protoreflect"

Expand Down
18 changes: 9 additions & 9 deletions features/protoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ func genMessageField(g *generator.GeneratedFile, f *fileInfo, m *messageInfo, fi
sf.append(oneof.GoName)
return
}
goType, pointer := fieldGoType(g, f, field)
goType, pointer := FieldGoType(g, field)
if pointer {
goType = "*" + goType
}
Expand Down Expand Up @@ -1490,7 +1490,7 @@ func genMessageDefaultDecls(g *generator.GeneratedFile, f *fileInfo, m *messageI
continue
}
name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
goType, _ := fieldGoType(g, f, field)
goType, _ := FieldGoType(g, field)
defVal := field.Desc.Default()
switch field.Desc.Kind() {
case protoreflect.StringKind:
Expand Down Expand Up @@ -1610,7 +1610,7 @@ func genMessageGetterMethods(g *generator.GeneratedFile, f *fileInfo, m *message
}

// Getter for message field.
goType, pointer := fieldGoType(g, f, field)
goType, pointer := FieldGoType(g, field)
defaultValue := fieldDefaultValue(g, f, m, field)
g.Annotate(m.GoIdent.GoName+".Get"+field.GoName, field.Location)
leadingComments := appendDeprecationSuffix("",
Expand Down Expand Up @@ -1679,10 +1679,10 @@ func genMessageSetterMethods(g *generator.GeneratedFile, f *fileInfo, m *message
}
}

// fieldGoType returns the Go type used for a field.
// FieldGoType returns the Go type used for a field.
//
// If it returns pointer=true, the struct field is a pointer to the type.
func fieldGoType(g *generator.GeneratedFile, f *fileInfo, field *protogen.Field) (goType string, pointer bool) {
func FieldGoType(g *generator.GeneratedFile, field *protogen.Field) (goType string, pointer bool) {
if field.Desc.IsWeak() {
return "struct{}", false
}
Expand Down Expand Up @@ -1718,8 +1718,8 @@ func fieldGoType(g *generator.GeneratedFile, f *fileInfo, field *protogen.Field)
case field.Desc.IsList():
return "[]" + goType, false
case field.Desc.IsMap():
keyType, _ := fieldGoType(g, f, field.Message.Fields[0])
valType, _ := fieldGoType(g, f, field.Message.Fields[1])
keyType, _ := FieldGoType(g, field.Message.Fields[0])
valType, _ := FieldGoType(g, field.Message.Fields[1])
return fmt.Sprintf("map[%v]%v", keyType, valType), false
}
return goType, pointer
Expand Down Expand Up @@ -1779,7 +1779,7 @@ func genExtensions(g *generator.GeneratedFile, f *fileInfo) {
for _, x := range f.allExtensions {
g.P("{")
g.P("ExtendedType: (*", x.Extendee.GoIdent, ")(nil),")
goType, pointer := fieldGoType(g, f, x.Extension)
goType, pointer := FieldGoType(g, x.Extension)
if pointer {
goType = "*" + goType
}
Expand Down Expand Up @@ -1852,7 +1852,7 @@ func genMessageOneofWrapperTypes(g *generator.GeneratedFile, f *fileInfo, m *mes
g.Annotate(field.GoIdent.GoName, field.Location)
g.Annotate(field.GoIdent.GoName+"."+field.GoName, field.Location)
g.P("type ", field.GoIdent, " struct {")
goType, _ := fieldGoType(g, f, field)
goType, _ := FieldGoType(g, field)
tags := structTags{
{"protobuf", fieldProtobufTagValue(field)},
}
Expand Down
230 changes: 230 additions & 0 deletions features/zeropb/zeropb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package zeropb

import (
"github.com/cosmos/cosmos-proto/features/protoc"
"github.com/cosmos/cosmos-proto/generator"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
)

const (
errorsPackage = protogen.GoImportPath("errors")
mathPackage = protogen.GoImportPath("math")
)

func init() {
generator.RegisterFeature("zeropb", func(gen *generator.GeneratedFile, _ *protogen.Plugin) generator.FeatureGenerator {
return zeropbFeature{
gen: gen,
}
})
}

type zeropbFeature struct {
gen *generator.GeneratedFile
}

func (g zeropbFeature) GenerateFile(file *protogen.File, _ *protogen.Plugin) bool {
for _, m := range file.Messages {
g.generateMessage(file, m)
}
return true // only do this once
}

func (g zeropbFeature) GenerateHelpers() {}

func (g zeropbFeature) generateMessage(f *protogen.File, m *protogen.Message) {
g.generateMarshal(f, m)
g.generateUnmarshal(f, m)
}

func (g zeropbFeature) generateMarshal(f *protogen.File, m *protogen.Message) {
g.gen.P("func (x *", m.GoIdent, ") MarshalZeroPB(buf []byte) (n int, err error) {")
g.gen.P("defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer overflow")`)
g.gen.P(" }")
g.gen.P("}()")
for _, f := range m.Fields {
g.generateMarshalField(f)
}
g.gen.P("return n, nil")
g.gen.P("}")
}

func (g zeropbFeature) generateMarshalField(f *protogen.Field) {
d := f.Desc
switch {
case d.IsList():
g.gen.P("len_", d.Index(), " := uint16(len(x.", f.GoName, "))")
g.gen.P("if len(x.", f.GoName, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, f.GoName, ` is too long")`)
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
g.gen.P("for _, e := range x.", f.GoName, " {")
g.generateMarshalPrimitive(d, "e")
g.gen.P("}")
case d.IsMap():
g.gen.P("len_", d.Index(), " := uint16(len(x.", f.GoName, "))")
g.gen.P("if len(x.", f.GoName, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, f.GoName, ` is too long")`)
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
g.gen.P("for k, v := range x.", f.GoName, " {")
g.generateMarshalPrimitive(d.MapKey(), "k")
g.generateMarshalPrimitive(d.MapValue(), "v")
g.gen.P("}")
case d.ContainingOneof() != nil:
g.gen.P("// TODO: field ", f.GoName)
return
default:
g.generateMarshalPrimitive(d, "x."+f.GoName)
}
}

func (g zeropbFeature) generateMarshalPrimitive(d protoreflect.FieldDescriptor, name string) {
switch d.Kind() {
case protoreflect.FloatKind:
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], ", mathPackage.Ident("Float32bits"), "(", name, "))")
g.gen.P("n += 4")
case protoreflect.DoubleKind:
g.gen.P("binary.LittleEndian.PutUint64(buf[n:], ", mathPackage.Ident("Float64bits"), "(", name, "))")
g.gen.P("n += 8")
case protoreflect.Sfixed32Kind, protoreflect.Fixed32Kind, protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, protoreflect.EnumKind:
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], uint32(", name, "))")
g.gen.P("n += 4")
case protoreflect.Sfixed64Kind, protoreflect.Fixed64Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind:
g.gen.P("binary.LittleEndian.PutUint64(buf[n:], uint64(", name, "))")
g.gen.P("n += 8")
case protoreflect.BoolKind:
g.gen.P("bool_", d.Index(), " := uint32(0)")
g.gen.P("if ", name, " {")
g.gen.P(" bool_", d.Index(), " = 1")
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint32(buf[n:], bool_", d.Index(), ")")
g.gen.P("n += 4")
case protoreflect.StringKind, protoreflect.BytesKind:
g.gen.P("len_", d.Index(), " := uint16(len(", name, "))")
g.gen.P("if len(", name, ") != int(len_", d.Index(), ") {")
g.gen.P(" return n, ", errorsPackage.Ident("New"), `("field `, name, ` is too long")`)
g.gen.P("}")
g.gen.P("binary.LittleEndian.PutUint16(buf[n:], len_", d.Index(), ")")
g.gen.P("n += 2")
// Reslice buf to convert a truncated write into a buffer overflow error.
g.gen.P("copy(buf[n:n+len(", name, ")], ", name, ")")
g.gen.P("n += len(", name, ")")
case protoreflect.MessageKind:
g.gen.P("n_", d.Index(), ", err := ", name, ".MarshalZeroPB(buf[n:])")
g.gen.P("n += n_", d.Index())
g.gen.P("if err != nil {")
g.gen.P(" return n, err")
g.gen.P("}")
default:
g.gen.P("// TODO: field ", name)
g.gen.P("_ = ", name)
}
}

func (g zeropbFeature) generateUnmarshal(f *protogen.File, m *protogen.Message) {
g.gen.P("func (x *", m.GoIdent, ") UnmarshalZeroPB(buf []byte) (n int, err error) {")
g.gen.P("defer func() {")
g.gen.P(" if e := recover(); e != nil {")
g.gen.P(" err = ", errorsPackage.Ident("New"), `("buffer underflow")`)
g.gen.P(" }")
g.gen.P("}()")
for _, f := range m.Fields {
g.generateUnmarshalField(f)
}
g.gen.P("return n, nil")
g.gen.P("}")
}

func (g zeropbFeature) generateUnmarshalField(f *protogen.Field) {
d := f.Desc
switch {
case d.IsList():
g.gen.P("len_", d.Index(), " := int(binary.LittleEndian.Uint16(buf[n:]))")
g.gen.P("n += 2")
typ, pointer := protoc.FieldGoType(g.gen, f)
if pointer {
typ = "*" + typ
}
g.gen.P("x.", f.GoName, " = make(", typ, ", len_", d.Index(), ")")
g.gen.P("for i := range x.", f.GoName, "{")
g.generateUnmarshalPrimitive(d, "x."+f.GoName+"[i]")
g.gen.P("}")
case d.IsMap():
g.gen.P("len_", d.Index(), " := int(binary.LittleEndian.Uint16(buf[n:]))")
g.gen.P("n += 2")
typ, _ := protoc.FieldGoType(g.gen, f)
g.gen.P("x.", f.GoName, " = make(", typ, ", len_", d.Index(), ")")
keyType, _ := protoc.FieldGoType(g.gen, f.Message.Fields[0])
valType, _ := protoc.FieldGoType(g.gen, f.Message.Fields[1])
g.gen.P("for i := 0; i < len_", d.Index(), "; i++ {")
g.gen.P("var k ", keyType)
g.gen.P("var v ", valType)
g.generateUnmarshalPrimitive(d.MapKey(), "k")
g.generateUnmarshalPrimitive(d.MapValue(), "v")
g.gen.P(" x.", f.GoName, "[k] = v")
g.gen.P("}")
case d.ContainingOneof() != nil:
g.gen.P("// TODO: field ", f.GoName)
default:
g.generateUnmarshalPrimitive(d, "x."+f.GoName)
}
}

func (g zeropbFeature) generateUnmarshalPrimitive(d protoreflect.FieldDescriptor, name string) {
switch d.Kind() {
case protoreflect.FloatKind:
g.gen.P(name, " = float32(", mathPackage.Ident("Float32frombits"), "(binary.LittleEndian.Uint32(buf[n:])))")
g.gen.P("n += 4")
case protoreflect.DoubleKind:
g.gen.P(name, " = float64(", mathPackage.Ident("Float64frombits"), "(binary.LittleEndian.Uint64(buf[n:])))")
g.gen.P("n += 8")
case protoreflect.Sfixed32Kind, protoreflect.Int32Kind, protoreflect.Sint32Kind:
g.gen.P(name, " = int32(binary.LittleEndian.Uint32(buf[n:]))")
g.gen.P("n += 4")
case protoreflect.Fixed32Kind, protoreflect.Uint32Kind:
g.gen.P(name, " = binary.LittleEndian.Uint32(buf[n:])")
g.gen.P("n += 4")
case protoreflect.Sfixed64Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind:
g.gen.P(name, " = int64(binary.LittleEndian.Uint64(buf[n:]))")
g.gen.P("n += 8")
case protoreflect.Fixed64Kind, protoreflect.Uint64Kind:
g.gen.P(name, " = binary.LittleEndian.Uint64(buf[n:])")
g.gen.P("n += 8")
case protoreflect.EnumKind:
g.gen.P(name, " = ", d.Enum().FullName(), "(binary.LittleEndian.Uint32(buf[n:]))")
g.gen.P("n += 4")
case protoreflect.BoolKind:
g.gen.P("bool_", d.Index(), " := binary.LittleEndian.Uint32(buf[n:])")
g.gen.P(name, " = false")
g.gen.P("if bool_", d.Index(), " != 0 {")
g.gen.P(" ", name, " = true")
g.gen.P("}")
g.gen.P("n += 4")
case protoreflect.StringKind:
g.gen.P("len_", d.Index(), " := int(binary.LittleEndian.Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P(name, " = string(buf[n:n+len_", d.Index(), "])")
g.gen.P("n += len_", d.Index())
case protoreflect.BytesKind:
g.gen.P("len_", d.Index(), " := int(binary.LittleEndian.Uint16(buf[n:]))")
g.gen.P("n += 2")
g.gen.P(name, " = append([]byte{}, buf[n:n+len_", d.Index(), "]...)")
g.gen.P("n += len_", d.Index())
case protoreflect.MessageKind:
g.gen.P(name, " = new(", d.Message().FullName(), ")")
g.gen.P("n_", d.Index(), ", err := ", name, ".UnmarshalZeroPB(buf[n:])")
g.gen.P("n += n_", d.Index())
g.gen.P("if err != nil {")
g.gen.P(" return n, err")
g.gen.P("}")
default:
g.gen.P("// TODO: field ", name)
g.gen.P("_ = ", name)
}
}
Loading

0 comments on commit d05f41e

Please sign in to comment.