From abc7c61556bd2490d228c561f27e29b89e779bcd Mon Sep 17 00:00:00 2001 From: Giulio Collura Date: Mon, 21 Oct 2024 17:44:15 -0700 Subject: [PATCH] `multipart/mixed` transport support for deferred queries (#3341) --- graphql/handler/testserver/testserver.go | 34 +++- .../handler/transport/http_multipart_mixed.go | 160 +++++++++++++++++ .../transport/http_multipart_mixed_test.go | 168 ++++++++++++++++++ 3 files changed, 359 insertions(+), 3 deletions(-) create mode 100644 graphql/handler/transport/http_multipart_mixed.go create mode 100644 graphql/handler/transport/http_multipart_mixed_test.go diff --git a/graphql/handler/testserver/testserver.go b/graphql/handler/testserver/testserver.go index 33713b26d5d..9cd6eeb7a5c 100644 --- a/graphql/handler/testserver/testserver.go +++ b/graphql/handler/testserver/testserver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/vektah/gqlparser/v2" @@ -43,6 +44,32 @@ func New() *TestServer { switch rc.Operation.Operation { case ast.Query: ran := false + // If the query contains @defer, we will mimick a deferred response. + if strings.Contains(rc.RawQuery, "@defer") { + initialResponse := true + return func(context context.Context) *graphql.Response { + select { + case <-ctx.Done(): + return nil + case <-next: + if initialResponse { + initialResponse = false + hasNext := true + return &graphql.Response{ + Data: []byte(`{"name":null}`), + HasNext: &hasNext, + } + } + hasNext := false + return &graphql.Response{ + Data: []byte(`{"name":"test"}`), + HasNext: &hasNext, + } + case <-completeSubscription: + return nil + } + } + } return func(ctx context.Context) *graphql.Response { if ran { return nil @@ -59,9 +86,10 @@ func New() *TestServer { }, }, }) - res, err := graphql.GetOperationContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (any, error) { - return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil - }) + res, err := graphql.GetOperationContext(ctx). + ResolverMiddleware(ctx, func(ctx context.Context) (any, error) { + return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil + }) if err != nil { panic(err) } diff --git a/graphql/handler/transport/http_multipart_mixed.go b/graphql/handler/transport/http_multipart_mixed.go new file mode 100644 index 00000000000..362cae51020 --- /dev/null +++ b/graphql/handler/transport/http_multipart_mixed.go @@ -0,0 +1,160 @@ +package transport + +import ( + "encoding/json" + "fmt" + "io" + "log" + "mime" + "net/http" + "strings" + + "github.com/vektah/gqlparser/v2/gqlerror" + + "github.com/99designs/gqlgen/graphql" +) + +// MultipartMixed is a transport that supports the multipart/mixed spec +type MultipartMixed struct { + Boundary string +} + +var _ graphql.Transport = MultipartMixed{} + +// Supports checks if the request supports the multipart/mixed spec +// Might be worth check the spec required, but Apollo Client mislabel the spec in the headers. +func (t MultipartMixed) Supports(r *http.Request) bool { + if !strings.Contains(r.Header.Get("Accept"), "multipart/mixed") { + return false + } + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + return false + } + return r.Method == http.MethodPost && mediaType == "application/json" +} + +// Do implements the multipart/mixed spec as a multipart/mixed response +func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { + // Implements the multipart/mixed spec as a multipart/mixed response: + // * https://github.com/graphql/graphql-wg/blob/e4ef5f9d5997815d9de6681655c152b6b7838b4c/rfcs/DeferStream.md + // 2022/08/23 as implemented by gqlgen. + // * https://github.com/graphql/graphql-wg/blob/f22ea7748c6ebdf88fdbf770a8d9e41984ebd429/rfcs/DeferStream.md June 2023 Spec for the + // `incremental` field + // Follows the format that is used in the Apollo Client tests: + // https://github.com/apollographql/apollo-client/blob/v3.11.8/src/link/http/__tests__/responseIterator.ts#L68 + // Apollo Client, despite mentioning in its requests that they require the 2022 spec, it wants the + // `incremental` field to be an array of responses, not a single response. Theoretically we could + // batch responses in the `incremental` field, if we wanted to optimize this code. + ctx := r.Context() + flusher, ok := w.(http.Flusher) + if !ok { + SendErrorf(w, http.StatusInternalServerError, "streaming unsupported") + return + } + defer flusher.Flush() + + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + // This header will be replaced below, but it's required in case we return errors. + w.Header().Set("Content-Type", "application/json") + + boundary := t.Boundary + if boundary == "" { + boundary = "graphql" + } + + params := &graphql.RawParams{} + start := graphql.Now() + params.Headers = r.Header + params.ReadTime = graphql.TraceTiming{ + Start: start, + End: graphql.Now(), + } + + bodyString, err := getRequestBody(r) + if err != nil { + gqlErr := gqlerror.Errorf("could not get json request body: %+v", err) + resp := exec.DispatchError(ctx, gqlerror.List{gqlErr}) + log.Printf("could not get json request body: %+v", err.Error()) + writeJson(w, resp) + return + } + + bodyReader := io.NopCloser(strings.NewReader(bodyString)) + if err = jsonDecode(bodyReader, ¶ms); err != nil { + w.WriteHeader(http.StatusBadRequest) + gqlErr := gqlerror.Errorf( + "json request body could not be decoded: %+v body:%s", + err, + bodyString, + ) + resp := exec.DispatchError(ctx, gqlerror.List{gqlErr}) + log.Printf("decoding error: %+v body:%s", err.Error(), bodyString) + writeJson(w, resp) + return + } + + rc, opErr := exec.CreateOperationContext(ctx, params) + ctx = graphql.WithOperationContext(ctx, rc) + + // Example of the response format (note the new lines are important!): + // --graphql + // Content-Type: application/json + // + // {"data":{"apps":{"apps":[ .. ],"totalNumApps":161,"__typename":"AppsOutput"}},"hasNext":true} + // + // --graphql + // Content-Type: application/json + // + // {"incremental":[{"data":{"groupAccessCount":0},"label":"test","path":["apps","apps",7],"hasNext":true}],"hasNext":true} + + if opErr != nil { + w.WriteHeader(statusFor(opErr)) + + resp := exec.DispatchError(ctx, opErr) + writeJson(w, resp) + return + } + + w.Header().Set( + "Content-Type", + fmt.Sprintf(`multipart/mixed;boundary="%s";deferSpec=20220824`, boundary), + ) + + responses, ctx := exec.DispatchOperation(ctx, rc) + initialResponse := true + for { + response := responses(ctx) + if response == nil { + break + } + + fmt.Fprintf(w, "--%s\r\n", boundary) + fmt.Fprintf(w, "Content-Type: application/json\r\n\r\n") + + if initialResponse { + writeJson(w, response) + initialResponse = false + } else { + writeIncrementalJson(w, response, response.HasNext) + } + fmt.Fprintf(w, "\r\n\r\n") + flusher.Flush() + } +} + +func writeIncrementalJson(w io.Writer, response *graphql.Response, hasNext *bool) { + // TODO: Remove this wrapper on response once gqlgen supports the 2023 spec + b, err := json.Marshal(struct { + Incremental []graphql.Response `json:"incremental"` + HasNext *bool `json:"hasNext"` + }{ + Incremental: []graphql.Response{*response}, + HasNext: hasNext, + }) + if err != nil { + panic(err) + } + w.Write(b) +} diff --git a/graphql/handler/transport/http_multipart_mixed_test.go b/graphql/handler/transport/http_multipart_mixed_test.go new file mode 100644 index 00000000000..e590074e773 --- /dev/null +++ b/graphql/handler/transport/http_multipart_mixed_test.go @@ -0,0 +1,168 @@ +package transport_test + +import ( + "bufio" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/99designs/gqlgen/graphql/handler/testserver" + "github.com/99designs/gqlgen/graphql/handler/transport" +) + +func TestMultipartMixed(t *testing.T) { + initialize := func() *testserver.TestServer { + h := testserver.New() + h.AddTransport(transport.MultipartMixed{}) + return h + } + + initializeWithServer := func() (*testserver.TestServer, *httptest.Server) { + h := initialize() + return h, httptest.NewServer(h) + } + + createHTTPRequest := func(url string, query string) *http.Request { + req, err := http.NewRequest("POST", url, strings.NewReader(query)) + require.NoError(t, err, "Request threw error -> %s", err) + req.Header.Set("Accept", "multipart/mixed") + req.Header.Set("content-type", "application/json; charset=utf-8") + return req + } + + doRequest := func(handler http.Handler, target, body string) *httptest.ResponseRecorder { + r := createHTTPRequest(target, body) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + return w + } + + t.Run("decode failure", func(t *testing.T) { + handler, srv := initializeWithServer() + resp := doRequest(handler, srv.URL, "notjson") + assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String()) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.Equal( + t, + `{"errors":[{"message":"json request body could not be decoded: invalid character 'o' in literal null (expecting 'u') body:notjson"}],"data":null}`, + resp.Body.String(), + ) + }) + + t.Run("parse failure", func(t *testing.T) { + handler, srv := initializeWithServer() + resp := doRequest(handler, srv.URL, `{"query": "!"}`) + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.Equal( + t, + `{"errors":[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}],"data":null}`, + resp.Body.String(), + ) + }) + + t.Run("validation failure", func(t *testing.T) { + handler, srv := initializeWithServer() + resp := doRequest(handler, srv.URL, `{"query": "{ title }"}`) + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.Equal( + t, + `{"errors":[{"message":"Cannot query field \"title\" on type \"Query\".","locations":[{"line":1,"column":3}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`, + resp.Body.String(), + ) + }) + + t.Run("invalid variable", func(t *testing.T) { + handler, srv := initializeWithServer() + resp := doRequest(handler, srv.URL, + `{"query": "query($id:Int!){find(id:$id)}","variables":{"id":false}}`, + ) + assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String()) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + assert.Equal( + t, + `{"errors":[{"message":"cannot use bool as Int","path":["variable","id"],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`, + resp.Body.String(), + ) + }) + + readLine := func(br *bufio.Reader) string { + bs, err := br.ReadString('\n') + require.NoError(t, err) + return bs + } + + t.Run("initial and incremental patches", func(t *testing.T) { + handler, srv := initializeWithServer() + defer srv.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + handler.SendNextSubscriptionMessage() + }() + + client := &http.Client{} + req := createHTTPRequest( + srv.URL, + `{"query":"query { ... @defer { name } }"}`, + ) + res, err := client.Do(req) + require.NoError(t, err, "Request threw error -> %s", err) + defer func() { + require.NoError(t, res.Body.Close()) + }() + + assert.Equal(t, 200, res.StatusCode, "Request return wrong status -> %d", res.Status) + assert.Equal(t, "keep-alive", res.Header.Get("Connection")) + assert.Contains(t, res.Header.Get("Content-Type"), "multipart/mixed") + assert.Contains(t, res.Header.Get("Content-Type"), `boundary="graphql"`) + + br := bufio.NewReader(res.Body) + + assert.Equal(t, "--graphql\r\n", readLine(br)) + assert.Equal(t, "Content-Type: application/json\r\n", readLine(br)) + assert.Equal(t, "\r\n", readLine(br)) + assert.Equal(t, + "{\"data\":{\"name\":null},\"hasNext\":true}\r\n", + readLine(br), + ) + assert.Equal(t, "\r\n", readLine(br)) + + wg.Add(1) + go func() { + defer wg.Done() + handler.SendNextSubscriptionMessage() + }() + + assert.Equal(t, "--graphql\r\n", readLine(br)) + assert.Equal(t, "Content-Type: application/json\r\n", readLine(br)) + assert.Equal(t, "\r\n", readLine(br)) + assert.Equal( + t, + "{\"incremental\":[{\"data\":{\"name\":\"test\"},\"hasNext\":false}],\"hasNext\":false}\r\n", + readLine(br), + ) + assert.Equal(t, "\r\n", readLine(br)) + + wg.Add(1) + go func() { + defer wg.Done() + handler.SendCompleteSubscriptionMessage() + }() + + _, err = br.ReadByte() + assert.Equal(t, err, io.EOF) + + wg.Wait() + }) +}