Skip to content

Commit

Permalink
Allow structs to be self-verifying.
Browse files Browse the repository at this point in the history
  • Loading branch information
nolag committed Oct 31, 2024
1 parent 2faccef commit a8c88c5
Show file tree
Hide file tree
Showing 31 changed files with 784 additions and 56 deletions.
8 changes: 8 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
tags []string
structNameFromTitle bool
minSizedInts bool
structVerify bool

errFlagFormat = errors.New("flag must be in the format URI=PACKAGE")

Expand Down Expand Up @@ -77,6 +78,7 @@ var (
Tags: tags,
OnlyModels: onlyModels,
MinSizedInts: minSizedInts,
StructVerify: structVerify,
}
for _, id := range allKeys(schemaPackageMap, schemaOutputMap, schemaRootTypeMap) {
mapping := generator.SchemaMapping{SchemaID: id}
Expand Down Expand Up @@ -174,6 +176,12 @@ also look for foo.json if --resolve-extension json is provided.`)
false,
"Uses sized int and uint values based on the min and max values for the field")

rootCmd.PersistentFlags().BoolVar(
&structVerify,
"struct-verify",
false,
"Add a Verify method to the generated struct that validates the struct against the schema")

abortWithErr(rootCmd.Execute())
}

Expand Down
1 change: 1 addition & 0 deletions pkg/generator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Config struct {
OnlyModels bool
MinSizedInts bool
Loader schemas.Loader
StructVerify bool
}

type SchemaMapping struct {
Expand Down
4 changes: 4 additions & 0 deletions pkg/generator/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func New(config Config) (*Generator, error) {
formatters = append(formatters, &yamlFormatter{})
}

if config.StructVerify {
formatters = append(formatters, &verifyFormatter{})
}

generator := &Generator{
caser: text.NewCaser(config.Capitalizations, config.ResolveExtensions),
config: config,
Expand Down
5 changes: 1 addition & 4 deletions pkg/generator/json_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,7 @@ func (jf *jsonFormatter) enumUnmarshal(
enumType.Generate(out)
out.Newline()

varName := "v"
if wrapInStruct {
varName += ".Value"
}
varName := enumVarName(wrapInStruct)

out.Printlnf("if err := json.Unmarshal(b, &%s); err != nil { return err }", varName)
out.Printlnf("var ok bool")
Expand Down
10 changes: 10 additions & 0 deletions pkg/generator/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ func isNamedType(t codegen.Type) bool {

return false
}

func enumVarName(wrapInStruct bool) string {
varName := "v"

if wrapInStruct {
varName += ".Value"
}

return varName
}
195 changes: 195 additions & 0 deletions pkg/generator/verify_formatter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package generator

import (
"strings"

"github.com/atombender/go-jsonschema/pkg/codegen"
)

type verifyFormatter struct{}

func (v verifyFormatter) addImport(_ *codegen.File) {}

func (v verifyFormatter) generate(declType codegen.TypeDecl, validators []validator) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {
var prefix string
switch declType.Type.(type) {
// No need to dereference the struct just to verify it.
case *codegen.StructType:
prefix = "*"

default:
prefix = ""
}

out.Comment("Verify checks all fields on the struct match the schema.")
out.Printlnf("func (%s %s%s) Verify() error {", varNamePlainStruct, prefix, declType.Name)
out.Indent(1)

for _, va := range validators {
desc := va.desc()
if desc.beforeJSONUnmarshal || desc.requiresRawAfter || !desc.hasError {
continue
}

va.generate(out)
}

if stct, ok := declType.Type.(*codegen.StructType); ok {
for _, field := range stct.Fields {
name := strings.ToLower(field.Name[0:1]) + field.Name[1:]
if verifyEmit := v.verifyType(field.Type, name); verifyEmit != nil {
out.Printlnf("%s := %s", name, getPlainName(field.Name))
verifyEmit(out)
}
}
}

out.Printlnf("return nil")
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyType(tpe codegen.Type, access string) func(*codegen.Emitter) {
// For some types, pointers are sometimes used and sometime not.
switch utpe := tpe.(type) {
case *codegen.ArrayType:
return v.verifyArray(*utpe, access)

case codegen.ArrayType:
return v.verifyArray(utpe, access)

case codegen.CustomNameType, *codegen.CustomNameType, codegen.NamedType, *codegen.NamedType, *codegen.StructType:
return func(out *codegen.Emitter) {
out.Printlnf("if err := %s.Verify(); err != nil {", access)
out.Indent(1)
out.Printlnf("return err")
out.Indent(-1)
out.Printlnf("}")
}

case *codegen.MapType:
return v.verifyMap(*utpe, access)

case codegen.MapType:
return v.verifyMap(utpe, access)

case *codegen.PointerType:
return v.verifyPointer(*utpe, access)

case codegen.PointerType:
return v.verifyPointer(utpe, access)

default:
return nil
}
}

func (v verifyFormatter) enumMarshal(_ codegen.TypeDecl) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {}
}

func (v verifyFormatter) enumUnmarshal(
declType codegen.TypeDecl,
_ codegen.Type,
valueConstant *codegen.Var,
wrapInStruct bool,
) func(*codegen.Emitter) {
return func(out *codegen.Emitter) {
varName := enumVarName(wrapInStruct)

out.Comment("Verify checks all fields on the struct match the schema.")
out.Printlnf("func (%s %s) Verify() error {", varNamePlainStruct, declType.Name)
out.Indent(1)
out.Printlnf("for _, expected := range %s {", valueConstant.Name)
out.Indent(1)
out.Printlnf("if reflect.DeepEqual(%s, expected) { return nil }", varName)
out.Indent(-1)
out.Printlnf("}")
out.Printlnf(`return fmt.Errorf("invalid value (expected one of %%#v): %%#v", %s, %s)`,
valueConstant.Name, varName)
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyArray(tpe codegen.ArrayType, access string) func(*codegen.Emitter) {
aaccess := "a" + access

verifyFn := v.verifyType(tpe.Type, aaccess)
if verifyFn == nil {
return nil
}

return func(out *codegen.Emitter) {
out.Printlnf("for _, %s := range %s {", aaccess, access)
out.Indent(1)
verifyFn(out)
out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyMap(tpe codegen.MapType, access string) func(*codegen.Emitter) {
keyAccess := "k" + access
valueAccess := "v" + access
verifyKeyFn := v.verifyType(tpe.KeyType, keyAccess)
verifyValueFn := v.verifyType(tpe.ValueType, valueAccess)

if verifyKeyFn == nil && verifyValueFn == nil {
return nil
}

if verifyKeyFn == nil {
keyAccess = "_"
}

if verifyValueFn == nil {
valueAccess = "_"
}

return func(out *codegen.Emitter) {
out.Printlnf("for %s, %s := range %s {", keyAccess, valueAccess, access)
out.Indent(1)

if verifyKeyFn != nil {
verifyKeyFn(out)
}

if verifyValueFn != nil {
verifyValueFn(out)
}

out.Indent(-1)
out.Printlnf("}")
}
}

func (v verifyFormatter) verifyPointer(tpe codegen.PointerType, access string) func(*codegen.Emitter) {
var prefix string
switch tpe.Type.(type) {
// Access the verify and fields without copying it.
case codegen.CustomNameType, *codegen.CustomNameType, codegen.NamedType, *codegen.NamedType:
prefix = ""

default:
prefix = "*"
}

paccess := "p" + access

verifyFn := v.verifyType(tpe.Type, paccess)
if verifyFn == nil {
return nil
}

return func(out *codegen.Emitter) {
out.Printlnf("if %s != nil {", access)
out.Printlnf("%s := %s%s", paccess, prefix, access)
out.Indent(1)
verifyFn(out)
out.Indent(-1)
out.Printlnf("}")
}
}
5 changes: 1 addition & 4 deletions pkg/generator/yaml_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,7 @@ func (yf *yamlFormatter) enumUnmarshal(
enumType.Generate(out)
out.Newline()

varName := "v"
if wrapInStruct {
varName += ".Value"
}
varName := enumVarName(wrapInStruct)

out.Printlnf("if err := value.Decode(&%s); err != nil { return err }", varName)
out.Printlnf("var ok bool")
Expand Down
Loading

0 comments on commit a8c88c5

Please sign in to comment.