diff --git a/codegen/field.go b/codegen/field.go index 74c077a71c9..00f9e1ca730 100644 --- a/codegen/field.go +++ b/codegen/field.go @@ -417,7 +417,7 @@ func (f *Field) ImplDirectives() []*Directive { loc = ast.LocationInputFieldDefinition } for i := range f.Directives { - if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) { + if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc, ast.LocationObject) { d = append(d, f.Directives[i]) } } diff --git a/codegen/testserver/directive.graphql b/codegen/testserver/directive.graphql index 50c0f804608..4005cfd8907 100644 --- a/codegen/testserver/directive.graphql +++ b/codegen/testserver/directive.graphql @@ -6,7 +6,8 @@ directive @toNull on ARGUMENT_DEFINITION | INPUT_FIELD_DEFINITION | FIELD_DEFINI directive @directive1 on FIELD_DEFINITION directive @directive2 on FIELD_DEFINITION directive @unimplemented on FIELD_DEFINITION -directive @order(location: String!) on FIELD_DEFINITION | OBJECT +directive @order1(location: String!) on FIELD_DEFINITION | OBJECT +directive @order2(location: String!) on OBJECT extend type Query { directiveArg(arg: String! @length(min:1, max: 255, message: "invalid length")): String @@ -14,7 +15,7 @@ extend type Query { directiveInputNullable(arg: InputDirectives): String directiveInput(arg: InputDirectives!): String directiveInputType(arg: InnerInput! @custom): String - directiveObject: ObjectDirectives @order(location: "Query_field") + directiveObject: ObjectDirectives @order1(location: "Query_field") directiveObjectWithCustomGoModel: ObjectDirectivesWithCustomGoModel directiveFieldDef(ret: String!): String! @length(min: 1, message: "not valid") directiveField: String @@ -41,7 +42,7 @@ input InnerDirectives { message: String! @length(min: 1, message: "not valid") } -type ObjectDirectives @order(location: "ObjectDirectives_object") { +type ObjectDirectives @order1(location: "ObjectDirectives_object_1") @order2(location: "ObjectDirectives_object_2") { text: String! @length(min: 0, max: 7, message: "not valid") nullableText: String @toNull order: [String!]! diff --git a/codegen/testserver/directive_test.go b/codegen/testserver/directive_test.go index c58ed5c2dd8..50342558fdf 100644 --- a/codegen/testserver/directive_test.go +++ b/codegen/testserver/directive_test.go @@ -160,7 +160,14 @@ func TestDirectives(t *testing.T) { Directive2: func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) { return next(ctx) }, - Order: func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) { + Order1: func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) { + order := []string{location} + res, err = next(ctx) + od := res.(*ObjectDirectives) + od.Order = append(order, od.Order...) + return od, err + }, + Order2: func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) { order := []string{location} res, err = next(ctx) od := res.(*ObjectDirectives) @@ -378,7 +385,8 @@ func TestDirectives(t *testing.T) { require.Equal(t, "Ok", resp.DirectiveObject.Text) require.True(t, resp.DirectiveObject.NullableText == nil) require.Equal(t, "Query_field", resp.DirectiveObject.Order[0]) - require.Equal(t, "ObjectDirectives_object", resp.DirectiveObject.Order[1]) + require.Equal(t, "ObjectDirectives_object_2", resp.DirectiveObject.Order[1]) + require.Equal(t, "ObjectDirectives_object_1", resp.DirectiveObject.Order[2]) }) t.Run("when directive returns nil & custom go field is not nilable", func(t *testing.T) { var resp struct { diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 6eab307c39c..154059b9ffd 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -62,7 +62,8 @@ type DirectiveRoot struct { Logged func(ctx context.Context, obj interface{}, next graphql.Resolver, id string) (res interface{}, err error) MakeNil func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) MakeTypedNil func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) - Order func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) + Order1 func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) + Order2 func(ctx context.Context, obj interface{}, next graphql.Resolver, location string) (res interface{}, err error) Range func(ctx context.Context, obj interface{}, next graphql.Resolver, min *int, max *int) (res interface{}, err error) ToNull func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) Unimplemented func(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) @@ -1793,7 +1794,8 @@ directive @toNull on ARGUMENT_DEFINITION | INPUT_FIELD_DEFINITION | FIELD_DEFINI directive @directive1 on FIELD_DEFINITION directive @directive2 on FIELD_DEFINITION directive @unimplemented on FIELD_DEFINITION -directive @order(location: String!) on FIELD_DEFINITION | OBJECT +directive @order1(location: String!) on FIELD_DEFINITION | OBJECT +directive @order2(location: String!) on OBJECT extend type Query { directiveArg(arg: String! @length(min:1, max: 255, message: "invalid length")): String @@ -1801,7 +1803,7 @@ extend type Query { directiveInputNullable(arg: InputDirectives): String directiveInput(arg: InputDirectives!): String directiveInputType(arg: InnerInput! @custom): String - directiveObject: ObjectDirectives @order(location: "Query_field") + directiveObject: ObjectDirectives @order1(location: "Query_field") directiveObjectWithCustomGoModel: ObjectDirectivesWithCustomGoModel directiveFieldDef(ret: String!): String! @length(min: 1, message: "not valid") directiveField: String @@ -1828,7 +1830,7 @@ input InnerDirectives { message: String! @length(min: 1, message: "not valid") } -type ObjectDirectives @order(location: "ObjectDirectives_object") { +type ObjectDirectives @order1(location: "ObjectDirectives_object_1") @order2(location: "ObjectDirectives_object_2") { text: String! @length(min: 0, max: 7, message: "not valid") nullableText: String @toNull order: [String!]! @@ -2330,7 +2332,21 @@ func (ec *executionContext) dir_logged_args(ctx context.Context, rawArgs map[str return args, nil } -func (ec *executionContext) dir_order_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { +func (ec *executionContext) dir_order1_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["location"]; ok { + arg0, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["location"] = arg0 + return args, nil +} + +func (ec *executionContext) dir_order2_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} var arg0 string @@ -6154,27 +6170,37 @@ func (ec *executionContext) _Query_directiveObject(ctx context.Context, field gr return ec.resolvers.Query().DirectiveObject(rctx) } directive1 := func(ctx context.Context) (interface{}, error) { - location, err := ec.unmarshalNString2string(ctx, "ObjectDirectives_object") + location, err := ec.unmarshalNString2string(ctx, "ObjectDirectives_object_1") if err != nil { return nil, err } - if ec.directives.Order == nil { - return nil, errors.New("directive order is not implemented") + if ec.directives.Order1 == nil { + return nil, errors.New("directive order1 is not implemented") } - return ec.directives.Order(ctx, nil, directive0, location) + return ec.directives.Order1(ctx, nil, directive0, location) } directive2 := func(ctx context.Context) (interface{}, error) { + location, err := ec.unmarshalNString2string(ctx, "ObjectDirectives_object_2") + if err != nil { + return nil, err + } + if ec.directives.Order2 == nil { + return nil, errors.New("directive order2 is not implemented") + } + return ec.directives.Order2(ctx, nil, directive1, location) + } + directive3 := func(ctx context.Context) (interface{}, error) { location, err := ec.unmarshalNString2string(ctx, "Query_field") if err != nil { return nil, err } - if ec.directives.Order == nil { - return nil, errors.New("directive order is not implemented") + if ec.directives.Order1 == nil { + return nil, errors.New("directive order1 is not implemented") } - return ec.directives.Order(ctx, nil, directive1, location) + return ec.directives.Order1(ctx, nil, directive2, location) } - tmp, err := directive2(rctx) + tmp, err := directive3(rctx) if err != nil { return nil, err } diff --git a/example/type-system-extension/generated.go b/example/type-system-extension/generated.go index 318ff3baec1..17e6d08a18d 100644 --- a/example/type-system-extension/generated.go +++ b/example/type-system-extension/generated.go @@ -396,8 +396,28 @@ func (ec *executionContext) _MyMutation_createTodo(ctx context.Context, field gr } fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.MyMutation().CreateTodo(rctx, args["todo"].(TodoInput)) + directive0 := func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.MyMutation().CreateTodo(rctx, args["todo"].(TodoInput)) + } + directive1 := func(ctx context.Context) (interface{}, error) { + if ec.directives.ObjectLogging == nil { + return nil, errors.New("directive objectLogging is not implemented") + } + return ec.directives.ObjectLogging(ctx, nil, directive0) + } + + tmp, err := directive1(rctx) + if err != nil { + return nil, err + } + if tmp == nil { + return nil, nil + } + if data, ok := tmp.(*Todo); ok { + return data, nil + } + return nil, fmt.Errorf(`unexpected type %T from directive, should be *github.com/99designs/gqlgen/example/type-system-extension.Todo`, tmp) }) if err != nil { ec.Error(ctx, err) @@ -430,8 +450,28 @@ func (ec *executionContext) _MyQuery_todos(ctx context.Context, field graphql.Co ctx = graphql.WithFieldContext(ctx, fc) resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.MyQuery().Todos(rctx) + directive0 := func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.MyQuery().Todos(rctx) + } + directive1 := func(ctx context.Context) (interface{}, error) { + if ec.directives.ObjectLogging == nil { + return nil, errors.New("directive objectLogging is not implemented") + } + return ec.directives.ObjectLogging(ctx, nil, directive0) + } + + tmp, err := directive1(rctx) + if err != nil { + return nil, err + } + if tmp == nil { + return nil, nil + } + if data, ok := tmp.([]*Todo); ok { + return data, nil + } + return nil, fmt.Errorf(`unexpected type %T from directive, should be []*github.com/99designs/gqlgen/example/type-system-extension.Todo`, tmp) }) if err != nil { ec.Error(ctx, err) @@ -471,8 +511,28 @@ func (ec *executionContext) _MyQuery_todo(ctx context.Context, field graphql.Col } fc.Args = args resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.MyQuery().Todo(rctx, args["id"].(string)) + directive0 := func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.MyQuery().Todo(rctx, args["id"].(string)) + } + directive1 := func(ctx context.Context) (interface{}, error) { + if ec.directives.ObjectLogging == nil { + return nil, errors.New("directive objectLogging is not implemented") + } + return ec.directives.ObjectLogging(ctx, nil, directive0) + } + + tmp, err := directive1(rctx) + if err != nil { + return nil, err + } + if tmp == nil { + return nil, nil + } + if data, ok := tmp.(*Todo); ok { + return data, nil + } + return nil, fmt.Errorf(`unexpected type %T from directive, should be *github.com/99designs/gqlgen/example/type-system-extension.Todo`, tmp) }) if err != nil { ec.Error(ctx, err)