Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optional return pointers in unmarshalInput #2397

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,16 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {

// TypeReference is used by args and field types. The Definition can refer to both input and output types.
type TypeReference struct {
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
PointersInUmarshalInput bool // Inverse values and pointers in return.
}

func (ref *TypeReference) Elem() *TypeReference {
Expand Down Expand Up @@ -412,6 +413,8 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
ref.GO = bindTarget
}

ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput

return ref, nil
}

Expand Down
2 changes: 2 additions & 0 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Config struct {
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
OmitGetters bool `yaml:"omit_getters,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Expand All @@ -50,6 +51,7 @@ func DefaultConfig() *Config {
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ReturnPointersInUmarshalInput: false,
ResolversAlwaysReturnPointers: true,
}
}
Expand Down
20 changes: 12 additions & 8 deletions codegen/input.gotpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{{- range $input := .Inputs }}
{{- if not .HasUnmarshal }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{.Type | ref}}, error) {
{{- $it := "it" }}
{{- if .PointersInUmarshalInput }}
{{- $it = "&it" }}
{{- end }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{ if .PointersInUmarshalInput }}*{{ end }}{{.Type | ref}}, error) {
var it {{.Type | ref}}
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
Expand Down Expand Up @@ -31,12 +35,12 @@
{{ template "implDirectives" $field }}
tmp, err := directive{{$field.ImplDirectives|len}}(ctx)
if err != nil {
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok {
{{- if $field.IsResolver }}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}} = data
Expand All @@ -49,29 +53,29 @@
{{- end }}
} else {
err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp)
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
{{- else }}
{{- if $field.IsResolver }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}}, err = ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
{{- end }}
{{- end }}
{{- end }}
}
}

return it, nil
return {{$it}}, nil
}
{{- end }}
{{ end }}
28 changes: 15 additions & 13 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ const (
type Object struct {
*ast.Definition

Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
PointersInUmarshalInput bool
}

func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
Expand All @@ -42,11 +43,12 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
}
caser := cases.Title(language.English, cases.NoLower)
obj := &Object{
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
PointersInUmarshalInput: b.Config.ReturnPointersInUmarshalInput,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), caser.String(typ.Name)+"Resolver", nil),
nil,
Expand Down
6 changes: 4 additions & 2 deletions codegen/type.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@
return res, graphql.ErrorOnPath(ctx, err)
{{- else }}
res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v)
{{- if $type.IsNilable }}
{{- if and $type.IsNilable (not $type.PointersInUmarshalInput) }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
{{- else if and (not $type.IsNilable) $type.PointersInUmarshalInput }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else }}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- end }}
Expand Down
3 changes: 3 additions & 0 deletions docs/content/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ resolver:
# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true

# Optional: turn on to return pointers instead of values in unmarshalInput
# return_pointers_in_unmarshalinput: false

# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true

Expand Down