Skip to content

Commit

Permalink
multipart/mixed transport support for deferred queries (#3341)
Browse files Browse the repository at this point in the history
  • Loading branch information
giulio-opal authored Oct 22, 2024
1 parent acd9b6a commit abc7c61
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 3 deletions.
34 changes: 31 additions & 3 deletions graphql/handler/testserver/testserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/vektah/gqlparser/v2"
Expand Down Expand Up @@ -43,6 +44,32 @@ func New() *TestServer {
switch rc.Operation.Operation {
case ast.Query:
ran := false
// If the query contains @defer, we will mimick a deferred response.
if strings.Contains(rc.RawQuery, "@defer") {
initialResponse := true
return func(context context.Context) *graphql.Response {
select {
case <-ctx.Done():
return nil
case <-next:
if initialResponse {
initialResponse = false
hasNext := true
return &graphql.Response{
Data: []byte(`{"name":null}`),
HasNext: &hasNext,
}
}
hasNext := false
return &graphql.Response{
Data: []byte(`{"name":"test"}`),
HasNext: &hasNext,
}
case <-completeSubscription:
return nil
}
}
}
return func(ctx context.Context) *graphql.Response {
if ran {
return nil
Expand All @@ -59,9 +86,10 @@ func New() *TestServer {
},
},
})
res, err := graphql.GetOperationContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (any, error) {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil
})
res, err := graphql.GetOperationContext(ctx).
ResolverMiddleware(ctx, func(ctx context.Context) (any, error) {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil
})
if err != nil {
panic(err)
}
Expand Down
160 changes: 160 additions & 0 deletions graphql/handler/transport/http_multipart_mixed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package transport

import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"

"github.com/vektah/gqlparser/v2/gqlerror"

"github.com/99designs/gqlgen/graphql"
)

// MultipartMixed is a transport that supports the multipart/mixed spec
type MultipartMixed struct {
Boundary string
}

var _ graphql.Transport = MultipartMixed{}

// Supports checks if the request supports the multipart/mixed spec
// Might be worth check the spec required, but Apollo Client mislabel the spec in the headers.
func (t MultipartMixed) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "multipart/mixed") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}

// Do implements the multipart/mixed spec as a multipart/mixed response
func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
// Implements the multipart/mixed spec as a multipart/mixed response:
// * https://github.com/graphql/graphql-wg/blob/e4ef5f9d5997815d9de6681655c152b6b7838b4c/rfcs/DeferStream.md
// 2022/08/23 as implemented by gqlgen.
// * https://github.com/graphql/graphql-wg/blob/f22ea7748c6ebdf88fdbf770a8d9e41984ebd429/rfcs/DeferStream.md June 2023 Spec for the
// `incremental` field
// Follows the format that is used in the Apollo Client tests:
// https://github.com/apollographql/apollo-client/blob/v3.11.8/src/link/http/__tests__/responseIterator.ts#L68
// Apollo Client, despite mentioning in its requests that they require the 2022 spec, it wants the
// `incremental` field to be an array of responses, not a single response. Theoretically we could
// batch responses in the `incremental` field, if we wanted to optimize this code.
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()

w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// This header will be replaced below, but it's required in case we return errors.
w.Header().Set("Content-Type", "application/json")

boundary := t.Boundary
if boundary == "" {
boundary = "graphql"
}

params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}

bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}

bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}

rc, opErr := exec.CreateOperationContext(ctx, params)
ctx = graphql.WithOperationContext(ctx, rc)

// Example of the response format (note the new lines are important!):
// --graphql
// Content-Type: application/json
//
// {"data":{"apps":{"apps":[ .. ],"totalNumApps":161,"__typename":"AppsOutput"}},"hasNext":true}
//
// --graphql
// Content-Type: application/json
//
// {"incremental":[{"data":{"groupAccessCount":0},"label":"test","path":["apps","apps",7],"hasNext":true}],"hasNext":true}

if opErr != nil {
w.WriteHeader(statusFor(opErr))

resp := exec.DispatchError(ctx, opErr)
writeJson(w, resp)
return
}

w.Header().Set(
"Content-Type",
fmt.Sprintf(`multipart/mixed;boundary="%s";deferSpec=20220824`, boundary),
)

responses, ctx := exec.DispatchOperation(ctx, rc)
initialResponse := true
for {
response := responses(ctx)
if response == nil {
break
}

fmt.Fprintf(w, "--%s\r\n", boundary)
fmt.Fprintf(w, "Content-Type: application/json\r\n\r\n")

if initialResponse {
writeJson(w, response)
initialResponse = false
} else {
writeIncrementalJson(w, response, response.HasNext)
}
fmt.Fprintf(w, "\r\n\r\n")
flusher.Flush()
}
}

func writeIncrementalJson(w io.Writer, response *graphql.Response, hasNext *bool) {
// TODO: Remove this wrapper on response once gqlgen supports the 2023 spec
b, err := json.Marshal(struct {
Incremental []graphql.Response `json:"incremental"`
HasNext *bool `json:"hasNext"`
}{
Incremental: []graphql.Response{*response},
HasNext: hasNext,
})
if err != nil {
panic(err)
}
w.Write(b)
}
168 changes: 168 additions & 0 deletions graphql/handler/transport/http_multipart_mixed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package transport_test

import (
"bufio"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/99designs/gqlgen/graphql/handler/testserver"
"github.com/99designs/gqlgen/graphql/handler/transport"
)

func TestMultipartMixed(t *testing.T) {
initialize := func() *testserver.TestServer {
h := testserver.New()
h.AddTransport(transport.MultipartMixed{})
return h
}

initializeWithServer := func() (*testserver.TestServer, *httptest.Server) {
h := initialize()
return h, httptest.NewServer(h)
}

createHTTPRequest := func(url string, query string) *http.Request {
req, err := http.NewRequest("POST", url, strings.NewReader(query))
require.NoError(t, err, "Request threw error -> %s", err)
req.Header.Set("Accept", "multipart/mixed")
req.Header.Set("content-type", "application/json; charset=utf-8")
return req
}

doRequest := func(handler http.Handler, target, body string) *httptest.ResponseRecorder {
r := createHTTPRequest(target, body)
w := httptest.NewRecorder()

handler.ServeHTTP(w, r)
return w
}

t.Run("decode failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, "notjson")
assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"json request body could not be decoded: invalid character 'o' in literal null (expecting 'u') body:notjson"}],"data":null}`,
resp.Body.String(),
)
})

t.Run("parse failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, `{"query": "!"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

t.Run("validation failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, `{"query": "{ title }"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"Cannot query field \"title\" on type \"Query\".","locations":[{"line":1,"column":3}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

t.Run("invalid variable", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL,
`{"query": "query($id:Int!){find(id:$id)}","variables":{"id":false}}`,
)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"cannot use bool as Int","path":["variable","id"],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

readLine := func(br *bufio.Reader) string {
bs, err := br.ReadString('\n')
require.NoError(t, err)
return bs
}

t.Run("initial and incremental patches", func(t *testing.T) {
handler, srv := initializeWithServer()
defer srv.Close()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
handler.SendNextSubscriptionMessage()
}()

client := &http.Client{}
req := createHTTPRequest(
srv.URL,
`{"query":"query { ... @defer { name } }"}`,
)
res, err := client.Do(req)
require.NoError(t, err, "Request threw error -> %s", err)
defer func() {
require.NoError(t, res.Body.Close())
}()

assert.Equal(t, 200, res.StatusCode, "Request return wrong status -> %d", res.Status)
assert.Equal(t, "keep-alive", res.Header.Get("Connection"))
assert.Contains(t, res.Header.Get("Content-Type"), "multipart/mixed")
assert.Contains(t, res.Header.Get("Content-Type"), `boundary="graphql"`)

br := bufio.NewReader(res.Body)

assert.Equal(t, "--graphql\r\n", readLine(br))
assert.Equal(t, "Content-Type: application/json\r\n", readLine(br))
assert.Equal(t, "\r\n", readLine(br))
assert.Equal(t,
"{\"data\":{\"name\":null},\"hasNext\":true}\r\n",
readLine(br),
)
assert.Equal(t, "\r\n", readLine(br))

wg.Add(1)
go func() {
defer wg.Done()
handler.SendNextSubscriptionMessage()
}()

assert.Equal(t, "--graphql\r\n", readLine(br))
assert.Equal(t, "Content-Type: application/json\r\n", readLine(br))
assert.Equal(t, "\r\n", readLine(br))
assert.Equal(
t,
"{\"incremental\":[{\"data\":{\"name\":\"test\"},\"hasNext\":false}],\"hasNext\":false}\r\n",
readLine(br),
)
assert.Equal(t, "\r\n", readLine(br))

wg.Add(1)
go func() {
defer wg.Done()
handler.SendCompleteSubscriptionMessage()
}()

_, err = br.ReadByte()
assert.Equal(t, err, io.EOF)

wg.Wait()
})
}

0 comments on commit abc7c61

Please sign in to comment.