Skip to content

Commit

Permalink
add StatusCodeFunc for HTTP Transport implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
vvakame committed Dec 24, 2019
1 parent f869f5a commit a3a6e40
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 24 deletions.
3 changes: 2 additions & 1 deletion graphql/executable_schema_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion graphql/handler/apollotracing/tracer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestApolloTracing_withFail(t *testing.T) {
h.Use(apollotracing.Tracer{})

resp := doRequest(h, "POST", "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
b := resp.Body.Bytes()
t.Log(string(b))
var respData struct {
Expand Down
32 changes: 22 additions & 10 deletions graphql/handler/extension/apq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@ func TestAPQIntegration(t *testing.T) {
h.Use(&extension.AutomaticPersistedQuery{Cache: graphql.MapCache{}})
h.AddTransport(&transport.POST{})

var stats *extension.ApqStats
h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
stats = extension.GetApqStats(ctx)
return next(ctx)
t.Run("hash only", func(t *testing.T) {
h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return next(ctx)
})

resp := doRequest(h, "POST", "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
require.Equal(t, `{"errors":[{"message":"PersistedQueryNotFound"}],"data":null}`, resp.Body.String())
})

resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }","extensions":{"persistedQuery":{"version":1,"sha256Hash":"30166fc3298853f22709fce1e4a00e98f1b6a3160eaaaf9cb3b7db6a16073b07"}}}`)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
t.Run("hash & query", func(t *testing.T) {
var stats *extension.ApqStats
h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
stats = extension.GetApqStats(ctx)
return next(ctx)
})

require.NotNil(t, stats)
require.True(t, stats.SentQuery)
require.Equal(t, "30166fc3298853f22709fce1e4a00e98f1b6a3160eaaaf9cb3b7db6a16073b07", stats.Hash)
resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }","extensions":{"persistedQuery":{"version":1,"sha256Hash":"30166fc3298853f22709fce1e4a00e98f1b6a3160eaaaf9cb3b7db6a16073b07"}}}`)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())

require.NotNil(t, stats)
require.True(t, stats.SentQuery)
require.Equal(t, "30166fc3298853f22709fce1e4a00e98f1b6a3160eaaaf9cb3b7db6a16073b07", stats.Hash)
})
}

func TestAPQ(t *testing.T) {
Expand Down
20 changes: 17 additions & 3 deletions graphql/handler/transport/http_form.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transport

import (
"context"
"encoding/json"
"io"
"io/ioutil"
Expand All @@ -22,6 +23,8 @@ type MultipartForm struct {
// as multipart/form-data in memory, with the remainder stored on disk in
// temporary files.
MaxMemory int64

StatusCodeFunc func(ctx context.Context, resp *graphql.Response) int
}

var _ graphql.Transport = MultipartForm{}
Expand Down Expand Up @@ -188,11 +191,22 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G

rc, gerr := exec.CreateOperationContext(r.Context(), &params)
if gerr != nil {
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), gerr)
w.WriteHeader(http.StatusUnprocessableEntity)
ctx := graphql.WithOperationContext(r.Context(), rc)
resp := exec.DispatchError(ctx, gerr)
f.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
resp := responses(ctx)
f.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
}

func (f MultipartForm) writeStatusCode(ctx context.Context, w http.ResponseWriter, resp *graphql.Response) {
if f.StatusCodeFunc == nil {
w.WriteHeader(httpStatusCode(resp))
} else {
w.WriteHeader(f.StatusCodeFunc(ctx, resp))
}
}
22 changes: 18 additions & 4 deletions graphql/handler/transport/http_get.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transport

import (
"context"
"encoding/json"
"io"
"net/http"
Expand All @@ -12,7 +13,9 @@ import (

// GET implements the GET side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#get
type GET struct{}
type GET struct {
StatusCodeFunc func(ctx context.Context, resp *graphql.Response) int
}

var _ graphql.Transport = GET{}

Expand Down Expand Up @@ -48,8 +51,9 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut

rc, err := exec.CreateOperationContext(r.Context(), raw)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
ctx := graphql.WithOperationContext(r.Context(), rc)
resp := exec.DispatchError(ctx, err)
h.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
return
}
Expand All @@ -61,7 +65,17 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut
}

responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
resp := responses(ctx)
h.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
}

func (h GET) writeStatusCode(ctx context.Context, w http.ResponseWriter, resp *graphql.Response) {
if h.StatusCodeFunc == nil {
w.WriteHeader(httpStatusCode(resp))
} else {
w.WriteHeader(h.StatusCodeFunc(ctx, resp))
}
}

func jsonDecode(r io.Reader, val interface{}) error {
Expand Down
22 changes: 18 additions & 4 deletions graphql/handler/transport/http_post.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transport

import (
"context"
"mime"
"net/http"

Expand All @@ -9,7 +10,9 @@ import (

// POST implements the POST side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#post
type POST struct{}
type POST struct {
StatusCodeFunc func(ctx context.Context, resp *graphql.Response) int
}

var _ graphql.Transport = POST{}

Expand Down Expand Up @@ -38,11 +41,22 @@ func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecu

rc, err := exec.CreateOperationContext(r.Context(), params)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
ctx := graphql.WithOperationContext(r.Context(), rc)
resp := exec.DispatchError(ctx, err)
h.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
resp := responses(ctx)
h.writeStatusCode(ctx, w, resp)
writeJson(w, resp)
}

func (h POST) writeStatusCode(ctx context.Context, w http.ResponseWriter, resp *graphql.Response) {
if h.StatusCodeFunc == nil {
w.WriteHeader(httpStatusCode(resp))
} else {
w.WriteHeader(h.StatusCodeFunc(ctx, resp))
}
}
2 changes: 1 addition & 1 deletion graphql/handler/transport/http_post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestPOST(t *testing.T) {

t.Run("execution failure", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", `{"query": "mutation { name }"}`)
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, resp.Header().Get("Content-Type"), "application/json")
assert.Equal(t, `{"errors":[{"message":"mutations are not supported"}],"data":null}`, resp.Body.String())
})
Expand Down
16 changes: 16 additions & 0 deletions graphql/handler/transport/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/gqlerror"
Expand All @@ -28,3 +29,18 @@ func writeJsonErrorf(w io.Writer, format string, args ...interface{}) {
func writeJsonGraphqlError(w io.Writer, err ...*gqlerror.Error) {
writeJson(w, &graphql.Response{Errors: err})
}

func httpStatusCode(resp *graphql.Response) int {
if len(resp.Errors) == 0 {
return http.StatusOK
}

if len(resp.Data) != 0 {
return http.StatusOK
} else if len(resp.Errors) == 1 && resp.Errors[0].Message == "PersistedQueryNotFound" {
// for APQ
return http.StatusOK
}

return http.StatusUnprocessableEntity
}

0 comments on commit a3a6e40

Please sign in to comment.