From 63281468765b56c4581cec4fc9cbb384f39b68fe Mon Sep 17 00:00:00 2001 From: Cyril Tovena Date: Wed, 26 Feb 2020 12:30:46 -0500 Subject: [PATCH] Fixes querier cancellation error to return 499 instead of 500. (#1745) Signed-off-by: Cyril Tovena --- pkg/querier/http.go | 84 ++++++++++++++++++++-------------------- pkg/querier/http_test.go | 33 ++++++++++++++++ 2 files changed, 74 insertions(+), 43 deletions(-) diff --git a/pkg/querier/http.go b/pkg/querier/http.go index a289e3d4b5dec..ee6988ebc4cf9 100644 --- a/pkg/querier/http.go +++ b/pkg/querier/http.go @@ -2,7 +2,6 @@ package querier import ( "context" - "fmt" "net/http" "time" @@ -23,6 +22,9 @@ import ( const ( wsPingPeriod = 1 * time.Second + + // StatusClientClosedRequest is the status code for when a client request cancellation of an http request + StatusClientClosedRequest = 499 ) type QueryResponse struct { @@ -38,22 +40,18 @@ func (q *Querier) RangeQueryHandler(w http.ResponseWriter, r *http.Request) { request, err := loghttp.ParseRangeQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Direction, request.Limit) result, err := query.Exec(ctx) if err != nil { - if logql.IsParseError(err) { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } if err := marshal.WriteQueryResponseJSON(result, w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } } @@ -66,22 +64,18 @@ func (q *Querier) InstantQueryHandler(w http.ResponseWriter, r *http.Request) { request, err := loghttp.ParseInstantQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } query := q.engine.NewInstantQuery(request.Query, request.Ts, request.Direction, request.Limit) result, err := query.Exec(ctx) if err != nil { - if logql.IsParseError(err) { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } if err := marshal.WriteQueryResponseJSON(result, w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } } @@ -94,44 +88,36 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) { request, err := loghttp.ParseRangeQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } request.Query, err = parseRegexQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } expr, err := logql.ParseExpr(request.Query) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(err, w) return } // short circuit metric queries if _, ok := expr.(logql.SampleExpr); ok { - http.Error( - w, - fmt.Sprintf("legacy endpoints only support %s result type", logql.ValueTypeStreams), - http.StatusBadRequest, - ) + writeError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w) return } query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Direction, request.Limit) result, err := query.Exec(ctx) if err != nil { - if logql.IsParseError(err) { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } if err := marshal_legacy.WriteQueryResponseJSON(result, w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } } @@ -140,13 +126,13 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) { func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) { req, err := loghttp.ParseLabelQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } resp, err := q.Label(r.Context(), req) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } @@ -156,7 +142,7 @@ func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) { err = marshal_legacy.WriteLabelResponseJSON(*resp, w) } if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } } @@ -169,13 +155,13 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) { req, err := loghttp.ParseTailQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } req.Query, err = parseRegexQuery(r) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } @@ -276,19 +262,19 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) { func (q *Querier) SeriesHandler(w http.ResponseWriter, r *http.Request) { req, err := loghttp.ParseSeriesQuery(r) if err != nil { - http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } resp, err := q.Series(r.Context(), req) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } err = marshal.WriteSeriesResponseJSON(*resp, w) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(err, w) return } } @@ -300,12 +286,7 @@ func NewPrepopulateMiddleware() middleware.Interface { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { - status := http.StatusBadRequest - http.Error( - w, - httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), - status, - ) + writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w) return } @@ -328,3 +309,20 @@ func parseRegexQuery(httpRequest *http.Request) (string, error) { } return query, nil } + +func writeError(err error, w http.ResponseWriter) { + switch { + case err == context.Canceled: + http.Error(w, err.Error(), StatusClientClosedRequest) + case err == context.DeadlineExceeded: + http.Error(w, err.Error(), http.StatusGatewayTimeout) + case logql.IsParseError(err): + http.Error(w, err.Error(), http.StatusBadRequest) + default: + if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok { + http.Error(w, string(grpcErr.Body), int(grpcErr.Code)) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/pkg/querier/http_test.go b/pkg/querier/http_test.go index 70eae4525880b..e1387209a9b12 100644 --- a/pkg/querier/http_test.go +++ b/pkg/querier/http_test.go @@ -2,13 +2,19 @@ package querier import ( "bytes" + "context" + "errors" "io" + "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/require" + "github.com/weaveworks/common/httpgrpc" + + "github.com/grafana/loki/pkg/logql" ) func TestPrepopulate(t *testing.T) { @@ -104,3 +110,30 @@ func TestPrepopulate(t *testing.T) { }) } } + +func Test_writeError(t *testing.T) { + for _, tt := range []struct { + name string + + err error + msg string + expectedStatus int + }{ + {"cancelled", context.Canceled, context.Canceled.Error(), StatusClientClosedRequest}, + {"deadline", context.DeadlineExceeded, context.DeadlineExceeded.Error(), http.StatusGatewayTimeout}, + {"parse error", logql.ParseError{}, "parse error : ", http.StatusBadRequest}, + {"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest}, + {"internal", errors.New("foo"), "foo", http.StatusInternalServerError}, + } { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + writeError(tt.err, rec) + require.Equal(t, tt.expectedStatus, rec.Result().StatusCode) + b, err := ioutil.ReadAll(rec.Result().Body) + if err != nil { + t.Fatal(err) + } + require.Equal(t, tt.msg, string(b[:len(b)-1])) + }) + } +}