diff --git a/graphql/handler/transport/headers.go b/graphql/handler/transport/headers.go new file mode 100644 index 00000000000..bc4e572444b --- /dev/null +++ b/graphql/handler/transport/headers.go @@ -0,0 +1,17 @@ +package transport + +import "net/http" + +func writeHeaders(w http.ResponseWriter, headers map[string][]string) { + if len(headers) == 0 { + headers = map[string][]string{ + "Content-Type": {"application/json"}, + } + } + + for key, values := range headers { + for _, value := range values { + w.Header().Add(key, value) + } + } +} diff --git a/graphql/handler/transport/headers_test.go b/graphql/handler/transport/headers_test.go new file mode 100644 index 00000000000..5e8160861a2 --- /dev/null +++ b/graphql/handler/transport/headers_test.go @@ -0,0 +1,163 @@ +package transport_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/handler/testserver" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser/v2" + "github.com/vektah/gqlparser/v2/ast" +) + +func TestHeadersWithPOST(t *testing.T) { + t.Run("Headers not set", func(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.POST{}) + + resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, 1, len(resp.Header())) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + }) + + t.Run("Headers set", func(t *testing.T) { + headers := map[string][]string{ + "Content-Type": {"application/json; charset: utf8"}, + "Other-Header": {"dummy-post", "another-one"}, + } + + h := testserver.New() + h.AddTransport(transport.POST{ResponseHeaders: headers}) + + resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, 2, len(resp.Header())) + assert.Equal(t, "application/json; charset: utf8", resp.Header().Get("Content-Type")) + assert.Equal(t, "dummy-post", resp.Header().Get("Other-Header")) + assert.Equal(t, "another-one", resp.Header().Values("Other-Header")[1]) + }) +} + +func TestHeadersWithGET(t *testing.T) { + t.Run("Headers not set", func(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.GET{}) + + resp := doRequest(h, "GET", "/graphql?query={name}", "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, 1, len(resp.Header())) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + }) + + t.Run("Headers set", func(t *testing.T) { + headers := map[string][]string{ + "Content-Type": {"application/json; charset: utf8"}, + "Other-Header": {"dummy-get"}, + } + + h := testserver.New() + h.AddTransport(transport.GET{ResponseHeaders: headers}) + + resp := doRequest(h, "GET", "/graphql?query={name}", "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, 2, len(resp.Header())) + assert.Equal(t, "application/json; charset: utf8", resp.Header().Get("Content-Type")) + assert.Equal(t, "dummy-get", resp.Header().Get("Other-Header")) + }) +} + +func TestHeadersWithMULTIPART(t *testing.T) { + t.Run("Headers not set", func(t *testing.T) { + es := &graphql.ExecutableSchemaMock{ + ExecFunc: func(ctx context.Context) graphql.ResponseHandler { + return graphql.OneShot(graphql.ErrorResponse(ctx, "not implemented")) + }, + SchemaFunc: func() *ast.Schema { + return gqlparser.MustLoadSchema(&ast.Source{Input: ` + type Mutation { + singleUpload(file: Upload!): String! + } + scalar Upload + `}) + }, + } + + h := handler.New(es) + h.AddTransport(transport.MultipartForm{}) + + es.ExecFunc = func(ctx context.Context) graphql.ResponseHandler { + return graphql.OneShot(&graphql.Response{Data: []byte(`{"singleUpload":"test"}`)}) + } + + operations := `{ "query": "mutation ($file: Upload!) { singleUpload(file: $file) }", "variables": { "file": null } }` + mapData := `{ "0": ["variables.file"] }` + files := []file{ + { + mapKey: "0", + name: "a.txt", + content: "test1", + contentType: "text/plain", + }, + } + req := createUploadRequest(t, operations, mapData, files) + + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + assert.Equal(t, 1, len(resp.Header())) + assert.Equal(t, "application/json", resp.Header().Get("Content-Type")) + }) + + t.Run("Headers set", func(t *testing.T) { + es := &graphql.ExecutableSchemaMock{ + ExecFunc: func(ctx context.Context) graphql.ResponseHandler { + return graphql.OneShot(graphql.ErrorResponse(ctx, "not implemented")) + }, + SchemaFunc: func() *ast.Schema { + return gqlparser.MustLoadSchema(&ast.Source{Input: ` + type Mutation { + singleUpload(file: Upload!): String! + } + scalar Upload + `}) + }, + } + + h := handler.New(es) + headers := map[string][]string{ + "Content-Type": {"application/json; charset: utf8"}, + "Other-Header": {"dummy-multipart"}, + } + h.AddTransport(transport.MultipartForm{ResponseHeaders: headers}) + + es.ExecFunc = func(ctx context.Context) graphql.ResponseHandler { + return graphql.OneShot(&graphql.Response{Data: []byte(`{"singleUpload":"test"}`)}) + } + + operations := `{ "query": "mutation ($file: Upload!) { singleUpload(file: $file) }", "variables": { "file": null } }` + mapData := `{ "0": ["variables.file"] }` + files := []file{ + { + mapKey: "0", + name: "a.txt", + content: "test1", + contentType: "text/plain", + }, + } + req := createUploadRequest(t, operations, mapData, files) + + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) + assert.Equal(t, 2, len(resp.Header())) + assert.Equal(t, "application/json; charset: utf8", resp.Header().Get("Content-Type")) + assert.Equal(t, "dummy-multipart", resp.Header().Get("Other-Header")) + }) +} diff --git a/graphql/handler/transport/http_form.go b/graphql/handler/transport/http_form.go index 3d3477b9ba6..b9eb5f8f433 100644 --- a/graphql/handler/transport/http_form.go +++ b/graphql/handler/transport/http_form.go @@ -20,6 +20,10 @@ type MultipartForm struct { // as multipart/form-data in memory, with the remainder stored on disk in // temporary files. MaxMemory int64 + + // Map of all headers that are added to graphql response. If not + // set, only one header: Content-Type: application/json will be set. + ResponseHeaders map[string][]string } var _ graphql.Transport = MultipartForm{} @@ -52,7 +56,7 @@ func (f MultipartForm) maxMemory() int64 { } func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { - w.Header().Set("Content-Type", "application/json") + writeHeaders(w, f.ResponseHeaders) start := graphql.Now() diff --git a/graphql/handler/transport/http_get.go b/graphql/handler/transport/http_get.go index 8114ba66a37..324fd986834 100644 --- a/graphql/handler/transport/http_get.go +++ b/graphql/handler/transport/http_get.go @@ -15,7 +15,11 @@ import ( // GET implements the GET side of the default HTTP transport // defined in https://github.com/APIs-guru/graphql-over-http#get -type GET struct{} +type GET struct { + // Map of all headers that are added to graphql response. If not + // set, only one header: Content-Type: application/json will be set. + ResponseHeaders map[string][]string +} var _ graphql.Transport = GET{} @@ -34,7 +38,7 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut writeJsonError(w, err.Error()) return } - w.Header().Set("Content-Type", "application/json") + writeHeaders(w, h.ResponseHeaders) raw := &graphql.RawParams{ Query: query.Get("query"), diff --git a/graphql/handler/transport/http_post.go b/graphql/handler/transport/http_post.go index 092e707f35d..a37010ab74b 100644 --- a/graphql/handler/transport/http_post.go +++ b/graphql/handler/transport/http_post.go @@ -14,7 +14,11 @@ import ( // POST implements the POST side of the default HTTP transport // defined in https://github.com/APIs-guru/graphql-over-http#post -type POST struct{} +type POST struct { + // Map of all headers that are added to graphql response. If not + // set, only one header: Content-Type: application/json will be set. + ResponseHeaders map[string][]string +} var _ graphql.Transport = POST{} @@ -44,7 +48,7 @@ func getRequestBody(r *http.Request) (string, error) { func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) { ctx := r.Context() - w.Header().Set("Content-Type", "application/json") + writeHeaders(w, h.ResponseHeaders) params := &graphql.RawParams{} start := graphql.Now() params.Headers = r.Header