From 5ebd157c798c7742656f802f9ecad5ada6c29183 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 22 Feb 2018 20:57:37 +1100 Subject: [PATCH] Only use one gofunc per subscription --- codegen/templates/field.gotpl | 18 ++++++------ codegen/templates/file.gotpl | 50 ++++++++++++++++++---------------- codegen/templates/object.gotpl | 2 +- graphql/exec.go | 2 +- graphql/oneshot.go | 14 ++++++++++ graphql/response.go | 21 ++------------ handler/graphql.go | 20 ++++++++++---- handler/stub.go | 34 ++++++++--------------- handler/websocket.go | 13 ++++++--- 9 files changed, 88 insertions(+), 86 deletions(-) create mode 100644 graphql/oneshot.go diff --git a/codegen/templates/field.gotpl b/codegen/templates/field.gotpl index 0ca8f55e678..23747c24a89 100644 --- a/codegen/templates/field.gotpl +++ b/codegen/templates/field.gotpl @@ -2,22 +2,22 @@ {{ $object := $field.Object }} {{- if $object.Stream }} - func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) <-chan graphql.Marshaler { - channel := make(chan graphql.Marshaler, 1) + func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler { {{- template "args.gotpl" $field.Args }} results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }}) if err != nil { ec.Error(err) return nil } - go func() { - for res := range results { - var out graphql.OrderedMap - out.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }()) - channel <- &out + return func() graphql.Marshaler { + res, ok := <-results + if !ok { + return nil } - }() - return channel + var out graphql.OrderedMap + out.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }()) + return &out + } } {{ else }} func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler { diff --git a/codegen/templates/file.gotpl b/codegen/templates/file.gotpl index 5241caa72c9..61d1f9b0509 100644 --- a/codegen/templates/file.gotpl +++ b/codegen/templates/file.gotpl @@ -37,9 +37,11 @@ func (e *executableSchema) Query(ctx context.Context, doc *query.Document, varia ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx} data := ec._{{.QueryRoot.GQLType}}(op.Selections) + var buf bytes.Buffer + data.MarshalGQL(&buf) return &graphql.Response{ - Data: data, + Data: buf.Bytes(), Errors: ec.Errors, } {{- else }} @@ -52,9 +54,11 @@ func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, va ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx} data := ec._{{.MutationRoot.GQLType}}(op.Selections) + var buf bytes.Buffer + data.MarshalGQL(&buf) return &graphql.Response{ - Data: data, + Data: buf.Bytes(), Errors: ec.Errors, } {{- else }} @@ -62,35 +66,33 @@ func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, va {{- end }} } -func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *graphql.Response { +func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response { {{- if .SubscriptionRoot }} - events := make(chan *graphql.Response, 10) - ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx} - eventData := ec._{{.SubscriptionRoot.GQLType}}(op.Selections) + next := ec._{{.SubscriptionRoot.GQLType}}(op.Selections) if ec.Errors != nil { - events<-&graphql.Response{ - Data: graphql.Null, - Errors: ec.Errors, + return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors}) + } + + var buf bytes.Buffer + return func() *graphql.Response { + buf.Reset() + data := next() + if data == nil { + return nil + } + data.MarshalGQL(&buf) + + errs := ec.Errors + ec.Errors = nil + return &graphql.Response{ + Data: buf.Bytes(), + Errors: errs, } - close(events) - } else { - go func() { - for data := range eventData { - events <- &graphql.Response{ - Data: data, - Errors: ec.Errors, - } - time.Sleep(20 * time.Millisecond) - } - }() } - return events {{- else }} - events := make(chan *graphql.Response, 1) - events<-&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }} - return events + return graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }}) {{- end }} } diff --git a/codegen/templates/object.gotpl b/codegen/templates/object.gotpl index 3bdbbf8a32c..28b56a1c684 100644 --- a/codegen/templates/object.gotpl +++ b/codegen/templates/object.gotpl @@ -4,7 +4,7 @@ var {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}} // nolint: gocyclo, errcheck, gas, goconst {{- if .Stream }} -func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) <-chan graphql.Marshaler { +func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() graphql.Marshaler { fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables) if len(fields) != 1 { diff --git a/graphql/exec.go b/graphql/exec.go index 451b465221d..93ddb80f09e 100644 --- a/graphql/exec.go +++ b/graphql/exec.go @@ -13,7 +13,7 @@ type ExecutableSchema interface { Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *Response Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *Response - Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *Response + Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) func() *Response } func CollectFields(doc *query.Document, selSet []query.Selection, satisfies []string, variables map[string]interface{}) []CollectedField { diff --git a/graphql/oneshot.go b/graphql/oneshot.go new file mode 100644 index 00000000000..dd31f5baa79 --- /dev/null +++ b/graphql/oneshot.go @@ -0,0 +1,14 @@ +package graphql + +func OneShot(resp *Response) func() *Response { + var oneshot bool + + return func() *Response { + if oneshot { + return nil + } + oneshot = true + + return resp + } +} diff --git a/graphql/response.go b/graphql/response.go index 089fbaea4ee..c09d2f0ed7d 100644 --- a/graphql/response.go +++ b/graphql/response.go @@ -1,27 +1,12 @@ package graphql import ( - "io" + "encoding/json" "github.com/vektah/gqlgen/neelance/errors" ) type Response struct { - Data Marshaler - Errors []*errors.QueryError -} - -func (r *Response) MarshalGQL(w io.Writer) { - result := &OrderedMap{} - if r.Data == nil { - result.Add("data", Null) - } else { - result.Add("data", r.Data) - } - - if len(r.Errors) > 0 { - result.Add("errors", MarshalErrors(r.Errors)) - } - - result.MarshalGQL(w) + Data json.RawMessage `json:"data"` + Errors []*errors.QueryError `json:"errors,omitempty"` } diff --git a/handler/graphql.go b/handler/graphql.go index 67c6ed8ecdb..ddd71721a0b 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -90,9 +90,17 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc switch op.Type { case query.Query: - exec.Query(r.Context(), doc, params.Variables, op).MarshalGQL(w) + b, err := json.Marshal(exec.Query(r.Context(), doc, params.Variables, op)) + if err != nil { + panic(err) + } + w.Write(b) case query.Mutation: - exec.Mutation(r.Context(), doc, params.Variables, op).MarshalGQL(w) + b, err := json.Marshal(exec.Mutation(r.Context(), doc, params.Variables, op)) + if err != nil { + panic(err) + } + w.Write(b) default: sendErrorf(w, http.StatusBadRequest, "unsupported operation type") } @@ -101,11 +109,11 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc func sendError(w http.ResponseWriter, code int, errs ...*errors.QueryError) { w.WriteHeader(code) - - resp := &graphql.Response{ - Errors: errs, + b, err := json.Marshal(&graphql.Response{Errors: errs}) + if err != nil { + panic(err) } - resp.MarshalGQL(w) + w.Write(b) } func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { diff --git a/handler/stub.go b/handler/stub.go index 0afae2d3525..e3d445e8fc6 100644 --- a/handler/stub.go +++ b/handler/stub.go @@ -24,10 +24,7 @@ func (e *executableSchemaStub) Schema() *schema.Schema { } func (e *executableSchemaStub) Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response { - data := graphql.OrderedMap{} - data.Add("name", graphql.MarshalString("test")) - - return &graphql.Response{Data: &data} + return &graphql.Response{Data: []byte(`{"name":"test"}`)} } func (e *executableSchemaStub) Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response { @@ -36,25 +33,16 @@ func (e *executableSchemaStub) Mutation(ctx context.Context, document *query.Doc } } -func (e *executableSchemaStub) Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *graphql.Response { - events := make(chan *graphql.Response, 0) - - go func() { - for { - select { - case <-ctx.Done(): - close(events) - return - default: - data := graphql.OrderedMap{} - data.Add("name", graphql.MarshalString("test")) - - events <- &graphql.Response{ - Data: &data, - } +func (e *executableSchemaStub) Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response { + return func() *graphql.Response { + time.Sleep(20 * time.Millisecond) + select { + case <-ctx.Done(): + return nil + default: + return &graphql.Response{ + Data: []byte(`{"name":"test"}`), } - time.Sleep(20 * time.Millisecond) } - }() - return events + } } diff --git a/handler/websocket.go b/handler/websocket.go index e43eeaf2b61..ce22e803ec2 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -171,7 +171,9 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { c.active[message.ID] = cancel c.mu.Unlock() go func() { - for result := range c.exec.Subscription(ctx, doc, params.Variables, op) { + next := c.exec.Subscription(ctx, doc, params.Variables, op) + for result := next(); result != nil; result = next() { + fmt.Println(result) c.sendData(message.ID, result) } @@ -187,10 +189,13 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { } func (c *wsConnection) sendData(id string, response *graphql.Response) { - var b bytes.Buffer - response.MarshalGQL(&b) + b, err := json.Marshal(response) + if err != nil { + c.sendError(id, errors.Errorf("unable to encode json response: %s", err.Error())) + return + } - c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b.Bytes()}) + c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b}) } func (c *wsConnection) sendError(id string, errors ...*errors.QueryError) {