Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: don't extract @goField twice #3234

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,24 +320,33 @@ func (c *Config) injectTypesFromSchema() error {
}
}

if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
if schemaType.Kind == ast.Object ||
schemaType.Kind == ast.InputObject ||
schemaType.Kind == ast.Interface {
for _, field := range schemaType.Fields {
if fd := field.Directives.ForName("goField"); fd != nil {
forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName

if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
if fr, err := ra.Value.Value(nil); err == nil {
forceResolver = fr.(bool)
}
}

fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
if na := fd.Arguments.ForName("name"); na != nil {
if fr, err := na.Value.Value(nil); err == nil {
fieldName = fr.(string)
}
}

omittable := c.Models[schemaType.Name].Fields[field.Name].Omittable
if arg := fd.Arguments.ForName("omittable"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
val := k.(bool)
omittable = &val
}
}

if c.Models[schemaType.Name].Fields == nil {
c.Models[schemaType.Name] = TypeMapEntry{
Model: c.Models[schemaType.Name].Model,
Expand All @@ -349,6 +358,7 @@ func (c *Config) injectTypesFromSchema() error {
c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
FieldName: fieldName,
Resolver: forceResolver,
Omittable: omittable,
}
}
}
Expand Down Expand Up @@ -449,6 +459,7 @@ type TypeMapEntry struct {
type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
Omittable *bool `yaml:"omittable"`
GeneratedMethod string `yaml:"-"`
}

Expand Down
227 changes: 112 additions & 115 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ type (

// DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
var err error
f, err = GoFieldHook(td, fd, f)
if err != nil {
return f, err
}
return GoTagFieldHook(td, fd, f)
}

Expand Down Expand Up @@ -337,117 +332,139 @@ func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition)
binder := cfg.NewBinder()
fields := make([]*Field, 0)

var omittableType types.Type

for _, field := range schemaType.Fields {
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]

if cfg.Models.UserDefined(field.Type.Name()) {
var err error
typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
if err != nil {
return nil, err
}
} else {
switch fieldDef.Kind {
case ast.Scalar:
// no user defined model, referencing a default scalar
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
nil,
nil,
)

case ast.Interface, ast.Union:
// no user defined model, referencing a generated interface type
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewInterfaceType([]*types.Func{}, []types.Type{}),
nil,
)

case ast.Enum:
// no user defined model, must reference a generated enum
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
nil,
nil,
)

case ast.Object, ast.InputObject:
// no user defined model, must reference a generated struct
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewStruct(nil, nil),
nil,
)

default:
panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
}
f, err := m.generateField(cfg, binder, schemaType, field)
if err != nil {
return nil, err
}

name := templates.ToGo(field.Name)
if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" {
name = nameOverride
if f == nil {
continue
}

typ = binder.CopyModifiersFromAst(field.Type, typ)
fields = append(fields, f)
}

if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
}

func (m *Plugin) generateField(
cfg *config.Config,
binder *config.Binder,
schemaType *ast.Definition,
field *ast.FieldDefinition,
) (*Field, error) {
var omittableType types.Type
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]

if cfg.Models.UserDefined(field.Type.Name()) {
var err error
typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
if err != nil {
return nil, err
}
} else {
switch fieldDef.Kind {
case ast.Scalar:
// no user defined model, referencing a default scalar
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
nil,
nil,
)

case ast.Interface, ast.Union:
// no user defined model, referencing a generated interface type
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewInterfaceType([]*types.Func{}, []types.Type{}),
nil,
)

case ast.Enum:
// no user defined model, must reference a generated enum
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
nil,
nil,
)

f := &Field{
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: getStructTagFromField(cfg, field),
Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
case ast.Object, ast.InputObject:
// no user defined model, must reference a generated struct
typ = types.NewNamed(
types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
types.NewStruct(nil, nil),
nil,
)

default:
panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
}
}

if m.FieldHook != nil {
mf, err := m.FieldHook(schemaType, field, f)
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
f = mf
name := templates.ToGo(field.Name)
if nameOverride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOverride != "" {
name = nameOverride
}

typ = binder.CopyModifiersFromAst(field.Type, typ)

if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
}

if f.IsResolver && cfg.OmitResolverFields {
continue
f := &Field{
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: getStructTagFromField(cfg, field),
Omittable: cfg.NullableInputOmittable && schemaType.Kind == ast.InputObject && !field.Type.NonNull,
IsResolver: cfg.Models[schemaType.Name].Fields[field.Name].Resolver,
}

if omittable := cfg.Models[schemaType.Name].Fields[field.Name].Omittable; omittable != nil {
f.Omittable = *omittable
}

if m.FieldHook != nil {
mf, err := m.FieldHook(schemaType, field, f)
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
f = mf
}

if f.Omittable {
if schemaType.Kind != ast.InputObject || field.Type.NonNull {
return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
}
if f.IsResolver && cfg.OmitResolverFields {
return nil, nil
}

var err error
if f.Omittable {
if schemaType.Kind != ast.InputObject || field.Type.NonNull {
return nil, fmt.Errorf("generror: field %v.%v: omittable is only applicable to nullable input fields", schemaType.Name, field.Name)
}

if omittableType == nil {
omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
if err != nil {
return nil, err
}
}
var err error

f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
if omittableType == nil {
omittableType, err = binder.FindTypeFromName("github.com/99designs/gqlgen/graphql.Omittable")
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
return nil, err
}
}

fields = append(fields, f)
f.Type, err = binder.InstantiateType(omittableType, []types.Type{f.Type})
if err != nil {
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
}

fields = append(fields, getExtraFields(cfg, schemaType.Name)...)

return fields, nil
return f, nil
}

func getExtraFields(cfg *config.Config, modelName string) []*Field {
Expand Down Expand Up @@ -636,29 +653,9 @@ func removeDuplicateTags(t string) string {
return returnTags
}

// GoFieldHook applies the goField directive to the generated Field f.
// GoFieldHook is a noop
// TODO: This will be removed in the next breaking release
func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
args := make([]string, 0)
_ = args
for _, goField := range fd.Directives.ForNames("goField") {
if arg := goField.Arguments.ForName("name"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.GoName = k.(string)
}
}

if arg := goField.Arguments.ForName("forceResolver"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.IsResolver = k.(bool)
}
}

if arg := goField.Arguments.ForName("omittable"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.Omittable = k.(bool)
}
}
}
return f, nil
}

Expand Down
Loading