diff --git a/codegen/config/config.go b/codegen/config/config.go index 54776083d6d..eca0f5a3d98 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -320,24 +320,33 @@ func (c *Config) injectTypesFromSchema() error { } } - if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject { + if schemaType.Kind == ast.Object || + schemaType.Kind == ast.InputObject || + schemaType.Kind == ast.Interface { for _, field := range schemaType.Fields { if fd := field.Directives.ForName("goField"); fd != nil { forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver - fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName - if ra := fd.Arguments.ForName("forceResolver"); ra != nil { if fr, err := ra.Value.Value(nil); err == nil { forceResolver = fr.(bool) } } + fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName if na := fd.Arguments.ForName("name"); na != nil { if fr, err := na.Value.Value(nil); err == nil { fieldName = fr.(string) } } + omittable := c.Models[schemaType.Name].Fields[field.Name].Omittable + if arg := fd.Arguments.ForName("omittable"); arg != nil { + if k, err := arg.Value.Value(nil); err == nil { + val := k.(bool) + omittable = &val + } + } + if c.Models[schemaType.Name].Fields == nil { c.Models[schemaType.Name] = TypeMapEntry{ Model: c.Models[schemaType.Name].Model, @@ -349,6 +358,7 @@ func (c *Config) injectTypesFromSchema() error { c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{ FieldName: fieldName, Resolver: forceResolver, + Omittable: omittable, } } } @@ -449,6 +459,7 @@ type TypeMapEntry struct { type TypeMapField struct { Resolver bool `yaml:"resolver"` FieldName string `yaml:"fieldName"` + Omittable *bool `yaml:"omittable"` GeneratedMethod string `yaml:"-"` } diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index 0ab8902c2d0..660e3537672 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -26,11 +26,6 @@ type ( // DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook. func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { - var err error - f, err = GoFieldHook(td, fd, f) - if err != nil { - return f, err - } return GoTagFieldHook(td, fd, f) } @@ -337,117 +332,139 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) binder := cfg.NewBinder() fields := make([]*Field, 0) - var omittableType types.Type - for _, field := range schemaType.Fields { - var typ types.Type - fieldDef := cfg.Schema.Types[field.Type.Name()] - - if cfg.Models.UserDefined(field.Type.Name()) { - var err error - typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) - if err != nil { - return nil, err - } - } else { - switch fieldDef.Kind { - case ast.Scalar: - // no user defined model, referencing a default scalar - typ = types.NewNamed( - types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), - nil, - nil, - ) - - case ast.Interface, ast.Union: - // no user defined model, referencing a generated interface type - typ = types.NewNamed( - types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), - types.NewInterfaceType([]*types.Func{}, []types.Type{}), - nil, - ) - - case ast.Enum: - // no user defined model, must reference a generated enum - typ = types.NewNamed( - types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), - nil, - nil, - ) - - case ast.Object, ast.InputObject: - // no user defined model, must reference a generated struct - typ = types.NewNamed( - types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), - types.NewStruct(nil, nil), - nil, - ) - - default: - panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) - } + f, err := m.generateField(cfg, binder, schemaType, field) + if err != nil { + return nil, err } - name := templates.ToGo(field.Name) - if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" { - name = nameOverride + if f == nil { + continue } - typ = binder.CopyModifiersFromAst(field.Type, typ) + fields = append(fields, f) + } - if cfg.StructFieldsAlwaysPointers { - if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { - typ = types.NewPointer(typ) - } + fields = append(fields, getExtraFields(cfg, schemaType.Name)...) + + return fields, nil +} + +func (m *Plugin) generateField( + cfg *config.Config, + binder *config.Binder, + schemaType *ast.Definition, + field *ast.FieldDefinition, +) (*Field, error) { + var omittableType types.Type + var typ types.Type + fieldDef := cfg.Schema.Types[field.Type.Name()] + + if cfg.Models.UserDefined(field.Type.Name()) { + var err error + typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0]) + if err != nil { + return nil, err } + } else { + switch fieldDef.Kind { + case ast.Scalar: + // no user defined model, referencing a default scalar + typ = types.NewNamed( + types.NewTypeName(0, cfg.Model.Pkg(), "string", nil), + nil, + nil, + ) + + case ast.Interface, ast.Union: + // no user defined model, referencing a generated interface type + typ = types.NewNamed( + types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), + types.NewInterfaceType([]*types.Func{}, []types.Type{}), + nil, + ) + + case ast.Enum: + // no user defined model, must reference a generated enum + typ = types.NewNamed( + types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), + nil, + nil, + ) - f := &Field{ - Name: field.Name, - GoName: name, - Type: typ, - Description: field.Description, - Tag: getStructTagFromField(cfg, field), - Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull, + case ast.Object, ast.InputObject: + // no user defined model, must reference a generated struct + typ = types.NewNamed( + types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil), + types.NewStruct(nil, nil), + nil, + ) + + default: + panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind)) } + } - if m.FieldHook != nil { - mf, err := m.FieldHook(schemaType, field, f) - if err != nil { - return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) - } - f = mf + name := templates.ToGo(field.Name) + if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" { + name = nameOverride + } + + typ = binder.CopyModifiersFromAst(field.Type, typ) + + if cfg.StructFieldsAlwaysPointers { + if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) { + typ = types.NewPointer(typ) } + } - if f.IsResolver && cfg.OmitResolverFields { - continue + f := &Field{ + Name: field.Name, + GoName: name, + Type: typ, + Description: field.Description, + Tag: getStructTagFromField(cfg, field), + Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull, + IsResolver: cfg.Models[schemaType.Name].Fields[field.Name].Resolver, + } + + if omittable := cfg.Models[schemaType.Name].Fields[field.Name].Omittable; omittable != nil { + f.Omittable = *omittable + } + + if m.FieldHook != nil { + mf, err := m.FieldHook(schemaType, field, f) + if err != nil { + return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) } + f = mf + } - if f.Omittable { - if schemaType.Kind != ast.InputObject || field.Type.NonNull { - return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name) - } + if f.IsResolver && cfg.OmitResolverFields { + return nil, nil + } - var err error + if f.Omittable { + if schemaType.Kind != ast.InputObject || field.Type.NonNull { + return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name) + } - if omittableType == nil { - omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable") - if err != nil { - return nil, err - } - } + var err error - f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type}) + if omittableType == nil { + omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable") if err != nil { - return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) + return nil, err } } - fields = append(fields, f) + f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type}) + if err != nil { + return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err) + } } - fields = append(fields, getExtraFields(cfg, schemaType.Name)...) - - return fields, nil + return f, nil } func getExtraFields(cfg *config.Config, modelName string) []*Field { @@ -636,29 +653,9 @@ func removeDuplicateTags(t string) string { return returnTags } -// GoFieldHook applies the goField directive to the generated Field f. +// GoFieldHook is a noop +// TODO: This will be removed in the next breaking release func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) { - args := make([]string, 0) - _ = args - for _, goField := range fd.Directives.ForNames("goField") { - if arg := goField.Arguments.ForName("name"); arg != nil { - if k, err := arg.Value.Value(nil); err == nil { - f.GoName = k.(string) - } - } - - if arg := goField.Arguments.ForName("forceResolver"); arg != nil { - if k, err := arg.Value.Value(nil); err == nil { - f.IsResolver = k.(bool) - } - } - - if arg := goField.Arguments.ForName("omittable"); arg != nil { - if k, err := arg.Value.Value(nil); err == nil { - f.Omittable = k.(bool) - } - } - } return f, nil }