diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 639448917..b98d908e0 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -12,6 +12,7 @@ - Add error detail to catch-all HTTP error response. {pull}1854[1854] - Fix issue were errors where being ignored written to elasticsearch. {pull}1896[1896] - LoadServerLimits will not overwrite specified limits when loading default/agent number specified values. {issue}1841[1841] {pull}1912[1912] +- Use seperate rate limiters for internal and external API listeners. {issue}1859[1859] {pull}1904[1904] ==== New Features diff --git a/cmd/fleet/main.go b/cmd/fleet/main.go index 976564cb8..e61113820 100644 --- a/cmd/fleet/main.go +++ b/cmd/fleet/main.go @@ -907,10 +907,10 @@ func (f *FleetServer) runSubsystems(ctx context.Context, cfg *config.Config, g * ack := api.NewAckT(&cfg.Inputs[0].Server, bulker, f.cache) st := api.NewStatusT(&cfg.Inputs[0].Server, bulker, f.cache) - router := api.NewRouter(ctx, bulker, ct, et, at, ack, st, sm, tracer, f.bi) + router := api.NewRouter(&cfg.Inputs[0].Server, bulker, ct, et, at, ack, st, sm, tracer, f.bi) g.Go(loggedRunFunc(ctx, "Http server", func(ctx context.Context) error { - return api.Run(ctx, router, &cfg.Inputs[0].Server) + return router.Run(ctx) })) return err diff --git a/internal/pkg/api/error.go b/internal/pkg/api/error.go index a044d6cdb..032fba855 100644 --- a/internal/pkg/api/error.go +++ b/internal/pkg/api/error.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/elastic/fleet-server/v7/internal/pkg/dl" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/pkg/errors" @@ -43,7 +42,6 @@ type HTTPErrResp struct { // NewHTTPErrResp creates an ErrResp from a go error func NewHTTPErrResp(err error) HTTPErrResp { - errTable := []struct { target error meta HTTPErrResp @@ -57,24 +55,6 @@ func NewHTTPErrResp(err error) HTTPErrResp { zerolog.WarnLevel, }, }, - { - limit.ErrRateLimit, - HTTPErrResp{ - http.StatusTooManyRequests, - "RateLimit", - "exceeded the rate limit", - zerolog.DebugLevel, - }, - }, - { - limit.ErrMaxLimit, - HTTPErrResp{ - http.StatusTooManyRequests, - "MaxLimit", - "exceeded the max limit", - zerolog.DebugLevel, - }, - }, { ErrAPIKeyNotEnabled, HTTPErrResp{ diff --git a/internal/pkg/api/handleAck.go b/internal/pkg/api/handleAck.go index c73808694..c8b0b3a64 100644 --- a/internal/pkg/api/handleAck.go +++ b/internal/pkg/api/handleAck.go @@ -25,7 +25,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/config" "github.com/elastic/fleet-server/v7/internal/pkg/dl" "github.com/elastic/fleet-server/v7/internal/pkg/es" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/model" "github.com/elastic/fleet-server/v7/internal/pkg/policy" @@ -42,27 +41,21 @@ func (e *HTTPError) Error() string { type AckT struct { cfg *config.Server - limit *limit.Limiter bulk bulk.Bulk cache cache.Cache } func NewAckT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache) *AckT { - log.Info(). - Interface("limits", cfg.Limits.AckLimit). - Msg("Setting config ack_limits") - return &AckT{ cfg: cfg, bulk: bulker, cache: cache, - limit: limit.NewLimiter(&cfg.Limits.AckLimit), } } -func (rt Router) handleAcks(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +//nolint:dupl // function body calls different internal handler then handleCheckin +func (rt *Router) handleAcks(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { start := time.Now() - id := ps.ByName("id") reqID := r.Header.Get(logger.HeaderRequestID) @@ -91,12 +84,6 @@ func (rt Router) handleAcks(w http.ResponseWriter, r *http.Request, ps httproute } func (ack *AckT) handleAcks(zlog *zerolog.Logger, w http.ResponseWriter, r *http.Request, id string) error { - limitF, err := ack.limit.Acquire() - if err != nil { - return err - } - defer limitF() - agent, err := authAgent(r, &id, ack.bulk, ack.cache) if err != nil { return err @@ -107,10 +94,6 @@ func (ack *AckT) handleAcks(zlog *zerolog.Logger, w http.ResponseWriter, r *http return ctx.Str(LogAccessAPIKeyID, agent.AccessAPIKeyID) }) - // Metrics; serenity now. - dfunc := cntAcks.IncStart() - defer dfunc() - return ack.processRequest(*zlog, w, r, agent) } diff --git a/internal/pkg/api/handleArtifacts.go b/internal/pkg/api/handleArtifacts.go index 13c0879c9..a5a362b3a 100644 --- a/internal/pkg/api/handleArtifacts.go +++ b/internal/pkg/api/handleArtifacts.go @@ -19,7 +19,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/cache" "github.com/elastic/fleet-server/v7/internal/pkg/config" "github.com/elastic/fleet-server/v7/internal/pkg/dl" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/model" "github.com/elastic/fleet-server/v7/internal/pkg/throttle" @@ -46,24 +45,17 @@ type ArtifactT struct { bulker bulk.Bulk cache cache.Cache esThrottle *throttle.Throttle - limit *limit.Limiter } func NewArtifactT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache) *ArtifactT { - log.Info(). - Interface("limits", cfg.Limits.ArtifactLimit). - Int("maxParallel", defaultMaxParallel). - Msg("Artifact install limits") - return &ArtifactT{ bulker: bulker, cache: cache, - limit: limit.NewLimiter(&cfg.Limits.ArtifactLimit), esThrottle: throttle.NewThrottle(defaultMaxParallel), } } -func (rt Router) handleArtifacts(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (rt *Router) handleArtifacts(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { start := time.Now() var ( @@ -112,12 +104,6 @@ func (rt Router) handleArtifacts(w http.ResponseWriter, r *http.Request, ps http } func (at ArtifactT) handleArtifacts(zlog *zerolog.Logger, r *http.Request, id, sha2 string) (io.Reader, error) { - limitF, err := at.limit.Acquire() - if err != nil { - return nil, err - } - defer limitF() - // Authenticate the APIKey; retrieve agent record. // Note: This is going to be a bit slow even if we hit the cache on the api key. // In order to validate that the agent still has that api key, we fetch the agent record from elastic. @@ -131,10 +117,6 @@ func (at ArtifactT) handleArtifacts(zlog *zerolog.Logger, r *http.Request, id, s return ctx.Str(LogAccessAPIKeyID, agent.AccessAPIKeyID) }) - // Metrics; serenity now. - dfunc := cntArtifacts.IncStart() - defer dfunc() - return at.processRequest(r.Context(), *zlog, agent, id, sha2) } diff --git a/internal/pkg/api/handleCheckin.go b/internal/pkg/api/handleCheckin.go index 2752dd147..96446c9be 100644 --- a/internal/pkg/api/handleCheckin.go +++ b/internal/pkg/api/handleCheckin.go @@ -22,7 +22,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/checkin" "github.com/elastic/fleet-server/v7/internal/pkg/config" "github.com/elastic/fleet-server/v7/internal/pkg/dl" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/model" "github.com/elastic/fleet-server/v7/internal/pkg/monitor" @@ -48,7 +47,8 @@ const ( kEncodingGzip = "gzip" ) -func (rt Router) handleCheckin(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +//nolint:dupl // function body calls different internal hander then handleAck +func (rt *Router) handleCheckin(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { start := time.Now() id := ps.ByName("id") @@ -65,12 +65,6 @@ func (rt Router) handleCheckin(w http.ResponseWriter, r *http.Request, ps httpro cntCheckin.IncError(err) resp := NewHTTPErrResp(err) - // Log this as warn for visibility that limit has been reached. - // This allows customers to tune the configuration on detection of threshold. - if errors.Is(err, limit.ErrMaxLimit) { - resp.Level = zerolog.WarnLevel - } - zlog.WithLevel(resp.Level). Err(err). Int(ECSHTTPResponseCode, resp.StatusCode). @@ -93,7 +87,6 @@ type CheckinT struct { ad *action.Dispatcher tr *action.TokenResolver bulker bulk.Bulk - limit *limit.Limiter } func NewCheckinT( @@ -107,14 +100,6 @@ func NewCheckinT( tr *action.TokenResolver, bulker bulk.Bulk, ) *CheckinT { - - log.Info(). - Interface("limits", cfg.Limits.CheckinLimit). - Dur("long_poll_timeout", cfg.Timeouts.CheckinLongPoll). - Dur("long_poll_timestamp", cfg.Timeouts.CheckinTimestamp). - Dur("long_poll_jitter", cfg.Timeouts.CheckinJitter). - Msg("Checkin install limits") - ct := &CheckinT{ verCon: verCon, cfg: cfg, @@ -124,7 +109,6 @@ func NewCheckinT( gcp: gcp, ad: ad, tr: tr, - limit: limit.NewLimiter(&cfg.Limits.CheckinLimit), bulker: bulker, } @@ -132,15 +116,8 @@ func NewCheckinT( } func (ct *CheckinT) handleCheckin(zlog *zerolog.Logger, w http.ResponseWriter, r *http.Request, id string) error { - start := time.Now() - limitF, err := ct.limit.Acquire() - if err != nil { - return err - } - defer limitF() - agent, err := authAgent(r, &id, ct.bulker, ct.cache) if err != nil { return err @@ -158,11 +135,6 @@ func (ct *CheckinT) handleCheckin(zlog *zerolog.Logger, w http.ResponseWriter, r // Safely check if the agent version is different, return empty string otherwise newVer := agent.CheckDifferentVersion(ver) - - // Metrics; serenity now. - dfunc := cntCheckin.IncStart() - defer dfunc() - return ct.processRequest(*zlog, w, r, start, agent, newVer) } diff --git a/internal/pkg/api/handleEnroll.go b/internal/pkg/api/handleEnroll.go index 9123723d6..f08e6d770 100644 --- a/internal/pkg/api/handleEnroll.go +++ b/internal/pkg/api/handleEnroll.go @@ -16,7 +16,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/cache" "github.com/elastic/fleet-server/v7/internal/pkg/config" "github.com/elastic/fleet-server/v7/internal/pkg/dl" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/model" "github.com/elastic/fleet-server/v7/internal/pkg/rollback" @@ -49,25 +48,19 @@ type EnrollerT struct { cfg *config.Server bulker bulk.Bulk cache cache.Cache - limit *limit.Limiter } func NewEnrollerT(verCon version.Constraints, cfg *config.Server, bulker bulk.Bulk, c cache.Cache) (*EnrollerT, error) { - log.Info(). - Interface("limits", cfg.Limits.EnrollLimit). - Msg("Setting config enroll_limit") - return &EnrollerT{ verCon: verCon, cfg: cfg, - limit: limit.NewLimiter(&cfg.Limits.EnrollLimit), bulker: bulker, cache: c, }, nil } -func (rt Router) handleEnroll(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (rt *Router) handleEnroll(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { start := time.Now() // Work around wonky router rule @@ -129,13 +122,6 @@ func (rt Router) handleEnroll(w http.ResponseWriter, r *http.Request, ps httprou } func (et *EnrollerT) handleEnroll(rb *rollback.Rollback, zlog *zerolog.Logger, w http.ResponseWriter, r *http.Request) (*EnrollResponse, error) { - - limitF, err := et.limit.Acquire() - if err != nil { - return nil, err - } - defer limitF() - key, err := authAPIKey(r, et.bulker, et.cache) if err != nil { return nil, err @@ -151,10 +137,6 @@ func (et *EnrollerT) handleEnroll(rb *rollback.Rollback, zlog *zerolog.Logger, w return nil, err } - // Metrics; serenity now. - dfunc := cntEnroll.IncStart() - defer dfunc() - return et.processRequest(rb, *zlog, w, r, key.ID, ver) } diff --git a/internal/pkg/api/handleStatus.go b/internal/pkg/api/handleStatus.go index 8c242058c..ff9dce021 100644 --- a/internal/pkg/api/handleStatus.go +++ b/internal/pkg/api/handleStatus.go @@ -16,7 +16,6 @@ import ( "github.com/elastic/fleet-server/v7/internal/pkg/bulk" "github.com/elastic/fleet-server/v7/internal/pkg/cache" "github.com/elastic/fleet-server/v7/internal/pkg/config" - "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/julienschmidt/httprouter" @@ -33,7 +32,6 @@ type AuthFunc func(*http.Request) (*apikey.APIKey, error) type StatusT struct { cfg *config.Server - limit *limit.Limiter bulk bulk.Bulk cache cache.Cache authfn AuthFunc @@ -42,15 +40,10 @@ type StatusT struct { type OptFunc func(*StatusT) func NewStatusT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache, opts ...OptFunc) *StatusT { - log.Info(). - Interface("limits", cfg.Limits.StatusLimit). - Msg("Setting config status_limits") - st := &StatusT{ cfg: cfg, bulk: bulker, cache: cache, - limit: limit.NewLimiter(&cfg.Limits.StatusLimit), } st.authfn = st.authenticate @@ -68,14 +61,7 @@ func (st StatusT) authenticate(r *http.Request) (*apikey.APIKey, error) { return authAPIKey(r, st.bulk, st.cache) } -func (st StatusT) handleStatus(_ *zerolog.Logger, r *http.Request, rt *Router) (resp StatusResponse, status proto.StateObserved_Status, err error) { - limitF, err := st.limit.Acquire() - // When failing to acquire a limiter send an error response. - if err != nil { - return - } - defer limitF() - +func (st StatusT) handleStatus(_ *zerolog.Logger, r *http.Request, rt *Router) (resp StatusResponse, status proto.StateObserved_Status) { authed := true if _, aerr := st.authfn(r); aerr != nil { log.Debug().Err(aerr).Msg("unauthenticated status request, return short status response only") @@ -96,16 +82,12 @@ func (st StatusT) handleStatus(_ *zerolog.Logger, r *http.Request, rt *Router) ( } } - return resp, status, nil + return resp, status } -func (rt Router) handleStatus(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (rt *Router) handleStatus(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { start := time.Now() - - dfunc := cntStatus.IncStart() - defer dfunc() - reqID := r.Header.Get(logger.HeaderRequestID) zlog := log.With(). @@ -113,22 +95,7 @@ func (rt Router) handleStatus(w http.ResponseWriter, r *http.Request, _ httprout Str("mod", kStatusMod). Logger() - resp, status, err := rt.st.handleStatus(&zlog, r, &rt) - if err != nil { - cntStatus.IncError(err) - resp := NewHTTPErrResp(err) - - zlog.WithLevel(resp.Level). - Err(err). - Int(ECSHTTPResponseCode, resp.StatusCode). - Int64(ECSEventDuration, time.Since(start).Nanoseconds()). - Msg("fail status") - - if rerr := resp.Write(w); rerr != nil { - zlog.Error().Err(rerr).Msg("fail writing error response") - } - return - } + resp, status := rt.st.handleStatus(&zlog, r, rt) data, err := json.Marshal(&resp) if err != nil { diff --git a/internal/pkg/api/router.go b/internal/pkg/api/router.go index bfa34dc05..7d219cbdf 100644 --- a/internal/pkg/api/router.go +++ b/internal/pkg/api/router.go @@ -6,10 +6,16 @@ package api import ( "context" + "crypto/tls" + "errors" + "net" "net/http" + "github.com/elastic/elastic-agent-libs/transport/tlscommon" "github.com/elastic/fleet-server/v7/internal/pkg/build" "github.com/elastic/fleet-server/v7/internal/pkg/bulk" + "github.com/elastic/fleet-server/v7/internal/pkg/config" + "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" "github.com/elastic/fleet-server/v7/internal/pkg/policy" "github.com/julienschmidt/httprouter" @@ -27,7 +33,8 @@ const ( ) type Router struct { - ctx context.Context + ctx context.Context // used only by handleEnroll, set at start of Run func + cfg *config.Server bulker bulk.Bulk ct *CheckinT et *EnrollerT @@ -35,12 +42,13 @@ type Router struct { ack *AckT st *StatusT sm policy.SelfMonitor + tracer *apm.Tracer bi build.Info } -func NewRouter(ctx context.Context, bulker bulk.Bulk, ct *CheckinT, et *EnrollerT, at *ArtifactT, ack *AckT, st *StatusT, sm policy.SelfMonitor, tracer *apm.Tracer, bi build.Info) *httprouter.Router { - r := Router{ - ctx: ctx, +func NewRouter(cfg *config.Server, bulker bulk.Bulk, ct *CheckinT, et *EnrollerT, at *ArtifactT, ack *AckT, st *StatusT, sm policy.SelfMonitor, tracer *apm.Tracer, bi build.Info) *Router { + rt := &Router{ + cfg: cfg, bulker: bulker, ct: ct, et: et, @@ -48,9 +56,18 @@ func NewRouter(ctx context.Context, bulker bulk.Bulk, ct *CheckinT, et *Enroller at: at, ack: ack, st: st, + tracer: tracer, bi: bi, } + return rt +} + +// Create a new httprouter, the passed addr is only added as a label in log messages +func (rt *Router) newHTTPRouter(addr string) *httprouter.Router { + log.Info().Str("addr", addr).Interface("limits", rt.cfg.Limits).Msg("fleet-server creating new limiter") + limiter := limit.NewHTTPWrapper(addr, &rt.cfg.Limits) + routes := []struct { method string path string @@ -59,43 +76,43 @@ func NewRouter(ctx context.Context, bulker bulk.Bulk, ct *CheckinT, et *Enroller { http.MethodGet, RouteStatus, - r.handleStatus, + limiter.WrapStatus(rt.handleStatus, &cntStatus), }, { http.MethodPost, RouteEnroll, - r.handleEnroll, + limiter.WrapEnroll(rt.handleEnroll, &cntEnroll), }, { http.MethodPost, RouteCheckin, - r.handleCheckin, + limiter.WrapCheckin(rt.handleCheckin, &cntCheckin), }, { http.MethodPost, RouteAcks, - r.handleAcks, + limiter.WrapAck(rt.handleAcks, &cntAcks), }, { http.MethodGet, RouteArtifacts, - r.handleArtifacts, + limiter.WrapArtifact(rt.handleArtifacts, &cntArtifacts), }, } router := httprouter.New() - // Install routes for _, rte := range routes { log.Info(). + Str("addr", addr). Str("method", rte.method). Str("path", rte.path). Msg("fleet-server route added") handler := rte.handler - if tracer != nil { + if rt.tracer != nil { handler = apmhttprouter.Wrap( - rte.handler, rte.path, apmhttprouter.WithTracer(tracer), + rte.handler, rte.path, apmhttprouter.WithTracer(rt.tracer), ) } router.Handle( @@ -104,8 +121,120 @@ func NewRouter(ctx context.Context, bulker bulk.Bulk, ct *CheckinT, et *Enroller logger.HTTPHandler(handler), ) } + log.Info().Str("addr", addr).Msg("fleet-server routes set up") + return router +} - log.Info().Msg("fleet-server routes set up") +// Run starts the api server on the listeners configured in the config. +// Each listener has a unique limit.Limiter to allow for non-global rate limits. +func (rt *Router) Run(ctx context.Context) error { + rt.ctx = ctx - return router + listeners := rt.cfg.BindEndpoints() + rdto := rt.cfg.Timeouts.Read + wrto := rt.cfg.Timeouts.Write + idle := rt.cfg.Timeouts.Idle + rdhr := rt.cfg.Timeouts.ReadHeader + mhbz := rt.cfg.Limits.MaxHeaderByteSize + bctx := func(net.Listener) context.Context { return ctx } + + errChan := make(chan error) + baseCtx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, addr := range listeners { + log.Info(). + Str("bind", addr). + Dur("rdTimeout", rdto). + Dur("wrTimeout", wrto). + Msg("server listening") + + server := http.Server{ + Addr: addr, + ReadTimeout: rdto, + WriteTimeout: wrto, + IdleTimeout: idle, + ReadHeaderTimeout: rdhr, + Handler: rt.newHTTPRouter(addr), // Note that we use a different router for each listener instead of wrapping with different middleware instances as it is cleaner to do + BaseContext: bctx, + ConnState: diagConn, + MaxHeaderBytes: mhbz, + ErrorLog: errLogger(), + } + + forceCh := make(chan struct{}) + defer close(forceCh) + + // handler to close server + go func() { + select { + case <-ctx.Done(): + log.Debug().Msg("force server close on ctx.Done()") + err := server.Close() + if err != nil { + log.Error().Err(err).Msg("error while closing server") + } + case <-forceCh: + log.Debug().Msg("go routine forced closed on exit") + } + }() + + var listenCfg net.ListenConfig + + ln, err := listenCfg.Listen(ctx, "tcp", addr) + if err != nil { + return err + } + + // Bind the deferred Close() to the stack variable to handle case where 'ln' is wrapped + defer func() { + err := ln.Close() + if err != nil { + log.Error().Err(err).Msg("error while closing listener.") + } + }() + + // Conn Limiter must be before the TLS handshake in the stack; + // The server should not eat the cost of the handshake if there + // is no capacity to service the connection. + // Also, it appears the HTTP2 implementation depends on the tls.Listener + // being at the top of the stack. + ln = wrapConnLimitter(ctx, ln, rt.cfg) + + if rt.cfg.TLS != nil && rt.cfg.TLS.IsEnabled() { + commonTLSCfg, err := tlscommon.LoadTLSServerConfig(rt.cfg.TLS) + if err != nil { + return err + } + server.TLSConfig = commonTLSCfg.BuildServerConfig(rt.cfg.Host) + + // Must enable http/2 in the configuration explicitly. + // (see https://golang.org/pkg/net/http/#Server.Serve) + server.TLSConfig.NextProtos = []string{"h2", "http/1.1"} + + ln = tls.NewListener(ln, server.TLSConfig) + + } else { + log.Warn().Msg("Exposed over insecure HTTP; enablement of TLS is strongly recommended") + } + + log.Debug().Msgf("Listening on %s", addr) + + go func(_ context.Context, errChan chan error, ln net.Listener) { + if err := server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err + } + }(baseCtx, errChan, ln) + + } + + select { + case err := <-errChan: + if !errors.Is(err, context.Canceled) { + return err + } + case <-baseCtx.Done(): + } + + return nil } diff --git a/internal/pkg/api/server_test.go b/internal/pkg/api/router_test.go similarity index 95% rename from internal/pkg/api/server_test.go rename to internal/pkg/api/router_test.go index 44a223c6c..a874958fd 100644 --- a/internal/pkg/api/server_test.go +++ b/internal/pkg/api/router_test.go @@ -48,13 +48,13 @@ func TestRun(t *testing.T) { et, err := NewEnrollerT(verCon, cfg, nil, c) require.NoError(t, err) - router := NewRouter(ctx, bulker, ct, et, nil, nil, nil, nil, nil, fbuild.Info{}) + router := NewRouter(cfg, bulker, ct, et, nil, nil, nil, nil, nil, fbuild.Info{}) errCh := make(chan error) var wg sync.WaitGroup wg.Add(1) go func() { - err = Run(ctx, router, cfg) + err = router.Run(ctx) wg.Done() }() var errFromChan error diff --git a/internal/pkg/api/server.go b/internal/pkg/api/server.go index 8787b2c34..32ab05358 100644 --- a/internal/pkg/api/server.go +++ b/internal/pkg/api/server.go @@ -6,13 +6,10 @@ package api import ( "context" - "crypto/tls" - "errors" slog "log" "net" "net/http" - "github.com/elastic/elastic-agent-libs/transport/tlscommon" "github.com/elastic/fleet-server/v7/internal/pkg/config" "github.com/elastic/fleet-server/v7/internal/pkg/limit" "github.com/elastic/fleet-server/v7/internal/pkg/logger" @@ -39,117 +36,6 @@ func diagConn(c net.Conn, s http.ConnState) { } } -// Run runs the passed router with the config. -func Run(ctx context.Context, router http.Handler, cfg *config.Server) error { - listeners := cfg.BindEndpoints() - rdto := cfg.Timeouts.Read - wrto := cfg.Timeouts.Write - idle := cfg.Timeouts.Idle - rdhr := cfg.Timeouts.ReadHeader - mhbz := cfg.Limits.MaxHeaderByteSize - bctx := func(net.Listener) context.Context { return ctx } - - errChan := make(chan error) - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - - for _, addr := range listeners { - log.Info(). - Str("bind", addr). - Dur("rdTimeout", rdto). - Dur("wrTimeout", wrto). - Msg("server listening") - - server := http.Server{ - Addr: addr, - ReadTimeout: rdto, - WriteTimeout: wrto, - IdleTimeout: idle, - ReadHeaderTimeout: rdhr, - Handler: router, - BaseContext: bctx, - ConnState: diagConn, - MaxHeaderBytes: mhbz, - ErrorLog: errLogger(), - } - - forceCh := make(chan struct{}) - defer close(forceCh) - - // handler to close server - go func() { - select { - case <-ctx.Done(): - log.Debug().Msg("force server close on ctx.Done()") - err := server.Close() - if err != nil { - log.Error().Err(err).Msg("error while closing server") - } - case <-forceCh: - log.Debug().Msg("go routine forced closed on exit") - } - }() - - var listenCfg net.ListenConfig - - ln, err := listenCfg.Listen(ctx, "tcp", addr) - if err != nil { - return err - } - - // Bind the deferred Close() to the stack variable to handle case where 'ln' is wrapped - defer func() { - err := ln.Close() - if err != nil { - log.Error().Err(err).Msg("error while closing listener.") - } - }() - - // Conn Limiter must be before the TLS handshake in the stack; - // The server should not eat the cost of the handshake if there - // is no capacity to service the connection. - // Also, it appears the HTTP2 implementation depends on the tls.Listener - // being at the top of the stack. - ln = wrapConnLimitter(ctx, ln, cfg) - - if cfg.TLS != nil && cfg.TLS.IsEnabled() { - commonTLSCfg, err := tlscommon.LoadTLSServerConfig(cfg.TLS) - if err != nil { - return err - } - server.TLSConfig = commonTLSCfg.BuildServerConfig(cfg.Host) - - // Must enable http/2 in the configuration explicitly. - // (see https://golang.org/pkg/net/http/#Server.Serve) - server.TLSConfig.NextProtos = []string{"h2", "http/1.1"} - - ln = tls.NewListener(ln, server.TLSConfig) - - } else { - log.Warn().Msg("Exposed over insecure HTTP; enablement of TLS is strongly recommended") - } - - log.Debug().Msgf("Listening on %s", addr) - - go func(_ context.Context, errChan chan error, ln net.Listener) { - if err := server.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err - } - }(cancelCtx, errChan, ln) - - } - - select { - case err := <-errChan: - if !errors.Is(err, context.Canceled) { - return err - } - case <-cancelCtx.Done(): - } - - return nil -} - func wrapConnLimitter(_ context.Context, ln net.Listener, cfg *config.Server) net.Listener { hardLimit := cfg.Limits.MaxConnections diff --git a/internal/pkg/limit/error.go b/internal/pkg/limit/error.go new file mode 100644 index 000000000..65bea753b --- /dev/null +++ b/internal/pkg/limit/error.go @@ -0,0 +1,51 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package limit + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/rs/zerolog/log" +) + +var ( + ErrRateLimit = errors.New("rate limit") + ErrMaxLimit = errors.New("max limit") +) + +// writeError recreates the behaviour of api/error.go. +// It is defined separately here to stop a circular import +func writeError(w http.ResponseWriter, err error) error { + resp := struct { + Status int `json:"statusCode"` + Error string `json:"error"` + Message string `json:"message"` + }{ + Status: http.StatusTooManyRequests, + Error: "UnknownLimiterError", + Message: "unknown limiter error encountered", + } + switch { + case errors.Is(err, ErrRateLimit): + resp.Error = "RateLimit" + resp.Message = "exceeded the rate limit" + case errors.Is(err, ErrMaxLimit): + resp.Error = "MaxLimit" + resp.Message = "exceeded the max limit" + default: + log.Error().Err(err).Msg("Encountered unknown limiter error") + } + p, wErr := json.Marshal(&resp) + if wErr != nil { + return wErr + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusTooManyRequests) + _, wErr = w.Write(p) + return wErr +} diff --git a/internal/pkg/limit/error_test.go b/internal/pkg/limit/error_test.go new file mode 100644 index 000000000..829e99e79 --- /dev/null +++ b/internal/pkg/limit/error_test.go @@ -0,0 +1,56 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package limit + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + err error + want string + }{{ + name: "unknown", + err: errors.New("unknown"), + want: "UnknownLimiterError", + }, { + name: "rate limit", + err: ErrRateLimit, + want: "RateLimit", + }, { + name: "max limit", + err: ErrMaxLimit, + want: "MaxLimit", + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + err := writeError(w, tt.err) + require.NoError(t, err) + resp := w.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + + var body struct { + Status int `json:"statusCode"` + Error string `json:"error"` + } + dec := json.NewDecoder(resp.Body) + err = dec.Decode(&body) + require.NoError(t, err) + require.Equal(t, http.StatusTooManyRequests, body.Status) + require.Equal(t, tt.want, body.Error) + }) + } +} diff --git a/internal/pkg/limit/httpwrapper.go b/internal/pkg/limit/httpwrapper.go new file mode 100644 index 000000000..5f9860c0b --- /dev/null +++ b/internal/pkg/limit/httpwrapper.go @@ -0,0 +1,65 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package limit + +import ( + "github.com/elastic/fleet-server/v7/internal/pkg/config" + "github.com/julienschmidt/httprouter" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +// HTTPWrapper enforces rate limits for each API endpoint. +type HTTPWrapper struct { + checkin *limiter + artifact *limiter + enroll *limiter + ack *limiter + status *limiter + log zerolog.Logger +} + +// Create a new HTTPWrapper using the specified limits. +func NewHTTPWrapper(addr string, cfg *config.ServerLimits) *HTTPWrapper { + return &HTTPWrapper{ + checkin: newLimiter(&cfg.CheckinLimit), + artifact: newLimiter(&cfg.ArtifactLimit), + enroll: newLimiter(&cfg.EnrollLimit), + ack: newLimiter(&cfg.AckLimit), + status: newLimiter(&cfg.StatusLimit), + log: log.With().Str("addr", addr).Logger(), + } +} + +// WhapCheckin wraps the checkin handler with the rate limiter and tracks statistics for the endpoint. +func (l *HTTPWrapper) WrapCheckin(h httprouter.Handle, i StatIncer) httprouter.Handle { + return l.checkin.wrap(l.log.With().Str("route", "checkin").Logger(), zerolog.WarnLevel, h, i) +} + +// WhapArtifact wraps the artifact handler with the rate limiter and tracks statistics for the endpoint. +func (l *HTTPWrapper) WrapArtifact(h httprouter.Handle, i StatIncer) httprouter.Handle { + return l.artifact.wrap(l.log.With().Str("route", "artifact").Logger(), zerolog.DebugLevel, h, i) +} + +// WhapEnroll wraps the enroll handler with the rate limiter and tracks statistics for the endpoint. +func (l *HTTPWrapper) WrapEnroll(h httprouter.Handle, i StatIncer) httprouter.Handle { + return l.enroll.wrap(l.log.With().Str("route", "enroll").Logger(), zerolog.DebugLevel, h, i) +} + +// WhapAck wraps the ack handler with the rate limiter and tracks statistics for the endpoint. +func (l *HTTPWrapper) WrapAck(h httprouter.Handle, i StatIncer) httprouter.Handle { + return l.ack.wrap(l.log.With().Str("route", "ack").Logger(), zerolog.DebugLevel, h, i) +} + +// WhapStatus wraps the checkin handler with the rate limiter and tracks statistics for the endpoint. +func (l *HTTPWrapper) WrapStatus(h httprouter.Handle, i StatIncer) httprouter.Handle { + return l.status.wrap(l.log.With().Str("route", "status").Logger(), zerolog.DebugLevel, h, i) +} + +// StatIncer is the interface used to count statistics associated with an endpoint. +type StatIncer interface { + IncError(error) + IncStart() func() +} diff --git a/internal/pkg/limit/limiter.go b/internal/pkg/limit/limiter.go index 05a4a8262..98aabd540 100644 --- a/internal/pkg/limit/limiter.go +++ b/internal/pkg/limit/limiter.go @@ -2,38 +2,34 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -// Package limit provides the ability to set the maximum connections that a server should handle +// Package limit provides the ability to rate limit the api server. package limit import ( - "errors" + "net/http" "time" "github.com/elastic/fleet-server/v7/internal/pkg/config" + "github.com/julienschmidt/httprouter" + "github.com/rs/zerolog" "golang.org/x/sync/semaphore" "golang.org/x/time/rate" ) -type Limiter struct { +type releaseFunc func() + +type limiter struct { rateLimit *rate.Limiter maxLimit *semaphore.Weighted } -type ReleaseFunc func() - -var ( - ErrRateLimit = errors.New("rate limit") - ErrMaxLimit = errors.New("max limit") -) - -func NewLimiter(cfg *config.Limit) *Limiter { - +func newLimiter(cfg *config.Limit) *limiter { if cfg == nil { - return &Limiter{} + return &limiter{} } - l := &Limiter{} + l := &limiter{} if cfg.Interval != time.Duration(0) { l.rateLimit = rate.NewLimiter(rate.Every(cfg.Interval), cfg.Burst) @@ -46,7 +42,7 @@ func NewLimiter(cfg *config.Limit) *Limiter { return l } -func (l *Limiter) Acquire() (ReleaseFunc, error) { +func (l *limiter) acquire() (releaseFunc, error) { releaseFunc := noop if l.rateLimit != nil && !l.rateLimit.Allow() { @@ -63,11 +59,30 @@ func (l *Limiter) Acquire() (ReleaseFunc, error) { return releaseFunc, nil } -func (l *Limiter) release() { +func (l *limiter) release() { if l.maxLimit != nil { l.maxLimit.Release(1) } } +func (l *limiter) wrap(logger zerolog.Logger, level zerolog.Level, h httprouter.Handle, i StatIncer) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + dfunc := i.IncStart() + defer dfunc() + + lf, err := l.acquire() + if err != nil { + logger.WithLevel(level).Err(err).Msg("limit reached") + if wErr := writeError(w, err); wErr != nil { + logger.Error().Err(wErr).Msg("fail writing error response") + } + i.IncError(err) + return + } + defer lf() + h(w, r, p) + } +} + func noop() { } diff --git a/internal/pkg/limit/limiter_test.go b/internal/pkg/limit/limiter_test.go new file mode 100644 index 000000000..6c4df66ef --- /dev/null +++ b/internal/pkg/limit/limiter_test.go @@ -0,0 +1,97 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package limit + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/julienschmidt/httprouter" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/sync/semaphore" + "golang.org/x/time/rate" +) + +type mockIncer struct { + mock.Mock +} + +func (m *mockIncer) IncError(err error) { + m.Called(err) +} + +func (m *mockIncer) IncStart() func() { + args := m.Called() + return args.Get(0).(func()) +} + +func stubHandle() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + w.WriteHeader(http.StatusOK) + } +} + +func TestWrap(t *testing.T) { + t.Run("no limits reached", func(t *testing.T) { + var b bool + var fdec = func() { b = true } + i := &mockIncer{} + i.On("IncStart").Return(fdec).Once() + l := &limiter{} + + h := l.wrap(zerolog.Nop(), zerolog.DebugLevel, stubHandle(), i) + w := httptest.NewRecorder() + h(w, &http.Request{}, httprouter.Params{}) + + resp := w.Result() + resp.Body.Close() + i.AssertExpectations(t) + assert.True(t, b, "expected dec func to have been called") + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + t.Run("max limit reached", func(t *testing.T) { + var b bool + var fdec = func() { b = true } + i := &mockIncer{} + i.On("IncStart").Return(fdec).Once() + i.On("IncError", ErrMaxLimit).Once() + l := &limiter{ + maxLimit: semaphore.NewWeighted(0), + } + + h := l.wrap(zerolog.Nop(), zerolog.DebugLevel, stubHandle(), i) + w := httptest.NewRecorder() + h(w, &http.Request{}, httprouter.Params{}) + + resp := w.Result() + resp.Body.Close() + i.AssertExpectations(t) + assert.True(t, b, "expected dec func to have been called") + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + }) + t.Run("rate limit reached", func(t *testing.T) { + var b bool + var fdec = func() { b = true } + i := &mockIncer{} + i.On("IncStart").Return(fdec).Once() + i.On("IncError", ErrRateLimit).Once() + l := &limiter{ + rateLimit: rate.NewLimiter(rate.Limit(0), 0), + } + + h := l.wrap(zerolog.Nop(), zerolog.DebugLevel, stubHandle(), i) + w := httptest.NewRecorder() + h(w, &http.Request{}, httprouter.Params{}) + + resp := w.Result() + resp.Body.Close() + i.AssertExpectations(t) + assert.True(t, b, "expected dec func to have been called") + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + }) +}