Skip to content

Commit

Permalink
Only use one gofunc per subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 22, 2018
1 parent 79a7037 commit 5ebd157
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 86 deletions.
18 changes: 9 additions & 9 deletions codegen/templates/field.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
{{ $object := $field.Object }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) <-chan graphql.Marshaler {
channel := make(chan graphql.Marshaler, 1)
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}
results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return nil
}
go func() {
for res := range results {
var out graphql.OrderedMap
out.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())
channel <- &out
return func() graphql.Marshaler {
res, ok := <-results
if !ok {
return nil
}
}()
return channel
var out graphql.OrderedMap
out.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())
return &out
}
}
{{ else }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {
Expand Down
50 changes: 26 additions & 24 deletions codegen/templates/file.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ func (e *executableSchema) Query(ctx context.Context, doc *query.Document, varia
ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

data := ec._{{.QueryRoot.GQLType}}(op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

return &graphql.Response{
Data: data,
Data: buf.Bytes(),
Errors: ec.Errors,
}
{{- else }}
Expand All @@ -52,45 +54,45 @@ func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, va
ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

data := ec._{{.MutationRoot.GQLType}}(op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

return &graphql.Response{
Data: data,
Data: buf.Bytes(),
Errors: ec.Errors,
}
{{- else }}
return &graphql.Response{Errors: []*errors.QueryError{ {Message: "mutations are not supported"} }}
{{- end }}
}

func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *graphql.Response {
func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response {
{{- if .SubscriptionRoot }}
events := make(chan *graphql.Response, 10)

ec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx}

eventData := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)
next := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)
if ec.Errors != nil {
events<-&graphql.Response{
Data: graphql.Null,
Errors: ec.Errors,
return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors})
}

var buf bytes.Buffer
return func() *graphql.Response {
buf.Reset()
data := next()
if data == nil {
return nil
}
data.MarshalGQL(&buf)

errs := ec.Errors
ec.Errors = nil
return &graphql.Response{
Data: buf.Bytes(),
Errors: errs,
}
close(events)
} else {
go func() {
for data := range eventData {
events <- &graphql.Response{
Data: data,
Errors: ec.Errors,
}
time.Sleep(20 * time.Millisecond)
}
}()
}
return events
{{- else }}
events := make(chan *graphql.Response, 1)
events<-&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }}
return events
return graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: "subscriptions are not supported"} }})
{{- end }}
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/object.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ var {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}}

// nolint: gocyclo, errcheck, gas, goconst
{{- if .Stream }}
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) <-chan graphql.Marshaler {
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)

if len(fields) != 1 {
Expand Down
2 changes: 1 addition & 1 deletion graphql/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type ExecutableSchema interface {

Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *Response
Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *Response
Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *Response
Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) func() *Response
}

func CollectFields(doc *query.Document, selSet []query.Selection, satisfies []string, variables map[string]interface{}) []CollectedField {
Expand Down
14 changes: 14 additions & 0 deletions graphql/oneshot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package graphql

func OneShot(resp *Response) func() *Response {
var oneshot bool

return func() *Response {
if oneshot {
return nil
}
oneshot = true

return resp
}
}
21 changes: 3 additions & 18 deletions graphql/response.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
package graphql

import (
"io"
"encoding/json"

"github.com/vektah/gqlgen/neelance/errors"
)

type Response struct {
Data Marshaler
Errors []*errors.QueryError
}

func (r *Response) MarshalGQL(w io.Writer) {
result := &OrderedMap{}
if r.Data == nil {
result.Add("data", Null)
} else {
result.Add("data", r.Data)
}

if len(r.Errors) > 0 {
result.Add("errors", MarshalErrors(r.Errors))
}

result.MarshalGQL(w)
Data json.RawMessage `json:"data"`
Errors []*errors.QueryError `json:"errors,omitempty"`
}
20 changes: 14 additions & 6 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc

switch op.Type {
case query.Query:
exec.Query(r.Context(), doc, params.Variables, op).MarshalGQL(w)
b, err := json.Marshal(exec.Query(r.Context(), doc, params.Variables, op))
if err != nil {
panic(err)
}
w.Write(b)
case query.Mutation:
exec.Mutation(r.Context(), doc, params.Variables, op).MarshalGQL(w)
b, err := json.Marshal(exec.Mutation(r.Context(), doc, params.Variables, op))
if err != nil {
panic(err)
}
w.Write(b)
default:
sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
}
Expand All @@ -101,11 +109,11 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc

func sendError(w http.ResponseWriter, code int, errs ...*errors.QueryError) {
w.WriteHeader(code)

resp := &graphql.Response{
Errors: errs,
b, err := json.Marshal(&graphql.Response{Errors: errs})
if err != nil {
panic(err)
}
resp.MarshalGQL(w)
w.Write(b)
}

func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
Expand Down
34 changes: 11 additions & 23 deletions handler/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ func (e *executableSchemaStub) Schema() *schema.Schema {
}

func (e *executableSchemaStub) Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response {
data := graphql.OrderedMap{}
data.Add("name", graphql.MarshalString("test"))

return &graphql.Response{Data: &data}
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
}

func (e *executableSchemaStub) Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) *graphql.Response {
Expand All @@ -36,25 +33,16 @@ func (e *executableSchemaStub) Mutation(ctx context.Context, document *query.Doc
}
}

func (e *executableSchemaStub) Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) <-chan *graphql.Response {
events := make(chan *graphql.Response, 0)

go func() {
for {
select {
case <-ctx.Done():
close(events)
return
default:
data := graphql.OrderedMap{}
data.Add("name", graphql.MarshalString("test"))

events <- &graphql.Response{
Data: &data,
}
func (e *executableSchemaStub) Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation) func() *graphql.Response {
return func() *graphql.Response {
time.Sleep(20 * time.Millisecond)
select {
case <-ctx.Done():
return nil
default:
return &graphql.Response{
Data: []byte(`{"name":"test"}`),
}
time.Sleep(20 * time.Millisecond)
}
}()
return events
}
}
13 changes: 9 additions & 4 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
c.active[message.ID] = cancel
c.mu.Unlock()
go func() {
for result := range c.exec.Subscription(ctx, doc, params.Variables, op) {
next := c.exec.Subscription(ctx, doc, params.Variables, op)
for result := next(); result != nil; result = next() {
fmt.Println(result)
c.sendData(message.ID, result)
}

Expand All @@ -187,10 +189,13 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
}

func (c *wsConnection) sendData(id string, response *graphql.Response) {
var b bytes.Buffer
response.MarshalGQL(&b)
b, err := json.Marshal(response)
if err != nil {
c.sendError(id, errors.Errorf("unable to encode json response: %s", err.Error()))
return
}

c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b.Bytes()})
c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
}

func (c *wsConnection) sendError(id string, errors ...*errors.QueryError) {
Expand Down

0 comments on commit 5ebd157

Please sign in to comment.