diff --git a/graphql/handler/apollofederatedtracingv1/tracing.go b/graphql/handler/apollofederatedtracingv1/tracing.go index 689c3c0640a..b186201bebe 100644 --- a/graphql/handler/apollofederatedtracingv1/tracing.go +++ b/graphql/handler/apollofederatedtracingv1/tracing.go @@ -41,7 +41,8 @@ func (Tracer) Validate(graphql.ExecutableSchema) error { } func (t *Tracer) shouldTrace(ctx context.Context) bool { - return graphql.GetOperationContext(ctx).Headers.Get("apollo-federation-include-trace") == "ftv1" + return graphql.HasOperationContext(ctx) && + graphql.GetOperationContext(ctx).Headers.Get("apollo-federation-include-trace") == "ftv1" } func (t *Tracer) getTreeBuilder(ctx context.Context) *TreeBuilder { diff --git a/graphql/handler/apollofederatedtracingv1/tracing_test.go b/graphql/handler/apollofederatedtracingv1/tracing_test.go index 9ed91f91f84..7217217ed6f 100644 --- a/graphql/handler/apollofederatedtracingv1/tracing_test.go +++ b/graphql/handler/apollofederatedtracingv1/tracing_test.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -14,7 +15,6 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1" "github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1/generated" - "github.com/99designs/gqlgen/graphql/handler/apollotracing" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/testserver" @@ -25,6 +25,12 @@ import ( "google.golang.org/protobuf/proto" ) +type alwaysError struct{} + +func (a *alwaysError) Read(p []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + func TestApolloTracing(t *testing.T) { h := testserver.New() h.AddTransport(transport.POST{}) @@ -89,7 +95,7 @@ func TestApolloTracing_withFail(t *testing.T) { h := testserver.New() h.AddTransport(transport.POST{}) h.Use(extension.AutomaticPersistedQuery{Cache: lru.New(100)}) - h.Use(apollotracing.Tracer{}) + h.Use(&apollofederatedtracingv1.Tracer{}) resp := doRequest(h, http.MethodPost, "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`) assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) @@ -103,8 +109,21 @@ func TestApolloTracing_withFail(t *testing.T) { require.Equal(t, "PersistedQueryNotFound", respData.Errors[0].Message) } +func TestApolloTracing_withUnexpectedEOF(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.POST{}) + h.Use(&apollofederatedtracingv1.Tracer{}) + + resp := doRequestWithReader(h, http.MethodPost, "/graphql", &alwaysError{}) + assert.Equal(t, http.StatusOK, resp.Code) +} func doRequest(handler http.Handler, method, target, body string) *httptest.ResponseRecorder { - r := httptest.NewRequest(method, target, strings.NewReader(body)) + return doRequestWithReader(handler, method, target, strings.NewReader(body)) +} + +func doRequestWithReader(handler http.Handler, method string, target string, + reader io.Reader) *httptest.ResponseRecorder { + r := httptest.NewRequest(method, target, reader) r.Header.Set("Content-Type", "application/json") r.Header.Set("apollo-federation-include-trace", "ftv1") w := httptest.NewRecorder() diff --git a/graphql/handler/apollotracing/tracer.go b/graphql/handler/apollotracing/tracer.go index 4cfd1592bd3..d1e92bbfdcd 100644 --- a/graphql/handler/apollotracing/tracer.go +++ b/graphql/handler/apollotracing/tracer.go @@ -85,6 +85,10 @@ func (Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (interf } func (Tracer) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response { + if !graphql.HasOperationContext(ctx) { + return next(ctx) + } + rc := graphql.GetOperationContext(ctx) start := rc.Stats.OperationStart diff --git a/graphql/handler/apollotracing/tracer_test.go b/graphql/handler/apollotracing/tracer_test.go index 0b0f30ef2a4..789448cae9a 100644 --- a/graphql/handler/apollotracing/tracer_test.go +++ b/graphql/handler/apollotracing/tracer_test.go @@ -2,6 +2,7 @@ package apollotracing_test import ( "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -20,6 +21,12 @@ import ( "github.com/vektah/gqlparser/v2/gqlerror" ) +type alwaysError struct{} + +func (a *alwaysError) Read(p []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + func TestApolloTracing(t *testing.T) { now := time.Unix(0, 0) @@ -92,8 +99,22 @@ func TestApolloTracing_withFail(t *testing.T) { require.Equal(t, "PersistedQueryNotFound", respData.Errors[0].Message) } +func TestApolloTracing_withUnexpectedEOF(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.POST{}) + h.Use(apollotracing.Tracer{}) + + resp := doRequestWithReader(h, http.MethodPost, "/graphql", &alwaysError{}) + assert.Equal(t, http.StatusOK, resp.Code) +} + func doRequest(handler http.Handler, method, target, body string) *httptest.ResponseRecorder { - r := httptest.NewRequest(method, target, strings.NewReader(body)) + return doRequestWithReader(handler, method, target, strings.NewReader(body)) +} + +func doRequestWithReader(handler http.Handler, method string, target string, + reader io.Reader) *httptest.ResponseRecorder { + r := httptest.NewRequest(method, target, reader) r.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder()