Skip to content

Commit

Permalink
make codegen more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
Pipello committed Oct 8, 2024
1 parent 60c9d18 commit fed350f
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 64 deletions.
22 changes: 22 additions & 0 deletions definition/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package definition

import (
"database/sql"

"google.golang.org/protobuf/types/known/timestamppb"
)

func ProtoTimeToSql(protoTime *timestamppb.Timestamp) sql.NullTime {
sqlTime := sql.NullTime{}
if protoTime != nil {
sqlTime.Scan(protoTime.AsTime())
}
return sqlTime
}

func SQLTimeToProto(sqlTime sql.NullTime) *timestamppb.Timestamp {
if sqlTime.Valid {
timestamppb.New(sqlTime.Time)
}
return nil
}
26 changes: 13 additions & 13 deletions definition/model.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@
package models

import (
pb "iot-device-register/api"
"database/sql"
"gorm.io/gorm"
"google.golang.org/protobuf/types/known/timestamppb"

"gorm.io/gorm"
pb "{{.RepositoryName}}/api"
"github.com/Pipello/codegen/definition"
)
{{ range $i, $field := .Fields}}
{{- if gt (len $field.Choices) 0 }}
var {{ $field.LowerCaseName }}Choices = []{{$field.Type}}{{"{"}}
var {{ $field.LowerCaseName }}Choices = []{{$field.GetGORMType}}{{"{"}}
{{- range $j, $choice := .Choices }}
{{ $choice.GetValue $field.Type }},
{{ $choice.GetValue $field.GetGORMType }},
{{- end }}
{{"}"}}
{{ end }}
{{- end }}
type {{.Name}} struct {
gorm.Model
{{- range .Fields}}
{{.Name}} {{if .Repeated}}[]{{end}}{{if .Optional}}*{{end}}{{.Type}} `json:"{{.ToSnakeCase}}"{{if .GormTag}} gorm:"{{.GormTag}}"{{end}}`
{{.Name}} {{if .Repeated}}[]{{end}}{{if .Optional}}*{{end}}{{.GetGORMType}} `json:"{{.ToSnakeCase}}"{{if .GormTag}} gorm:"{{.GormTag}}"{{end}}`
{{- end}}
}

Expand All @@ -40,21 +43,18 @@ func (m *{{.Name}}) ToProto() *pb.{{.Name}} {
}
{{- range .Fields -}}
{{- if and .Relationship .Repeated }}
{{ .LowerCaseName }} := []*pb.{{.Type}}{}
{{ .LowerCaseName }} := []*pb.{{.GetGORMType}}{}
for _, item := range m.{{ .Name }} {
{{ .LowerCaseName }} = append({{ .LowerCaseName }}, item.ToProto())
}
{{- end }}
{{- end }}
return &pb.{{.Name}}{
Id: uint64(m.ID),
CreatedAt: m.CreatedAt.String(),
UpdatedAt: m.UpdatedAt.String(),
{{- range .Fields }}
{{ if not .Relationship -}}{{ .GoCamelCaseName }}: m.{{ .Name }},
{{- else if not .Repeated -}}{{ .GoCamelCaseName }}: m.{{ .Name }}.ToProto(),
{{- else }}{{ .GoCamelCaseName }}: {{ .LowerCaseName }},
{{- end -}}
CreatedAt: timestamppb.New(m.CreatedAt),
UpdatedAt: timestamppb.New(m.UpdatedAt),
{{- range .Fields }}
{{ .GoCamelCaseName }}: {{ .ValueToProto }},
{{- end }}
}
}
Expand Down
149 changes: 116 additions & 33 deletions definition/model_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ func (c Choice) GetValue(t string) string {
return fmt.Sprint(c.Value)
}

type FieldType int

const (
IntegerType FieldType = iota
UnsignedIntegerType
StringType
Timestamp
Relationship
)

type Field struct {
Name string
Type string
Type FieldType
Optional bool
Repeated bool
Relationship bool
Expand All @@ -39,14 +49,57 @@ type Field struct {
Choices []Choice
}

func (f *Field) GetZeroValue() string {
if f.Type == "string" {
func (f *Field) GetProtoType() string {
switch f.Type {
case IntegerType:
return "int64"
case UnsignedIntegerType:
return "uint64"
case StringType:
return "string"
case Timestamp:
return "google.protobuf.Timestamp"
}
return ""
}

func (f *Field) GetGORMType() string {
switch f.Type {
case IntegerType:
return "int64"
case UnsignedIntegerType:
return "uint64"
case StringType:
return "string"
case Timestamp:
return "sql.NullTime"
}
return ""
}

func (f *Field) ValueToProto() string {
if f.Type == Timestamp {
return "definition.SQLTimeToProto(m." + f.Name + ")"
}

if f.Relationship {
if f.Repeated {
return f.LowerCaseName()
}
return "m." + f.Name + ".ToProto()"
}
return "m." + f.Name
}

func (f *Field) ValueToSQL(model string) string {
if f.Relationship {
return ""
}
if f.Type == "uint64" {
return "0"
accessor := strings.Join([]string{"req", model, f.GoCamelCaseName()}, ".")
if f.Type == Timestamp {
return "definition.ProtoTimeToSql(" + accessor + ")"
}
return "nil"
return accessor
}

func (f *Field) GoCamelCaseName() string {
Expand Down Expand Up @@ -81,43 +134,44 @@ const (
Delete
)

type Model struct {
Name string
Table string
Methods allowedMethods
Fields []*Field
type Schema struct {
Name string
Table string
Methods allowedMethods
Fields []*Field
CustomValidation string
RepositoryName string
}

func (m *Model) HasGet() bool {
func (m *Schema) HasGet() bool {
return m.Methods&Get > 0
}

func (m *Model) HasList() bool {
func (m *Schema) HasList() bool {
return m.Methods&List > 0
}

func (m *Model) HasCreate() bool {
func (m *Schema) HasCreate() bool {
return m.Methods&Create > 0
}

func (m *Model) HasUpdate() bool {
func (m *Schema) HasUpdate() bool {
return m.Methods&Update > 0
}

func (m *Model) HasDelete() bool {
func (m *Schema) HasDelete() bool {
return m.Methods&Delete > 0
}

func (m *Model) AutoFillProtoIndex() {
func (m *Schema) AutoFillProtoIndex() {
for i, f := range m.Fields {
if f.ProtoIndex == 0 {
f.ProtoIndex = i + 4
}
}
}

func (m *Model) GenerateDBModel() error {
func (m *Schema) GenerateDBModel() error {
t := template.Must(template.ParseFiles(getFilePath("model.go.tpl")))
fileName := "./internal/models/" + strings.ToLower(m.Name) + ".go"
m.readModelCustomBlock(fileName)
Expand All @@ -133,7 +187,7 @@ func (m *Model) GenerateDBModel() error {
return nil
}

func (m *Model) readModelCustomBlock(path string) {
func (m *Schema) readModelCustomBlock(path string) {
f, err := os.ReadFile(path)
if err != nil {
return
Expand All @@ -145,7 +199,7 @@ func (m *Model) readModelCustomBlock(path string) {
}
}

func (m *Model) GenerateService() error {
func (m *Schema) GenerateService() error {
t := template.Must(template.ParseFiles(getFilePath("service.go.tpl")))
fileName := "./internal/services/" + strings.ToLower(m.Name) + ".go"
outFile, err := os.Create(fileName)
Expand All @@ -156,23 +210,30 @@ func (m *Model) GenerateService() error {
return t.Execute(outFile, m)
}

func (m *Model) LowercaseName() string {
func (m *Schema) LowercaseName() string {
return strings.ToLower(m.Name)
}

func (m *Model) SnakeCaseName() string {
func (m *Schema) SnakeCaseName() string {
return strcase.ToSnake(m.Name)
}

type ModelGenerator struct {
Models []*Model
type CompleteGenerator struct {
Schemas []*Schema
Config Config
CustomMethods string
CustomMessages string
}

func (m *ModelGenerator) readCustomBlocks() {
path := "./api/iot_collector_service.proto"
f, err := os.ReadFile(path)
func (m *CompleteGenerator) getProtoServiceFilePath() string {
if m.Config.ProtoBufFileName != "" {
return "./api/" + m.Config.ProtoBufFileName
}
return fmt.Sprintf("./api/%s.proto", strcase.ToSnake(m.Config.GRPCServiceName))
}

func (m *CompleteGenerator) readCustomBlocks() {
f, err := os.ReadFile(m.getProtoServiceFilePath())
if err != nil {
return
}
Expand All @@ -187,17 +248,17 @@ func (m *ModelGenerator) readCustomBlocks() {
}
}

func (g *ModelGenerator) GenerateServiceProto() error {
func (g *CompleteGenerator) GenerateServiceProto() error {
t := template.Must(template.ParseFiles(getFilePath("service.proto.tpl")))
outFile, err := os.Create("./api/iot_collector_service.proto")
outFile, err := os.Create(g.getProtoServiceFilePath())
if err != nil {
return err
}
defer outFile.Close()
return t.Execute(outFile, g)
}

func (g *ModelGenerator) GenerateServer() error {
func (g *CompleteGenerator) GenerateServer() error {
t := template.Must(template.ParseFiles(getFilePath("server.go.tpl")))
outFile, err := os.Create("./internal/server/server_generated.go")
if err != nil {
Expand All @@ -207,15 +268,16 @@ func (g *ModelGenerator) GenerateServer() error {
return t.Execute(outFile, g)
}

func (g *ModelGenerator) GenerateFiles() error {
func (g *CompleteGenerator) GenerateFiles() error {
g.readCustomBlocks()
if err := g.GenerateServiceProto(); err != nil {
return err
}
if err := g.GenerateServer(); err != nil {
return err
}
for _, m := range g.Models {
for _, m := range g.Schemas {
m.RepositoryName = g.Config.RepositoryName
err := m.GenerateDBModel()
if err != nil {
return err
Expand All @@ -228,8 +290,29 @@ func (g *ModelGenerator) GenerateFiles() error {
return nil
}

func (g *CompleteGenerator) Generate() error {
return g.GenerateFiles()
}

func getFilePath(name string) string {
_, dir, _, _ := runtime.Caller(0)
dirName := filepath.Dir(dir)
return filepath.Join(dirName, name)
}
}

type Config struct {
ProtoBufFileName string
GRPCServiceName string
RepositoryName string
}

type Generator interface {
Generate() error
}

func NewGenerator(schemas []*Schema, config Config) Generator {
return &CompleteGenerator{
Schemas: schemas,
Config: config,
}
}
4 changes: 2 additions & 2 deletions definition/server.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ package server

import (
"context"
pb "iot-device-register/api"
pb "{{.Config.RepositoryName}}/api"
)


{{ range .Models }}
{{ range .Schemas }}
// {{.Name}} methods

{{ if .HasGet -}}
Expand Down
11 changes: 6 additions & 5 deletions definition/service.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"fmt"
"context"
"errors"
pb "iot-device-register/api"
"iot-device-register/internal/models"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"

pb "{{.RepositoryName}}/api"
"{{.RepositoryName}}/internal/models"
"github.com/Pipello/codegen/definition"
)
{{ $modelName := .Name }}
type {{$modelName}}Service struct {
Expand Down Expand Up @@ -79,7 +80,7 @@ func (s *{{$modelName}}Service) List(ctx context.Context, req *pb.List{{$modelNa
func (s *{{$modelName}}Service) Create(ctx context.Context, req *pb.Create{{$modelName}}Request) (*models.{{$modelName}}, error) {
item := models.{{$modelName}}{
{{- range $index, $field := .Fields}}
{{if not $field.Relationship}}{{$field.Name}}: req.{{$modelName}}.{{$field.GoCamelCaseName}},{{end}}
{{if not $field.Relationship}}{{$field.Name}}: {{ $field.ValueToSQL $modelName }},{{end}}
{{- end }}
}
if err := item.Validate(s.db); err != nil {
Expand Down Expand Up @@ -109,7 +110,7 @@ func (s *{{$modelName}}Service) Update(ctx context.Context, req *pb.Update{{$mod
{{- range $index, $field := .Fields }}
{{- if not $field.Relationship }}
case "{{ $field.ToSnakeCase }}":
item.{{ $field.Name }} = req.{{$modelName}}.{{ $field.GoCamelCaseName }}
item.{{ $field.Name }} = {{ $field.ValueToSQL $modelName }}
{{- end }}
{{- end }}
}
Expand Down
Loading

0 comments on commit fed350f

Please sign in to comment.