diff --git a/.commitlintrc.js b/.commitlintrc.js new file mode 100644 index 0000000..626b89b --- /dev/null +++ b/.commitlintrc.js @@ -0,0 +1,9 @@ +module.exports = { + extends: ["@commitlint/config-conventional"], + rules: { + "body-max-line-length": [2, "always", 200], + "subject-case": [2, "never", ["start-case", "pascal-case", "upper-case"]], + "subject-empty": [1, "never"], + "type-empty": [1, "never"], + }, +}; diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..1112feb --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,109 @@ +issues: + max-same-issues: 0 + exclude-use-default: false + exclude-rules: + - path: '_test\.go' + linters: + - bodyclose + - gocognit + - goconst + - gocyclo + - gosec + - lll + - prealloc + + # Overly picky + - linters: [revive] + text: 'package-comments' + - linters: [revive] + text: 'if-return' + + # Duplicates of errcheck + - linters: [gosec] + text: 'G104: Errors unhandled' + - linters: [gosec] + text: 'G307: Deferring unsafe method' + # Not a good rule since it ignores defaults + - linters: [gosec] + text: 'G112: Potential Slowloris Attack because ReadHeaderTimeout is not configured in the http.Server' + + # Contexts are best assigned defensively + - linters: [ineffassign] + text: 'ineffectual assignment to `ctx`' + - linters: [staticcheck] + text: 'SA4006: this value of `ctx` is never used' + + # Irrelevant for test examples + - linters: [gocritic] + path: example_test\.go + text: 'exitAfterDefer' + +run: + timeout: 5m + +linters: + enable: + - bodyclose + - errcheck + - errchkjson + - exportloopref + - goconst + - gocognit + - gocritic + - gocyclo + - godot + - gofumpt + - goimports + - gosec + - lll + - misspell + - nakedret + - nolintlint + - prealloc + - revive + - unconvert + - unparam + +linters-settings: + errcheck: + exclude-functions: + # Errors we wouldn't act on after checking + - (*database/sql.DB).Close + - (*database/sql.Rows).Close + - (io.Closer).Close + - (*os.File).Close + - (net/http.ResponseWriter).Write + + # Handled by errchkjson + - encoding/json.Marshal + - encoding/json.MarshalIndent + - (*encoding/json.Encoder).Encode + + gocognit: + min-complexity: 10 + + goconst: + min-len: 0 + min-occurrences: 3 + + gocritic: + disabled-checks: + - appendAssign + + gocyclo: + min-complexity: 10 + + goimports: + local-prefixes: github.com/morningconsult/grace + + golint: + min-confidence: 0 + + govet: + check-shadowing: true + + nakedret: + max-func-lines: 0 + + revive: + confidence: 0 diff --git a/README.md b/README.md index 7ac568b..eddb5f7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,209 @@ -# .github -Meta repository for all Morning Consult projects +# grace +[![Go Reference](https://pkg.go.dev/badge/github.com/morningconsult/grace.svg)](https://pkg.go.dev/github.com/morningconsult/grace) -# What is this? +A Go library for starting and stopping applications gracefully. -It is supposed to apply default metadata files to all projects. I learned about it from [terraform-aws-modules](https://github.com/terraform-aws-modules/.github) \ No newline at end of file +Grace facilitates gracefully starting and stopping a Go web application. +It helps with waiting for dependencies - such as sidecar upstreams - to be available +and handling operating system signals to shut down. + +Requires Go >= 1.21. + +## Usage + +In your project directory: + +```shell +go get github.com/morningconsult/grace +``` + +## Features + +* Graceful handling of upstream dependencies that might not be available when + your application starts +* Graceful shutdown of multiple HTTP servers when operating system signals are + received, allowing in-flight requests to finish. +* Automatic startup and control of a dedicated health check HTTP server. +* Passing of signal context to other non-HTTP components with a generic + function signature. + +### Gracefully shutting down an application + +Many HTTP applications need to handle graceful shutdowns so that in-flight requests +are not terminated, leaving an unsatisfactory experience for the requester. Grace +helps with this by catching operating system signals and allowing your HTTP servers +to finish processing requests before being forcefully stopped. + +To use this, add something similar to the following example to the end of your +application's entrypoint. `grace.Run` should be returned in your entrypoint/main +function. + +An absolute minimal configuration to get a graceful server would be the following: + +```go +ctx := context.Background() +httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) +}) + +// This is the absolute minimum configuration necessary to have a gracefully +// shutdown server. +g := grace.New(ctx, grace.WithServer("localhost:9090", httpHandler)) +err := g.Run(ctx) +``` + +Additionally, it will also handle setting up a health check server with any check functions +necessary. The health server will be shut down as soon as a signal is caught. This +helps to ensure that the orchestration system running your application marks it as unhealthy +and stops sending it any new requests, while the in-flight requests to your actual +application are still allowed to finish gracefully. + +An minimal example with a health check server and your application server would be similar +to the following: + +```go +httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) +}) + +dbPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping database + return nil +}) + +g := grace.New( + ctx, + grace.WithHealthCheckServer("localhost:9092", grace.WithCheckers(dbPinger)), + grace.WithServer("localhost:9090", httpHandler, grace.WithServerName("api")), +) +``` + +A full example with multiple servers, background jobs, and health checks: + +```go +// Set up database pools, other application things, server handlers, +// etc. +httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) +}) + +metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("here are the metrics")) +}) + +dbPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping database + return nil +}) + +redisPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping redis. + return nil +}) + +bgWorker := func(ctx context.Context) error { + // Start some background work + return nil +} + +// Create the new grace instance with your addresses/handlers. +// Here, we create: +// +// 1. A health check server listening on 0.0.0.0:9092 that will +// respond to requests at /-/live and /-/ready, running the dbPinger +// and redisPinger functions for each request to /-/ready. +// This overrides the default endpoints of /livez and /readyz. +// 2. Our application server on localhost:9090 with the httpHandler. +// It specifies the default read and write timeouts, and a graceful +// stop timeout of 10 seconds. +// 3. Our metrics server on localhost:9091, with a shorter stop timeout +// of 5 seconds. +// 4. A function to start a background worker process that will be called +// with the context to be notified from OS signals, allowing for background +// processes to also get stopped when a signal is received. +// 5. A custom list of operating system signals to intercept that override the +// defaults. +g := grace.New( + ctx, + grace.WithHealthCheckServer( + "0.0.0.0:9092", + grace.WithCheckers(dbPinger, redisPinger), + grace.WithLivenessEndpoint("/-/live"), + grace.WithReadinessEndpoint("/-/ready"), + ), + grace.WithServer( + "localhost:9090", + httpHandler, + grace.WithServerName("api"), + grace.WithServerReadTimeout(grace.DefaultReadTimeout), + grace.WithServerStopTimeout(10*time.Second), + grace.WithServerWriteTimeout(grace.DefaultWriteTimeout), + ), + grace.WithServer( + "localhost:9091", + metricsHandler, + grace.WithServerName("metrics"), + grace.WithServerStopTimeout(5*time.Second), + ), + grace.WithBackgroundJobs(bgWorker), + grace.WithStopSignals( + os.Interrupt, + syscall.SIGHUP, + syscall.SIGTERM, + ), +) + +if err = g.Run(ctx); err != nil { + log.Fatal(err) +} +``` + +### Waiting for dependencies + +If your application has upstream dependencies, such as a sidecar that exposes a +remote database, you can use grace to wait for them to be available before +attempting a connection. + +At the top of your application's entrypoint (before setting up database connections!) +use the `Wait` method to wait for specific addresses to respond to TCP/HTTP pings before +continuing with your application setup: + +```go +err := grace.Wait( + ctx, + 10*time.Second, + grace.WithWaitForTCP("localhost:6379"), // redis + grace.WithWaitForTCP("localhost:5432"), // postgres + grace.WithWaitForHTTP("http://localhost:9200"), // elasticsearch + grace.WithWaitForHTTP("http://localhost:19000/ready"), // envoy sidecar +) +if err != nil { + log.Fatal(err) +} +``` + +## Local Development + +### Testing + +#### Linting + +The project uses [`golangci-lint`](https://golangci-lint.run) for linting. Run +with + +```sh +golangci-lint run +``` + +Configuration is found in: + +- `./.golangci.yaml` - Linter configuration. + +#### Unit Tests + +Run unit tests with + +```sh +go test ./... +``` diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..875dfe5 --- /dev/null +++ b/example_test.go @@ -0,0 +1,207 @@ +package grace_test + +import ( + "context" + "errors" + "log" + "net" + "net/http" + "os" + "syscall" + "time" + + "github.com/morningconsult/grace" +) + +func Example_minimal() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + // Set up database pools, other application things, server handlers, + // etc. + // .... + + httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) + }) + + // This is the absolute minimum configuration necessary to have a gracefully + // shutdown server. + g := grace.New(ctx, grace.WithServer("localhost:9090", httpHandler)) + if err := g.Run(ctx); err != nil { + log.Fatal(err) + } + + // Output: +} + +func Example_minimal_with_healthcheck() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + // Set up database pools, other application things, server handlers, + // etc. + // .... + + httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) + }) + + dbPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping a database, etc. + return nil + }) + + // This is the minimum configuration for a gracefully shutdown server + // along with a health check server. This is most likely what you would + // want to implement. + g := grace.New( + ctx, + grace.WithHealthCheckServer( + "localhost:9092", + grace.WithCheckers(dbPinger), + ), + grace.WithServer( + "localhost:9090", + httpHandler, + grace.WithServerName("api"), + ), + ) + + if err := g.Run(ctx); err != nil { + log.Fatal(err) + } + + // Output: +} + +func Example_full() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Get addresses of dependencies, such as redis, postgres, etc + // from CLI flags or other configuration. Wait for them to be available + // before proceeding with setting up database connections and such. + err := grace.Wait(ctx, 10*time.Second, grace.WithWaitForTCP("example.com:80")) + if err != nil { + log.Fatal(err) + } + + // Set up database pools, other application things, server handlers, + // etc. + // .... + + httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("hello there")) + }) + + metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("here are the metrics")) + }) + + dbPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping database + return nil + }) + + redisPinger := grace.HealthCheckerFunc(func(ctx context.Context) error { + // ping redis. + return nil + }) + + bgWorker := func(ctx context.Context) error { + // Start some background work + return nil + } + + otherBackgroundWorker := func(ctx context.Context) error { + // Start some more background work + return nil + } + + // Create the new grace instance with your addresses/handlers. + g := grace.New( + ctx, + grace.WithHealthCheckServer( + "localhost:9092", + grace.WithCheckers(dbPinger, redisPinger), + grace.WithLivenessEndpoint("/-/live"), + grace.WithReadinessEndpoint("/-/ready"), + ), + grace.WithServer( + "localhost:9090", + httpHandler, + grace.WithServerName("api"), + grace.WithServerReadTimeout(grace.DefaultReadTimeout), + grace.WithServerStopTimeout(10*time.Second), + grace.WithServerWriteTimeout(grace.DefaultWriteTimeout), + ), + grace.WithServer( + "localhost:9091", + metricsHandler, + grace.WithServerName("metrics"), + grace.WithServerStopTimeout(5*time.Second), + ), + grace.WithBackgroundJobs( + bgWorker, + otherBackgroundWorker, + ), + grace.WithStopSignals( + os.Interrupt, + syscall.SIGHUP, + syscall.SIGTERM, + ), + ) + + if err = g.Run(ctx); err != nil { + log.Fatal(err) + } + + // Output: +} + +func ExampleWait() { + ctx := context.Background() + + es := &http.Server{ + Addr: "localhost:9200", + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + defer es.Shutdown(ctx) //nolint:errcheck + + go func() { + if err := es.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal(err) + } + }() + + pg, err := net.Listen("tcp", "localhost:6379") + if err != nil { + log.Fatal(err) + } + defer pg.Close() //nolint:errcheck + + redis, err := net.Listen("tcp", "localhost:5432") + if err != nil { + log.Fatal(err) + } + defer redis.Close() //nolint:errcheck + + // Get addresses of dependencies, such as redis, postgres, etc + // from CLI flags or other configuration. Wait for them to be available + // before proceeding with setting up database connections and such. + err = grace.Wait( + ctx, + 50*time.Millisecond, + grace.WithWaitForTCP("localhost:6379"), + grace.WithWaitForTCP("localhost:5432"), + grace.WithWaitForHTTP("http://localhost:9200"), + ) + if err != nil { + log.Fatal(err) + } + + // Output: +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e23cd05 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/morningconsult/grace + +go 1.21 + +require ( + github.com/hashicorp/go-cleanhttp v0.5.2 + github.com/morningconsult/serrors v0.4.0 + github.com/stretchr/testify v1.8.1 + golang.org/x/sync v0.1.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0f6abca --- /dev/null +++ b/go.sum @@ -0,0 +1,35 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/morningconsult/serrors v0.4.0 h1:9PHh5LSaEgVNdKuYMRGI0P4iVqpeK+kMZRK4P+Hlaqc= +github.com/morningconsult/serrors v0.4.0/go.mod h1:f5FEn6fh+5pGYKanDfI9BJhmWfWS4koN6dy8jR6XYNw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/grace.go b/grace.go new file mode 100644 index 0000000..bb8604e --- /dev/null +++ b/grace.go @@ -0,0 +1,409 @@ +package grace + +import ( + "context" + "errors" + "log/slog" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/morningconsult/serrors" + "golang.org/x/sync/errgroup" +) + +const ( + // DefaultReadTimeout is the default timeout for reading http requests for + // a server. + DefaultReadTimeout = 1 * time.Minute + + // DefaultWriteTimeout is the default timeout for writing http responses for + // a server. + DefaultWriteTimeout = 5 * time.Minute + + // DefaultStopTimeout is the default timeout for stopping a server after a + // signal is encountered. + DefaultStopTimeout = 10 * time.Second + + // DefaultLivenessEndpoint is the default liveness endpoint for the health server. + DefaultLivenessEndpoint = "/livez" + + // DefaultReadinessEndpoint is the default readiness endpoint for the health server. + DefaultReadinessEndpoint = "/readyz" +) + +// Grace handles graceful shutdown of http servers. +// +// Each of the servers specified will be started in order, with graceful +// handling of OS signals to allow any in-flight requests to complete +// before stopping them entirely. +// +// Additionally, a health check server will be started to receive health +// requests from external orchestration systems to confirm the aliveness +// of the application if added using [WithHealthCheckServer]. +type Grace struct { + health healthConfig + backgroundJobs []BackgroundJobFunc + logger *slog.Logger + servers []graceServer + stopSignals []os.Signal +} + +// config configures a new [Grace]. +type config struct { + BackgroundJobs []BackgroundJobFunc + Health healthConfig + Logger *slog.Logger + Servers []serverConfig + StopSignals []os.Signal +} + +// An Option is used to modify the grace config. +type Option func(cfg config) config + +// WithStopSignals sets the stop signals to listen for. +// +// StopSignals are the signals to listen for to gracefully stop servers when +// encountered. If not specified, it defaults to [os.Interrupt], +// [syscall.SIGHUP], and [syscall.SIGTERM]. +func WithStopSignals(signals ...os.Signal) Option { + return func(cfg config) config { + cfg.StopSignals = signals + return cfg + } +} + +// WithLogger configures the logger to use. +func WithLogger(logger *slog.Logger) Option { + return func(cfg config) config { + cfg.Logger = logger + + for i := range cfg.Servers { + cfg.Servers[i].Logger = logger + } + + return cfg + } +} + +// healthConfig configures the [config.Health] of grace. +type healthConfig struct { + Addr string + Checkers []HealthChecker + LivenessEndpoint string + ReadinessEndpoint string +} + +// A HealthOption is is used to modify the grace health check server config. +type HealthOption func(cfg healthConfig) healthConfig + +// WithHealthCheckServer adds a health check server to be run on the provided +// address in the form "ip:port" or "host:port". The checkers are the health +// checking functions to run for each request to the health check server. +func WithHealthCheckServer(addr string, opts ...HealthOption) Option { + return func(cfg config) config { + health := healthConfig{ + Addr: addr, + LivenessEndpoint: DefaultLivenessEndpoint, + ReadinessEndpoint: DefaultReadinessEndpoint, + } + + for _, op := range opts { + health = op(health) + } + + cfg.Health = health + return cfg + } +} + +// WithCheckers sets the [HealthChecker] functions to the health server will run. +func WithCheckers(checkers ...HealthChecker) HealthOption { + return func(cfg healthConfig) healthConfig { + cfg.Checkers = checkers + return cfg + } +} + +// WithLivenessEndpoint sets the liveness endpoint for the health check server. +// If not used, it will default to [DefaultLivenessEndpoint]. +func WithLivenessEndpoint(endpoint string) HealthOption { + return func(cfg healthConfig) healthConfig { + cfg.LivenessEndpoint = endpoint + return cfg + } +} + +// WithReadinessEndpoint sets the liveness endpoint for the health check server. +// If not used, it will default to [DefaultReadinessEndpoint]. +func WithReadinessEndpoint(endpoint string) HealthOption { + return func(cfg healthConfig) healthConfig { + cfg.ReadinessEndpoint = endpoint + return cfg + } +} + +// BackgroundJobFunc is a function to invoke with the context returned from +// [signal.NotifyContext]. This can be used to ensure that non-http servers +// in the application, such as background workers, can also be tied into the +// signal context. +// +// The function will be called within a [golang.org/x/sync/errgroup.Group] and +// must be blocking. +type BackgroundJobFunc func(ctx context.Context) error + +// WithBackgroundJobs sets the [BackgroundJobFunc] functions that will be +// invoked when [Run] is called. +func WithBackgroundJobs(jobs ...BackgroundJobFunc) Option { + return func(cfg config) config { + cfg.BackgroundJobs = jobs + return cfg + } +} + +// serverConfig is the configuration for a single server. +type serverConfig struct { + Addr string + Handler http.Handler + Logger *slog.Logger + Name string + ReadTimeout time.Duration + StopTimeout time.Duration + WriteTimeout time.Duration +} + +// A ServerOption is used to modify a server config. +type ServerOption func(cfg serverConfig) serverConfig + +// WithServer adds a new server to be handled by grace with the provided address +// and [http.Handler]. The address of the server to listen on +// should be in the form 'ip:port' or 'host:port'. +// +// The server's [http.Server.BaseContext] will be set to the context used when [New] +// is invoked. +func WithServer(addr string, handler http.Handler, options ...ServerOption) Option { + return func(cfg config) config { + srv := serverConfig{ + Addr: addr, + Handler: handler, + Logger: cfg.Logger, + ReadTimeout: DefaultReadTimeout, + StopTimeout: DefaultStopTimeout, + WriteTimeout: DefaultWriteTimeout, + } + for _, fn := range options { + srv = fn(srv) + } + + cfg.Servers = append(cfg.Servers, srv) + return cfg + } +} + +// WithServerName sets the name of the server, which is a helpful name for the server +// for logging purposes. +func WithServerName(name string) ServerOption { + return func(cfg serverConfig) serverConfig { + cfg.Name = name + return cfg + } +} + +// WithServerStopTimeout sets the stop timeout for the server. +// +// The StopTimeout is the amount of time to wait for the server to exit +// before forcing a shutdown. This determines the period that the +// "graceful" shutdown will last. +// +// If not used, the StopTimeout defaults to [DefaultStopTimeout]. +// A timeout of 0 will result in the server being shut down immediately. +func WithServerStopTimeout(timeout time.Duration) ServerOption { + return func(cfg serverConfig) serverConfig { + cfg.StopTimeout = timeout + return cfg + } +} + +// WithServerReadTimeout sets the read timeout for the server. +// +// ReadTimeout is the [http.Server.ReadTimeout] for the server. +// If not used, the ReadTimeout defaults to [DefaultReadTimeout]. +func WithServerReadTimeout(timeout time.Duration) ServerOption { + return func(cfg serverConfig) serverConfig { + cfg.ReadTimeout = timeout + return cfg + } +} + +// WithServerWriteTimeout sets the read timeout for the server. +// +// WriteTimeout is the [http.Server.WriteTimeout] for the server. +// If not used, the WriteTimeout defaults to [DefaultWriteTimeout]. +func WithServerWriteTimeout(timeout time.Duration) ServerOption { + return func(cfg serverConfig) serverConfig { + cfg.WriteTimeout = timeout + return cfg + } +} + +// New creates a new grace. Specify one or more [Option] to configure the new +// grace client. +// +// The provided context will be used as the base context for all created servers. +// +// New does not start listening for OS signals, it only creates the new grace that +// can be started by calling [Grace.Run]. +func New(ctx context.Context, options ...Option) Grace { + cfg := config{ + Logger: slog.Default(), + StopSignals: []os.Signal{os.Interrupt, syscall.SIGHUP, syscall.SIGTERM}, + } + + for _, op := range options { + cfg = op(cfg) + } + + srvs := make([]graceServer, 0, len(cfg.Servers)) + for _, srv := range cfg.Servers { + srvs = append(srvs, newGraceServer(ctx, srv)) + } + + return Grace{ + backgroundJobs: cfg.BackgroundJobs, + health: cfg.Health, + logger: cfg.Logger, + servers: srvs, + stopSignals: cfg.StopSignals, + } +} + +// Run starts all of the registered servers and creates a new health check server, +// if configured with [WithHealthCheckServer]. +// +// The created health check server will not be gracefully shutdown and will +// instead be stopped as soon as any stop signals are encountered or +// the context is finished. This is to ensure that any health checks to the +// application begin to fail immediately. +// +// They will all be stopped gracefully when the configured stop signals +// are encountered or the provided context is finished. +func (g Grace) Run(ctx context.Context) error { + ctx, stop := signal.NotifyContext(ctx, g.stopSignals...) + defer stop() + + if g.health.Addr != "" { + g.servers = append(g.servers, newGraceServer(ctx, serverConfig{ + Addr: g.health.Addr, + Handler: newHealthHandler( + g.logger, + g.health.LivenessEndpoint, + g.health.ReadinessEndpoint, + g.health.Checkers..., + ), + Logger: g.logger, + Name: "health", + ReadTimeout: DefaultReadTimeout, + WriteTimeout: DefaultWriteTimeout, + })) + } + + return g.listenAndServe(ctx) +} + +// listenAndServe starts the given servers with graceful shutdown handling for +// interrupts. +func (g Grace) listenAndServe(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + + for _, server := range g.servers { + server := server + eg.Go(func() error { + return server.start(ctx) + }) + + eg.Go(func() error { + // We need to block on the context being done otherwise we would be + // shutting down immediately. The context will be done when the parent + // context passed to listenAndServe gets canceled, or from the first + // error returned from any other goroutine in the group. + <-ctx.Done() + return server.stop(ctx) + }) + } + + for _, job := range g.backgroundJobs { + job := job + eg.Go(func() error { + return job(ctx) + }) + } + + return eg.Wait() +} + +// newGraceServer creates a new [graceServer] from the provided configuration. +func newGraceServer(ctx context.Context, cfg serverConfig) graceServer { + return graceServer{ + HTTPServer: &http.Server{ + BaseContext: func(net.Listener) context.Context { return ctx }, + Addr: cfg.Addr, + Handler: cfg.Handler, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + }, + Logger: cfg.Logger, + Name: cfg.Name, + StopTimeout: cfg.StopTimeout, + } +} + +// graceServer is a single http graceServer with its human-readable name +// and stop timeout. +type graceServer struct { + HTTPServer *http.Server + Logger *slog.Logger + Name string + StopTimeout time.Duration +} + +// start starts the graceServer. +func (gs graceServer) start(ctx context.Context) error { + gs.Logger.InfoContext(ctx, "server listening", + "server", gs.Name, + "address", gs.HTTPServer.Addr, + ) + if err := gs.HTTPServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return serrors.WithStack(err) + } + return nil +} + +// stop gracefully shuts down a grace HTTP server. +func (gs graceServer) stop(ctx context.Context) error { + // We detach the passed context here because it could be canceled already, + // which would defeat the purpose of gracefully draining. + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gs.StopTimeout) + defer cancel() + + gs.Logger.InfoContext(ctx, "server shutting down", + "server", gs.Name, + "address", gs.HTTPServer.Addr, + ) + defer gs.Logger.InfoContext(ctx, "server shut down", + "server", gs.Name, + "address", gs.HTTPServer.Addr, + ) + + gs.HTTPServer.SetKeepAlivesEnabled(false) + err := gs.HTTPServer.Shutdown(ctx) + if errors.Is(err, context.DeadlineExceeded) && gs.StopTimeout == 0 { + // If the server has a StopTimeout of 0, it always ends up with + // its context deadline exceeded. Since this was explicitly set to + // 0 by the user, we ignore this as an error. + return nil + } + return serrors.WithStack(err) +} diff --git a/grace_test.go b/grace_test.go new file mode 100644 index 0000000..320a489 --- /dev/null +++ b/grace_test.go @@ -0,0 +1,237 @@ +package grace_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/morningconsult/grace" +) + +func TestGrace_Run(t *testing.T) { + t.Parallel() + + t.Run("error address in use", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + addr := newTestAddr(t) + handler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + grc := grace.New( + ctx, + grace.WithServer( + addr, + handler, + grace.WithServerReadTimeout(time.Millisecond), + grace.WithServerWriteTimeout(time.Microsecond), + grace.WithServerStopTimeout(time.Millisecond), + ), + grace.WithServer(addr, handler), + grace.WithStopSignals(os.Kill), + ) + + err := grc.Run(ctx) + require.Error(t, err, "wanted start error from grace") + + wantError := fmt.Sprintf("listen tcp %s: listen: address already in use", addr) + if strings.Contains(err.Error(), "bind") { + wantError = fmt.Sprintf("listen tcp %s: bind: address already in use", addr) + } + require.EqualError(t, err, wantError) + }) + + t.Run("error background job", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + grc := grace.New( + ctx, + grace.WithBackgroundJobs(func(ctx context.Context) error { + return errors.New("wombat") + }), + ) + + err := grc.Run(ctx) + require.EqualError(t, err, "wombat") + }) + + t.Run("success shutdown context done", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + addr1 := newRandomAddr(t) + addr2 := newRandomAddr(t) + healthAddr := newRandomAddr(t) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + var checkerWasCalled bool + t.Cleanup(func() { + assert.True(t, checkerWasCalled, "wanted checker to be called") + }) + + grc := grace.New( + ctx, + grace.WithHealthCheckServer( + healthAddr, + grace.WithCheckers(grace.HealthCheckerFunc(func(ctx context.Context) error { + checkerWasCalled = true + return nil + })), + grace.WithLivenessEndpoint("/foo"), + grace.WithReadinessEndpoint("/bar"), + ), + grace.WithServer( + addr1, + handler, + grace.WithServerName("test1"), + grace.WithServerStopTimeout(200*time.Millisecond), + ), + grace.WithServer( + addr2, + handler, + grace.WithServerName("test2"), + grace.WithServerStopTimeout(200*time.Millisecond), + ), + ) + + ctx, cancel := context.WithCancel(ctx) + g := &errgroup.Group{} + g.Go(func() error { + return grc.Run(ctx) + }) + + err := grace.Wait( + ctx, + 3*time.Second, + grace.WithWaitForTCP(addr1), + grace.WithWaitForTCP(addr2), + grace.WithWaitForTCP(healthAddr), + ) + require.NoError(t, err) + + // This is to test the server is fully online. + for _, addr := range []string{addr1, addr2} { + res, err := http.Get("http://" + addr) + require.NoError(t, err) + assert.Equal(t, http.StatusTeapot, res.StatusCode) + } + + // And the health check server is online. + for _, path := range []string{"/foo", "/bar"} { + res, err := http.Get("http://" + healthAddr + path) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + } + + cancel() + require.NoError(t, g.Wait()) + }) + + t.Run("error shutdown deadline exceeded", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + addr1 := newRandomAddr(t) + addr2 := newRandomAddr(t) + healthAddr := newRandomAddr(t) + + reqCh := make(chan struct{}) + handler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + reqCh <- struct{}{} + time.Sleep(10 * time.Second) + }) + + grc := grace.New( + ctx, + grace.WithHealthCheckServer(healthAddr), + grace.WithServer( + addr1, + handler, + grace.WithServerName("test1"), + grace.WithServerStopTimeout(0), + ), + grace.WithServer( + addr2, + handler, + grace.WithServerName("test2"), + grace.WithServerStopTimeout(100*time.Millisecond), + ), + ) + + ctx, cancel := context.WithCancel(ctx) + g := &errgroup.Group{} + g.Go(func() error { + return grc.Run(ctx) + }) + + err := grace.Wait( + ctx, + time.Second, + grace.WithWaitForTCP(addr1), + grace.WithWaitForTCP(addr2), + grace.WithWaitForTCP(healthAddr), + ) + require.NoError(t, err) + + go func() { + http.Get("http://" + addr2) //nolint:errcheck + }() + + <-reqCh + cancel() + require.EqualError(t, g.Wait(), context.DeadlineExceeded.Error()) + }) + + t.Run("WithLogger", func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, nil)) + + ctx := context.Background() + addr := newTestAddr(t) + handler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + grc := grace.New( + ctx, + grace.WithServer(addr, handler), + grace.WithLogger(logger), + ) + + err := grc.Run(ctx) + require.Error(t, err, "wanted start error from grace") + + assert.NotZero(t, buf) + }) +} + +// newTestAddr starts a new httptest server and returns it's address. +// The test server will have a noop handler for all requests. +func newTestAddr(t *testing.T) string { + ts := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + t.Cleanup(ts.Close) + return strings.TrimPrefix(ts.URL, "http://") +} + +// newRandomAddr is used to get an open port for test servers to prevent tests +// from failing due to a static port being in use. +func newRandomAddr(t *testing.T) string { + t.Helper() + ts := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + addr := strings.TrimPrefix(ts.URL, "http://") + ts.Close() + return addr +} diff --git a/health.go b/health.go new file mode 100644 index 0000000..9df442e --- /dev/null +++ b/health.go @@ -0,0 +1,76 @@ +package grace + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + + "golang.org/x/sync/errgroup" +) + +// HealthChecker is something that needs its health checked to be "ready". +type HealthChecker interface { + CheckHealth(ctx context.Context) error +} + +// HealthCheckerFunc is a function that can be used as a HealthChecker. +type HealthCheckerFunc func(context.Context) error + +// CheckHealth checks the health of a resource using the HealthCheckerFunc. +func (hcf HealthCheckerFunc) CheckHealth(ctx context.Context) error { + return hcf(ctx) +} + +// newHealthHandler returns an http.Handler capable of serving health checks. +// +// The handler returned is kubernetes aware, in that it serves a "liveness" +// endpoint under livenessEndpoint, and a "readiness" endpoint under readinessEndpoint. +func newHealthHandler( + logger *slog.Logger, + livenessEndpoint string, + readinessEndpoint string, + checkers ...HealthChecker, +) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc(livenessEndpoint, func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Set("Content-Type", "application/json") + io.WriteString(rw, `{"healthy":true}`) //nolint: errcheck + }) + + mux.HandleFunc(readinessEndpoint, func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // We do not make a group with a shared context in order to avoid a single + // check failing causing all of the other checks to erroneously fail and + // lead to it being difficult to determine which actually failed. + g := &errgroup.Group{} + for _, checker := range checkers { + checker := checker + g.Go(func() error { + err := checker.CheckHealth(ctx) + if err != nil { + logger.ErrorContext(ctx, "checking health", + "error", err, + "checker", fmt.Sprintf("%T", checker), + ) + } + return err + }) + } + + if err := g.Wait(); err != nil { + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusInternalServerError) + io.WriteString(rw, `{"ready":false}`) //nolint: errcheck + return + } + + rw.Header().Set("Content-Type", "application/json") + io.WriteString(rw, `{"ready":true}`) //nolint: errcheck + }) + + return mux +} diff --git a/health_test.go b/health_test.go new file mode 100644 index 0000000..f1a61fd --- /dev/null +++ b/health_test.go @@ -0,0 +1,125 @@ +package grace_test + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/morningconsult/grace" +) + +func TestCheckerFunc_CheckHealth(t *testing.T) { + t.Parallel() + + called := false + wantErr := errors.New("foo") + + f := func(ctx context.Context) error { + called = true + return wantErr + } + + checker := grace.HealthCheckerFunc(f) + err := checker.CheckHealth(context.Background()) + + assert.True(t, called) + assert.Equal(t, wantErr, err) +} + +func Test_NewHealthHandler(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + checkers []grace.HealthChecker + wantStatus int + }{ + { + name: "no checkers", + wantStatus: http.StatusOK, + }, + { + name: "one checker success", + checkers: []grace.HealthChecker{ + mockChecker{nil}, + }, + wantStatus: http.StatusOK, + }, + { + name: "one checker failure", + checkers: []grace.HealthChecker{ + mockChecker{errors.New("oh no")}, + }, + wantStatus: http.StatusInternalServerError, + }, + { + name: "multi checker success", + checkers: []grace.HealthChecker{ + mockChecker{nil}, + mockChecker{nil}, + mockChecker{nil}, + }, + wantStatus: http.StatusOK, + }, + { + name: "multi checker one failure", + checkers: []grace.HealthChecker{ + mockChecker{nil}, + mockChecker{errors.New("oh no")}, + mockChecker{nil}, + }, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + addr := newRandomAddr(t) + ctx, cancel := context.WithCancel(context.Background()) + grc := grace.New(ctx, grace.WithHealthCheckServer(addr, grace.WithCheckers(tt.checkers...))) + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return grc.Run(ctx) + }) + + err := grace.Wait(ctx, time.Second, grace.WithWaitForTCP(addr)) + require.NoError(t, err) + + t.Run("liveness", func(t *testing.T) { + res, err := http.Get("http://" + addr + "/livez") + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("readiness", func(t *testing.T) { + res, err := http.Get("http://" + addr + "/readyz") + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, tt.wantStatus, res.StatusCode) + }) + + cancel() + require.NoError(t, g.Wait()) + }) + } +} + +type mockChecker struct { + error +} + +func (m mockChecker) CheckHealth(context.Context) error { + return m.error +} diff --git a/wait.go b/wait.go new file mode 100644 index 0000000..a04450c --- /dev/null +++ b/wait.go @@ -0,0 +1,203 @@ +package grace + +import ( + "context" + "log/slog" + "net" + "net/http" + "regexp" + "time" + + "github.com/hashicorp/go-cleanhttp" + "github.com/morningconsult/serrors" + "golang.org/x/sync/errgroup" +) + +// Waiter is something that waits for a thing to be "ready". +type Waiter interface { + Wait(ctx context.Context) error +} + +// WaiterFunc is a function that can be used as a Waiter. +type WaiterFunc func(context.Context) error + +// Wait waits for a resource using the WaiterFunc. +func (w WaiterFunc) Wait(ctx context.Context) error { + return w(ctx) +} + +// Wait waits for all the provided checker pings to be successful until +// the specified timeout is exceeded. It will block until all of the pings are +// successful and return nil, or return an error if any checker is failing by +// the time the timeout elapses. +// +// Wait can be used to wait for dependent services like sidecar upstreams to +// be available before proceeding with other parts of an application startup. +func Wait(ctx context.Context, timeout time.Duration, opts ...WaitOption) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + cfg := waitConfig{ + logger: slog.Default(), + } + + for _, opt := range opts { + cfg = opt(cfg) + } + + g, ctx := errgroup.WithContext(ctx) + for _, waiter := range cfg.waiters { + waiter := waiter + g.Go(func() error { + return waiter.Wait(ctx) + }) + } + + return serrors.WithStack(g.Wait()) +} + +// WaitOption is a configurable option for [Wait]. +type WaitOption func(cfg waitConfig) waitConfig + +type waitConfig struct { + logger *slog.Logger + waiters []Waiter +} + +// WithWaitLogger configures the logger to use when calling [Wait]. +func WithWaitLogger(logger *slog.Logger) WaitOption { + return func(cfg waitConfig) waitConfig { + cfg.logger = logger + + for i, waiter := range cfg.waiters { + switch waiter := waiter.(type) { + case httpWaiter: + waiter.logger = logger + cfg.waiters[i] = waiter + case tcpWaiter: + waiter.logger = logger + cfg.waiters[i] = waiter + } + } + + return cfg + } +} + +// WithWaiter adds a waiter for use with [Wait]. +func WithWaiter(w Waiter) WaitOption { + return func(cfg waitConfig) waitConfig { + cfg.waiters = append(cfg.waiters, w) + return cfg + } +} + +// WithWaiterFunc adds a waiter for use with [Wait]. +func WithWaiterFunc(w WaiterFunc) WaitOption { + return func(cfg waitConfig) waitConfig { + cfg.waiters = append(cfg.waiters, w) + return cfg + } +} + +// urlRegexp is used to remove any protocol or path that might be present +// when creating a tcp waiter. +var urlRegexp = regexp.MustCompile("^(https?://)?(?P.+):(?P[0-9]+)(.*)?") + +// WithWaitForTCP makes a new TCP waiter that will ping an address and return +// once it is reachable. +func WithWaitForTCP(addr string) WaitOption { + return func(cfg waitConfig) waitConfig { + cfg.waiters = append(cfg.waiters, + tcpWaiter{ + addr: urlRegexp.ReplaceAllString(addr, "$host:$port"), + logger: cfg.logger, + }, + ) + return cfg + } +} + +type tcpWaiter struct { + addr string + logger *slog.Logger +} + +// Wait waits for something to be listening on the given TCP address. +func (w tcpWaiter) Wait(ctx context.Context) error { + for { + if err := checkContextDone(ctx, w.logger, w.addr); err != nil { + return err + } + + d := net.Dialer{ + Timeout: 300 * time.Millisecond, + } + conn, _ := d.DialContext(ctx, "tcp", w.addr) + if conn != nil { + w.logger.DebugContext(ctx, "established connection to address", + "address", w.addr, + ) + defer conn.Close() //nolint:errcheck + return nil + } + } +} + +// WithWaitForHTTP makes a new HTTP waiter that will make GET requests to a URL +// until it returns a non-500 error code. All statuses below 500 mean the dependency +// is accepting requests, even if the check is unauthorized or invalid. +func WithWaitForHTTP(url string) WaitOption { + return func(cfg waitConfig) waitConfig { + cfg.waiters = append(cfg.waiters, + httpWaiter{ + client: cleanhttp.DefaultClient(), + logger: cfg.logger, + url: url, + }, + ) + return cfg + } +} + +type httpWaiter struct { + client *http.Client + logger *slog.Logger + url string +} + +// Wait waits for something to be accepting HTTP requests. +func (w httpWaiter) Wait(ctx context.Context) error { + for { + if err := checkContextDone(ctx, w.logger, w.url); err != nil { + return err + } + + res, _ := w.client.Get(w.url) + if res == nil { + continue + } + res.Body.Close() + + if res.StatusCode < http.StatusInternalServerError { + w.logger.DebugContext(ctx, "established connection to address", + "address", w.url, + ) + return nil + } + } +} + +// checkContextDone checks if the provided context is done, and returns +// an error if it is. +func checkContextDone(ctx context.Context, logger *slog.Logger, addr string) error { + select { + case <-ctx.Done(): + logger.DebugContext(ctx, "failed to establish connection to address", + "address", addr, + ) + return serrors.Errorf("timed out connecting to %q", addr) + default: + return nil + } +} diff --git a/wait_test.go b/wait_test.go new file mode 100644 index 0000000..ec2e01c --- /dev/null +++ b/wait_test.go @@ -0,0 +1,164 @@ +package grace_test + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/morningconsult/grace" +) + +func Test_Wait(t *testing.T) { + t.Parallel() + + t.Run("success tcp", func(t *testing.T) { + t.Parallel() + + addr1 := newTestAddr(t) + addr2 := newTestAddr(t) + + ctx := context.Background() + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + err := grace.Wait( + ctx, + 50*time.Millisecond, + grace.WithWaitForTCP(addr1), + grace.WithWaitForTCP(addr2), + grace.WithWaiter(grace.WaiterFunc(func(_ context.Context) error { + return nil + })), + grace.WithWaiterFunc(func(_ context.Context) error { + return nil + }), + grace.WithWaitLogger(logger), + ) + require.NoError(t, err) + + assert.NotZero(t, buf) + }) + + t.Run("success tcp from url", func(t *testing.T) { + t.Parallel() + + urls := []string{ + "https://" + newTestAddr(t) + "/foo/bar", + newTestAddr(t) + "/foo/bar", + "https://" + newTestAddr(t), + } + for _, url := range urls { + ctx := context.Background() + err := grace.Wait(ctx, 50*time.Millisecond, grace.WithWaitForTCP(url)) + assert.NoError(t, err) + } + }) + + t.Run("success http", func(t *testing.T) { + t.Parallel() + + addr1 := "http://" + newTestAddr(t) + addr2 := "http://" + newTestAddr(t) + + var wasCalled bool + t.Cleanup(func() { assert.True(t, wasCalled, "wanted handler called") }) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wasCalled = true + // This ensures that statuses < 500 are still considered online. + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(ts.Close) + + ctx := context.Background() + + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + err := grace.Wait( + ctx, + 50*time.Millisecond, + grace.WithWaitForHTTP(addr1), + grace.WithWaitForHTTP(addr2), + grace.WithWaitForHTTP(ts.URL), + grace.WithWaitLogger(logger), + ) + require.NoError(t, err) + + assert.NotZero(t, buf) + }) + + t.Run("timeout tcp", func(t *testing.T) { + t.Parallel() + + addr1 := newTestAddr(t) + addr2 := newRandomAddr(t) + + ctx := context.Background() + err := grace.Wait( + ctx, + 50*time.Millisecond, + grace.WithWaitForTCP(addr1), + grace.WithWaitForTCP(addr2), + ) + require.Error(t, err) + + wantError := fmt.Sprintf("timed out connecting to %q", addr1) + if strings.Contains(err.Error(), addr2) { + wantError = fmt.Sprintf("timed out connecting to %q", addr2) + } + require.EqualError(t, err, wantError) + }) + + t.Run("timeout http", func(t *testing.T) { + t.Parallel() + + addr1 := "http://" + newTestAddr(t) + addr2 := "http://" + newRandomAddr(t) + + var wasCalled bool + t.Cleanup(func() { + assert.True(t, wasCalled, "wanted http server called") + }) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wasCalled = true + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(ts.Close) + addr3 := ts.URL + + ctx := context.Background() + err := grace.Wait( + ctx, + 50*time.Millisecond, + grace.WithWaitForHTTP(addr1), + grace.WithWaitForHTTP(addr2), + grace.WithWaitForHTTP(addr3), + ) + require.Error(t, err) + + var wantAddr string + switch { + case strings.Contains(err.Error(), addr1): + wantAddr = addr1 + case strings.Contains(err.Error(), addr2): + wantAddr = addr2 + case strings.Contains(err.Error(), addr3): + wantAddr = addr3 + } + require.EqualError(t, err, fmt.Sprintf("timed out connecting to %q", wantAddr)) + }) +}