Skip to content

Commit

Permalink
Fixes querier cancellation error to return 499 instead of 500. (#1745)
Browse files Browse the repository at this point in the history
Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>
  • Loading branch information
cyriltovena authored Feb 26, 2020
1 parent d27f119 commit 6328146
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 43 deletions.
84 changes: 41 additions & 43 deletions pkg/querier/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package querier

import (
"context"
"fmt"
"net/http"
"time"

Expand All @@ -23,6 +22,9 @@ import (

const (
wsPingPeriod = 1 * time.Second

// StatusClientClosedRequest is the status code for when a client request cancellation of an http request
StatusClientClosedRequest = 499
)

type QueryResponse struct {
Expand All @@ -38,22 +40,18 @@ func (q *Querier) RangeQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseRangeQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
if logql.IsParseError(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}

if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}
}
Expand All @@ -66,22 +64,18 @@ func (q *Querier) InstantQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseInstantQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
query := q.engine.NewInstantQuery(request.Query, request.Ts, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
if logql.IsParseError(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}

if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}
}
Expand All @@ -94,44 +88,36 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseRangeQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
request.Query, err = parseRegexQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

expr, err := logql.ParseExpr(request.Query)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
writeError(err, w)
return
}

// short circuit metric queries
if _, ok := expr.(logql.SampleExpr); ok {
http.Error(
w,
fmt.Sprintf("legacy endpoints only support %s result type", logql.ValueTypeStreams),
http.StatusBadRequest,
)
writeError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w)
return
}

query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
if logql.IsParseError(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}

if err := marshal_legacy.WriteQueryResponseJSON(result, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}
}
Expand All @@ -140,13 +126,13 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseLabelQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

resp, err := q.Label(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}

Expand All @@ -156,7 +142,7 @@ func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
err = marshal_legacy.WriteLabelResponseJSON(*resp, w)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}
}
Expand All @@ -169,13 +155,13 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {

req, err := loghttp.ParseTailQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

req.Query, err = parseRegexQuery(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

Expand Down Expand Up @@ -276,19 +262,19 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) SeriesHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseSeriesQuery(r)
if err != nil {
http.Error(w, httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(), http.StatusBadRequest)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

resp, err := q.Series(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}

err = marshal.WriteSeriesResponseJSON(*resp, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
writeError(err, w)
return
}
}
Expand All @@ -300,12 +286,7 @@ func NewPrepopulateMiddleware() middleware.Interface {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
status := http.StatusBadRequest
http.Error(
w,
httpgrpc.Errorf(http.StatusBadRequest, err.Error()).Error(),
status,
)
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return

}
Expand All @@ -328,3 +309,20 @@ func parseRegexQuery(httpRequest *http.Request) (string, error) {
}
return query, nil
}

func writeError(err error, w http.ResponseWriter) {
switch {
case err == context.Canceled:
http.Error(w, err.Error(), StatusClientClosedRequest)
case err == context.DeadlineExceeded:
http.Error(w, err.Error(), http.StatusGatewayTimeout)
case logql.IsParseError(err):
http.Error(w, err.Error(), http.StatusBadRequest)
default:
if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok {
http.Error(w, string(grpcErr.Body), int(grpcErr.Code))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
33 changes: 33 additions & 0 deletions pkg/querier/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@ package querier

import (
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"

"github.com/grafana/loki/pkg/logql"
)

func TestPrepopulate(t *testing.T) {
Expand Down Expand Up @@ -104,3 +110,30 @@ func TestPrepopulate(t *testing.T) {
})
}
}

func Test_writeError(t *testing.T) {
for _, tt := range []struct {
name string

err error
msg string
expectedStatus int
}{
{"cancelled", context.Canceled, context.Canceled.Error(), StatusClientClosedRequest},
{"deadline", context.DeadlineExceeded, context.DeadlineExceeded.Error(), http.StatusGatewayTimeout},
{"parse error", logql.ParseError{}, "parse error : ", http.StatusBadRequest},
{"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest},
{"internal", errors.New("foo"), "foo", http.StatusInternalServerError},
} {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
writeError(tt.err, rec)
require.Equal(t, tt.expectedStatus, rec.Result().StatusCode)
b, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
require.Equal(t, tt.msg, string(b[:len(b)-1]))
})
}
}

0 comments on commit 6328146

Please sign in to comment.