Skip to content

Commit

Permalink
Recover from panic in http and grpc handlers. (#2059)
Browse files Browse the repository at this point in the history
* Recover from panic in http and grpc handlers.

I don't see any good reason to crash any component during a bad request.

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>

* Add alerts to the mixin for panics.

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>

* 😡 gomod

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>
  • Loading branch information
cyriltovena authored May 11, 2020
1 parent f5b9cff commit bce4470
Show file tree
Hide file tree
Showing 16 changed files with 372 additions and 121 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ require (
github.com/golang/snappy v0.0.1
github.com/gorilla/mux v1.7.1
github.com/gorilla/websocket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware v1.1.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645
github.com/hashicorp/golang-lru v0.5.3
Expand Down
2 changes: 1 addition & 1 deletion pkg/loki/fake_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ var fakeHTTPAuthMiddleware = middleware.Func(func(next http.Handler) http.Handle
})
})

var fakeGRPCAuthUniaryMiddleware = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var fakeGRPCAuthUnaryMiddleware = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx = user.InjectOrgID(ctx, "fake")
return handler(ctx, req)
}
Expand Down
45 changes: 21 additions & 24 deletions pkg/loki/loki.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/grafana/loki/pkg/querier"
"github.com/grafana/loki/pkg/querier/queryrange"
"github.com/grafana/loki/pkg/storage"
serverutil "github.com/grafana/loki/pkg/util/server"
"github.com/grafana/loki/pkg/util/validation"
)

Expand Down Expand Up @@ -140,37 +141,33 @@ func New(cfg Config) (*Loki, error) {
}

func (t *Loki) setupAuthMiddleware() {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{serverutil.RecoveryGRPCUnaryInterceptor}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{serverutil.RecoveryGRPCStreamInterceptor}
if t.cfg.AuthEnabled {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{
middleware.ServerUserHeaderInterceptor,
}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
switch info.FullMethod {
// Don't check auth header on TransferChunks, as we weren't originally
// sending it and this could cause transfers to fail on update.
//
// Also don't check auth /frontend.Frontend/Process, as this handles
// queries for multiple users.
case "/logproto.Ingester/TransferChunks", "/frontend.Frontend/Process":
return handler(srv, ss)
default:
return middleware.StreamServerUserHeaderInterceptor(srv, ss, info, handler)
}
},
}
t.cfg.Server.GRPCMiddleware = append(t.cfg.Server.GRPCMiddleware, middleware.ServerUserHeaderInterceptor)
t.cfg.Server.GRPCStreamMiddleware = append(t.cfg.Server.GRPCStreamMiddleware, GRPCStreamAuthInterceptor)
t.httpAuthMiddleware = middleware.AuthenticateUser
} else {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{
fakeGRPCAuthUniaryMiddleware,
}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{
fakeGRPCAuthStreamMiddleware,
}
t.cfg.Server.GRPCMiddleware = append(t.cfg.Server.GRPCMiddleware, fakeGRPCAuthUnaryMiddleware)
t.cfg.Server.GRPCStreamMiddleware = append(t.cfg.Server.GRPCStreamMiddleware, fakeGRPCAuthStreamMiddleware)
t.httpAuthMiddleware = fakeHTTPAuthMiddleware
}
}

var GRPCStreamAuthInterceptor = func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
switch info.FullMethod {
// Don't check auth header on TransferChunks, as we weren't originally
// sending it and this could cause transfers to fail on update.
//
// Also don't check auth /frontend.Frontend/Process, as this handles
// queries for multiple users.
case "/logproto.Ingester/TransferChunks", "/frontend.Frontend/Process":
return handler(srv, ss)
default:
return middleware.StreamServerUserHeaderInterceptor(srv, ss, info, handler)
}
}

func (t *Loki) initModuleServices(target moduleName) (map[moduleName]services.Service, error) {
servicesMap := map[moduleName]services.Service{}

Expand Down
13 changes: 11 additions & 2 deletions pkg/loki/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/grafana/loki/pkg/querier/queryrange"
loki_storage "github.com/grafana/loki/pkg/storage"
"github.com/grafana/loki/pkg/storage/stores/local"
serverutil "github.com/grafana/loki/pkg/util/server"
"github.com/grafana/loki/pkg/util/validation"
)

Expand Down Expand Up @@ -146,6 +147,7 @@ func (t *Loki) initDistributor() (services.Service, error) {
}

pushHandler := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
t.httpAuthMiddleware,
).Wrap(http.HandlerFunc(t.distributor.PushHandler))

Expand All @@ -167,8 +169,9 @@ func (t *Loki) initQuerier() (services.Service, error) {
}

httpMiddleware := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
t.httpAuthMiddleware,
querier.NewPrepopulateMiddleware(),
serverutil.NewPrepopulateMiddleware(),
)
t.server.HTTP.Handle("/loki/api/v1/query_range", httpMiddleware.Wrap(http.HandlerFunc(t.querier.RangeQueryHandler)))
t.server.HTTP.Handle("/loki/api/v1/query", httpMiddleware.Wrap(http.HandlerFunc(t.querier.InstantQueryHandler)))
Expand Down Expand Up @@ -295,7 +298,13 @@ func (t *Loki) initQueryFrontend() (_ services.Service, err error) {
t.frontend.Wrap(tripperware)
frontend.RegisterFrontendServer(t.server.GRPC, t.frontend)

frontendHandler := queryrange.StatsHTTPMiddleware.Wrap(t.httpAuthMiddleware.Wrap(t.frontend.Handler()))
frontendHandler := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
queryrange.StatsHTTPMiddleware,
t.httpAuthMiddleware,
serverutil.NewPrepopulateMiddleware(),
).Wrap(t.frontend.Handler())

t.server.HTTP.Handle("/loki/api/v1/query_range", frontendHandler)
t.server.HTTP.Handle("/loki/api/v1/query", frontendHandler)
t.server.HTTP.Handle("/loki/api/v1/label", frontendHandler)
Expand Down
84 changes: 24 additions & 60 deletions pkg/querier/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@ import (
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/promql"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/middleware"
"github.com/weaveworks/common/user"

"github.com/grafana/loki/pkg/loghttp"
loghttp_legacy "github.com/grafana/loki/pkg/loghttp/legacy"
"github.com/grafana/loki/pkg/logql"
"github.com/grafana/loki/pkg/logql/marshal"
marshal_legacy "github.com/grafana/loki/pkg/logql/marshal/legacy"
serverutil "github.com/grafana/loki/pkg/util/server"
)

const (
wsPingPeriod = 1 * time.Second

// StatusClientClosedRequest is the status code for when a client request cancellation of an http request
StatusClientClosedRequest = 499
)

type QueryResponse struct {
Expand All @@ -41,24 +38,24 @@ func (q *Querier) RangeQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseRangeQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Interval, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
Expand All @@ -71,24 +68,24 @@ func (q *Querier) InstantQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseInstantQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

query := q.engine.NewInstantQuery(request.Query, request.Ts, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
Expand All @@ -101,41 +98,41 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {

request, err := loghttp.ParseRangeQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
request.Query, err = parseRegexQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

expr, err := logql.ParseExpr(request.Query)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

// short circuit metric queries
if _, ok := expr.(logql.SampleExpr); ok {
writeError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w)
return
}

if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Interval, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

if err := marshal_legacy.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
Expand All @@ -144,13 +141,13 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseLabelQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

resp, err := q.Label(r.Context(), req)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

Expand All @@ -160,7 +157,7 @@ func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
err = marshal_legacy.WriteLabelResponseJSON(*resp, w)
}
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
Expand All @@ -174,13 +171,13 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {

req, err := loghttp.ParseTailQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

req.Query, err = parseRegexQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

Expand Down Expand Up @@ -281,39 +278,23 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) SeriesHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseSeriesQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}

resp, err := q.Series(r.Context(), req)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}

err = marshal.WriteSeriesResponseJSON(*resp, w)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}

// NewPrepopulateMiddleware creates a middleware which will parse incoming http forms.
// This is important because some endpoints can POST x-www-form-urlencoded bodies instead of GET w/ query strings.
func NewPrepopulateMiddleware() middleware.Interface {
return middleware.Func(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return

}
next.ServeHTTP(w, req)
})
})
}

// parseRegexQuery parses regex and query querystring from httpRequest and returns the combined LogQL query.
// This is used only to keep regexp query string support until it gets fully deprecated.
func parseRegexQuery(httpRequest *http.Request) (string, error) {
Expand All @@ -329,23 +310,6 @@ func parseRegexQuery(httpRequest *http.Request) (string, error) {
return query, nil
}

func writeError(err error, w http.ResponseWriter) {
switch {
case err == context.Canceled:
http.Error(w, err.Error(), StatusClientClosedRequest)
case err == context.DeadlineExceeded:
http.Error(w, err.Error(), http.StatusGatewayTimeout)
case logql.IsParseError(err):
http.Error(w, err.Error(), http.StatusBadRequest)
default:
if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok {
http.Error(w, string(grpcErr.Body), int(grpcErr.Code))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

func (q *Querier) validateEntriesLimits(ctx context.Context, limit uint32) error {
userID, err := user.ExtractOrgID(ctx)
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions pkg/util/server/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package server

import (
"context"
"net/http"

"github.com/weaveworks/common/httpgrpc"

"github.com/grafana/loki/pkg/logql"
)

// StatusClientClosedRequest is the status code for when a client request cancellation of an http request
const StatusClientClosedRequest = 499

// WriteError write a go error with the correct status code.
func WriteError(err error, w http.ResponseWriter) {
switch {
case err == context.Canceled:
http.Error(w, err.Error(), StatusClientClosedRequest)
case err == context.DeadlineExceeded:
http.Error(w, err.Error(), http.StatusGatewayTimeout)
case logql.IsParseError(err):
http.Error(w, err.Error(), http.StatusBadRequest)
default:
if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok {
http.Error(w, string(grpcErr.Body), int(grpcErr.Code))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
Loading

0 comments on commit bce4470

Please sign in to comment.