From 9865a37575fba7a6eeb9e6d9c39bf5d48ca72782 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Thu, 16 May 2024 21:08:04 +0800 Subject: [PATCH] fix: improve logging structure (#1583) ## What kind of change does this PR introduce? * Remove unformatted logs which do not confirm to JSON * Previously, we were logging both `time` (not UTC) and `timestamp` (in UTC) which is redundant. I've opted to remove `timestamp` and just log the UTC time as the `time` field, which is supported by logrus * Previously, the `request_id` was not being logged because it was unable to retrieve the context properly. Now, the `request_id` field is added to every log entry, which allows us to filter by `request_id` to see the entire lifecycle of the request * Previously, panics weren't being handled properly and they were just logged as text instead of JSON. The server would return an empty reply, which leads to ugly responses like "Unexpected token < in JSON..." if using fetch in JS. Now, the server returns a proper 500 error response: `{"code":500,"error_code":"unexpected_failure","msg":"Internal Server Error"}` * Added tests for `recoverer` and `NewStructuredLogger` to prevent regression * Remove "request started" log since the `request_id` can be used to keep track of the entire request lifecycle. This cuts down on the noise to signal ratio as well. ## Log format * Panics are now logged like this (note the additional fields like `panic` and `stack` - which is a dump of the stack trace): ```json { "component":"api", "duration":6065700500, "level":"info", "method":"GET", "msg":"request completed", "panic":"test panic", "path":"/panic", "referer":"http://localhost:3001", "remote_addr":"127.0.0.1", "request_id":"4cde5f20-2c3c-4645-bc75-52d6231e22e2", "stack":"goroutine 82 [running]:...rest of stack trace omitted for brevity", "status":500, "time":"2024-05-15T09:37:42Z" } ``` * Requests that call `NewAuditLogEntry` will be logged with the `auth_event` payload in this format (note that the timestamp field no longer exists) ```json { "auth_event": { "action": "token_refreshed", "actor_id": "733fb34d-a6f2-43e1-976a-8e6a456b6889", "actor_name": "Kang Ming Tay", "actor_username": "kang.ming1996@gmail.com", "actor_via_sso": false, "log_type": "token" }, "component": "api", "duration": 75945042, "level": "info", "method": "POST", "msg": "request completed", "path": "/token", "referer": "http://localhost:3001", "remote_addr": "127.0.0.1", "request_id": "08c7e47b-42f4-44dc-a39b-7275ef5bbb45", "status": 200, "time": "2024-05-15T09:40:09Z" } ``` --- internal/api/api.go | 11 +-- internal/api/auth.go | 8 +- internal/api/context.go | 16 ---- internal/api/errors.go | 43 +++++----- internal/api/errors_test.go | 41 +++++++++ internal/api/external.go | 6 +- internal/api/external_oauth.go | 2 +- internal/api/helpers.go | 18 ---- internal/api/hooks.go | 2 +- internal/api/middleware.go | 8 +- internal/api/samlacs.go | 2 +- internal/api/token.go | 2 +- internal/api/token_oidc.go | 4 +- internal/api/verify.go | 4 +- internal/observability/logging.go | 23 ++++- internal/observability/request-logger.go | 85 ++++++++++++------- internal/observability/request-logger_test.go | 72 ++++++++++++++++ internal/utilities/context.go | 28 ++++++ 18 files changed, 260 insertions(+), 115 deletions(-) create mode 100644 internal/observability/request-logger_test.go create mode 100644 internal/utilities/context.go diff --git a/internal/api/api.go b/internal/api/api.go index 9c60dbe7e..26207861f 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -98,21 +98,20 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig) r := newRouter() + r.UseBypass(observability.AddRequestID(globalConfig)) + r.UseBypass(logger) + r.UseBypass(xffmw.Handler) + r.UseBypass(recoverer) if globalConfig.API.MaxRequestDuration > 0 { r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration)) } - r.Use(addRequestID(globalConfig)) - // request tracing should be added only when tracing or metrics is enabled if globalConfig.Tracing.Enabled || globalConfig.Metrics.Enabled { r.UseBypass(observability.RequestTracing()) } - r.UseBypass(xffmw.Handler) - r.Use(recoverer) - if globalConfig.DB.CleanupEnabled { cleanup := models.NewCleanup(globalConfig) r.UseBypass(api.databaseCleanup(cleanup)) @@ -121,7 +120,6 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Get("/health", api.HealthCheck) r.Route("/callback", func(r *router) { - r.UseBypass(logger) r.Use(api.isValidExternalHost) r.Use(api.loadFlowState) @@ -130,7 +128,6 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati }) r.Route("/", func(r *router) { - r.UseBypass(logger) r.Use(api.isValidExternalHost) r.Get("/settings", api.Settings) diff --git a/internal/api/auth.go b/internal/api/auth.go index 4acbe4f95..c167d8212 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "net/http" - "time" + "strings" "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt" @@ -44,11 +44,10 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte return ctx, nil } -func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Context, error) { +func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { // Find the administrative user claims := getClaims(ctx) if claims == nil { - fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token") } @@ -59,8 +58,7 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil } - fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed") + return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) } func (a *API) extractBearerToken(r *http.Request) (string, error) { diff --git a/internal/api/context.go b/internal/api/context.go index 501ab49f2..b357299a6 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -16,7 +16,6 @@ func (c contextKey) String() string { const ( tokenKey = contextKey("jwt") - requestIDKey = contextKey("request_id") inviteTokenKey = contextKey("invite_token") signatureKey = contextKey("signature") externalProviderTypeKey = contextKey("external_provider_type") @@ -57,21 +56,6 @@ func getClaims(ctx context.Context) *AccessTokenClaims { return token.Claims.(*AccessTokenClaims) } -// withRequestID adds the provided request ID to the context. -func withRequestID(ctx context.Context, id string) context.Context { - return context.WithValue(ctx, requestIDKey, id) -} - -// getRequestID reads the request ID from the context. -func getRequestID(ctx context.Context) string { - obj := ctx.Value(requestIDKey) - if obj == nil { - return "" - } - - return obj.(string) -} - // withUser adds the user to the context. func withUser(ctx context.Context, u *models.User) context.Context { return context.WithValue(ctx, userKey, u) diff --git a/internal/api/errors.go b/internal/api/errors.go index 7fc5472e0..2d40a53f4 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -148,27 +148,28 @@ func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...in // Recoverer is a middleware that recovers from panics, logs the panic (and a // backtrace), and returns a HTTP 500 (Internal Server Error) status if // possible. Recoverer prints a request ID if one is provided. -func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) { - defer func() { - if rvr := recover(); rvr != nil { - - logEntry := observability.GetLogEntry(r) - if logEntry != nil { - logEntry.Panic(rvr, debug.Stack()) - } else { - fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) - debug.PrintStack() - } +func recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + logEntry := observability.GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) + debug.PrintStack() + } - se := &HTTPError{ - HTTPStatus: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), + se := &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + } + HandleResponseError(se, w, r) } - HandleResponseError(se, w, r) - } - }() - - return nil, nil + }() + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) } // ErrorCause is an error interface that contains the method Cause() for returning root cause errors @@ -182,8 +183,8 @@ type HTTPErrorResponse20240101 struct { } func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { - log := observability.GetLogEntry(r) - errorID := getRequestID(r.Context()) + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) if averr != nil { diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go index fc6135205..5524672e7 100644 --- a/internal/api/errors_test.go +++ b/internal/api/errors_test.go @@ -1,11 +1,16 @@ package api import ( + "bytes" + "encoding/json" "net/http" "net/http/httptest" "testing" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" ) func TestHandleResponseErrorWithHTTPError(t *testing.T) { @@ -62,3 +67,39 @@ func TestHandleResponseErrorWithHTTPError(t *testing.T) { require.Equal(t, example.ExpectedBody, rec.Body.String()) } } + +func TestRecoverer(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + require.NoError(t, observability.ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + panicHandler := recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + })) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + panicHandler.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + var data HTTPError + + // panic should return an internal server error + require.NoError(t, json.NewDecoder(w.Body).Decode(&data)) + require.Equal(t, ErrorCodeUnexpectedFailure, data.ErrorCode) + require.Equal(t, http.StatusInternalServerError, data.HTTPStatus) + require.Equal(t, "Internal Server Error", data.Message) + + // panic should log the error message internally + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "request panicked", logs["msg"]) + require.Equal(t, "test panic", logs["panic"]) + require.NotEmpty(t, logs["stack"]) +} diff --git a/internal/api/external.go b/internal/api/external.go index ab67d0d93..cf1736f03 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -71,7 +71,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ } redirectURL := utilities.GetReferrer(r, config) - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry log.WithField("provider", providerType).Info("Redirecting to external provider") if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { return "", err @@ -573,8 +573,8 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { ctx := r.Context() - log := observability.GetLogEntry(r) - errorID := getRequestID(ctx) + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(ctx) err := handler(w, r) if err != nil { q := getErrorQueryString(err, errorID, log, u.Query()) diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 6c0972ea8..af3dd51f4 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -69,7 +69,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry log.WithFields(logrus.Fields{ "provider": providerType, "code": oauthCode, diff --git a/internal/api/helpers.go b/internal/api/helpers.go index d771dca40..bcdce1416 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -6,30 +6,12 @@ import ( "fmt" "net/http" - "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/utilities" ) -func addRequestID(globalConfig *conf.GlobalConfiguration) middlewareHandler { - return func(w http.ResponseWriter, r *http.Request) (context.Context, error) { - id := "" - if globalConfig.API.RequestIDHeader != "" { - id = r.Header.Get(globalConfig.API.RequestIDHeader) - } - if id == "" { - uid := uuid.Must(uuid.NewV4()) - id = uid.String() - } - - ctx := r.Context() - ctx = withRequestID(ctx, id) - return ctx, nil - } -} - func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { w.Header().Set("Content-Type", "application/json") b, err := json.Marshal(obj) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 4efc33512..a2f693200 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -85,7 +85,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) defer cancel() - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry requestURL := hookConfig.URI hookLog := log.WithFields(logrus.Fields{ "component": "auth_hook", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 94efd156b..8360af0a0 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -62,7 +62,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { key := req.Header.Get(limitHeader) if key == "" { - log := observability.GetLogEntry(req) + log := observability.GetLogEntry(req).Entry log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") return c, nil } else { @@ -145,7 +145,7 @@ func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) return nil, err } - return a.requireAdmin(ctx, req) + return a.requireAdmin(ctx) } func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) { @@ -212,7 +212,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con } if u, err = url.ParseRequestURI(baseUrl); err != nil { // fallback to the default hostname - log := observability.GetLogEntry(req) + log := observability.GetLogEntry(req).Entry log.WithField("request_url", baseUrl).Warn(err) if u, err = url.ParseRequestURI(config.API.ExternalURL); err != nil { return ctx, err @@ -251,7 +251,7 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H } db := a.db.WithContext(r.Context()) - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry affectedRows, err := cleanup.Clean(db) if err != nil { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index d50e16a29..0916a7235 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -49,7 +49,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) config := a.config - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry relayStateValue := r.FormValue("RelayState") relayStateUUID := uuid.FromStringOrNil(relayStateValue) diff --git a/internal/api/token.go b/internal/api/token.go index 7d94af344..542c68edf 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -153,7 +153,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if wpe, ok := err.(*WeakPasswordError); ok { weakPasswordError = wpe } else { - observability.GetLogEntry(r).WithError(err).Warn("Password strength check on sign-in failed") + observability.GetLogEntry(r).Entry.WithError(err).Warn("Password strength check on sign-in failed") } } } diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index bb4370402..1c728bf86 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -25,7 +25,7 @@ type IdTokenGrantParams struct { } func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, *conf.OAuthProviderConfiguration, string, []string, error) { - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry var cfg *conf.OAuthProviderConfiguration var issuer string @@ -113,7 +113,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa // IdTokenGrant implements the id_token grant type flow func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - log := observability.GetLogEntry(r) + log := observability.GetLogEntry(r).Entry db := a.db.WithContext(ctx) config := a.config diff --git a/internal/api/verify.go b/internal/api/verify.go index 121cb4e2c..8a857f685 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -433,8 +433,8 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, // Maintain separate query params for hash and query hq := url.Values{} - log := observability.GetLogEntry(r) - errorID := getRequestID(r.Context()) + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) err.ErrorID = errorID log.WithError(err.Cause()).Info(err.Error()) if str, ok := oauthErrorMap[err.HTTPStatus]; ok { diff --git a/internal/observability/logging.go b/internal/observability/logging.go index 2e2595679..ff8ac96ea 100644 --- a/internal/observability/logging.go +++ b/internal/observability/logging.go @@ -3,6 +3,7 @@ package observability import ( "os" "sync" + "time" "github.com/bombsimon/logrusr/v3" "github.com/gobuffalo/pop/v6" @@ -22,11 +23,31 @@ var ( loggingOnce sync.Once ) +type CustomFormatter struct { + logrus.JSONFormatter +} + +func NewCustomFormatter() *CustomFormatter { + return &CustomFormatter{ + JSONFormatter: logrus.JSONFormatter{ + DisableTimestamp: false, + TimestampFormat: time.RFC3339, + }, + } +} + +func (f *CustomFormatter) Format(entry *logrus.Entry) ([]byte, error) { + // logrus doesn't support formatting the time in UTC so we need to use a custom formatter + entry.Time = entry.Time.UTC() + return f.JSONFormatter.Format(entry) +} + func ConfigureLogging(config *conf.LoggingConfig) error { var err error loggingOnce.Do(func() { - logrus.SetFormatter(&logrus.JSONFormatter{}) + formatter := NewCustomFormatter() + logrus.SetFormatter(formatter) // use a file if you want if config.File != "" { diff --git a/internal/observability/request-logger.go b/internal/observability/request-logger.go index b928d85bc..3e7a7e356 100644 --- a/internal/observability/request-logger.go +++ b/internal/observability/request-logger.go @@ -6,13 +6,37 @@ import ( "time" chimiddleware "github.com/go-chi/chi/middleware" + "github.com/gofrs/uuid" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/utilities" ) +func AddRequestID(globalConfig *conf.GlobalConfiguration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + id := uuid.Must(uuid.NewV4()).String() + if globalConfig.API.RequestIDHeader != "" { + id = r.Header.Get(globalConfig.API.RequestIDHeader) + } + ctx := r.Context() + ctx = utilities.WithRequestID(ctx, id) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} + func NewStructuredLogger(logger *logrus.Logger, config *conf.GlobalConfiguration) func(next http.Handler) http.Handler { - return chimiddleware.RequestLogger(&structuredLogger{logger, config}) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + next.ServeHTTP(w, r) + } else { + chimiddleware.RequestLogger(&structuredLogger{logger, config})(next).ServeHTTP(w, r) + } + }) + } } type structuredLogger struct { @@ -22,65 +46,62 @@ type structuredLogger struct { func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { referrer := utilities.GetReferrer(r, l.Config) - entry := &structuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)} + e := &logEntry{Entry: logrus.NewEntry(l.Logger)} logFields := logrus.Fields{ "component": "api", "method": r.Method, "path": r.URL.Path, "remote_addr": utilities.GetIPAddress(r), "referer": referrer, - "timestamp": time.Now().UTC().Format(time.RFC3339), } - if reqID := r.Context().Value("request_id"); reqID != nil { - logFields["request_id"] = reqID.(string) + if reqID := utilities.GetRequestID(r.Context()); reqID != "" { + logFields["request_id"] = reqID } - entry.Logger = entry.Logger.WithFields(logFields) - entry.Logger.Infoln("request started") - return entry + e.Entry = e.Entry.WithFields(logFields) + return e } -type structuredLoggerEntry struct { - Logger logrus.FieldLogger +// logEntry implements the chiMiddleware.LogEntry interface +type logEntry struct { + Entry *logrus.Entry } -func (l *structuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) { - l.Logger = l.Logger.WithFields(logrus.Fields{ +func (e *logEntry) Write(status, bytes int, elapsed time.Duration) { + entry := e.Entry.WithFields(logrus.Fields{ "status": status, "duration": elapsed.Nanoseconds(), }) - - l.Logger.Info("request completed") + entry.Info("request completed") + e.Entry = entry } -func (l *structuredLoggerEntry) Panic(v interface{}, stack []byte) { - l.Logger.WithFields(logrus.Fields{ +func (e *logEntry) Panic(v interface{}, stack []byte) { + entry := e.Entry.WithFields(logrus.Fields{ "stack": string(stack), "panic": fmt.Sprintf("%+v", v), - }).Panic("unhandled request panic") + }) + entry.Error("request panicked") + e.Entry = entry } -func GetLogEntry(r *http.Request) logrus.FieldLogger { - entry, _ := chimiddleware.GetLogEntry(r).(*structuredLoggerEntry) - if entry == nil { - return logrus.NewEntry(logrus.StandardLogger()) +func GetLogEntry(r *http.Request) *logEntry { + l, _ := chimiddleware.GetLogEntry(r).(*logEntry) + if l == nil { + return &logEntry{Entry: logrus.NewEntry(logrus.StandardLogger())} } - return entry.Logger + return l } -func LogEntrySetField(r *http.Request, key string, value interface{}) logrus.FieldLogger { - if entry, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*structuredLoggerEntry); ok { - entry.Logger = entry.Logger.WithField(key, value) - return entry.Logger +func LogEntrySetField(r *http.Request, key string, value interface{}) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithField(key, value) } - return nil } -func LogEntrySetFields(r *http.Request, fields logrus.Fields) logrus.FieldLogger { - if entry, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*structuredLoggerEntry); ok { - entry.Logger = entry.Logger.WithFields(fields) - return entry.Logger +func LogEntrySetFields(r *http.Request, fields logrus.Fields) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithFields(fields) } - return nil } diff --git a/internal/observability/request-logger_test.go b/internal/observability/request-logger_test.go new file mode 100644 index 000000000..7ab244c3f --- /dev/null +++ b/internal/observability/request-logger_test.go @@ -0,0 +1,72 @@ +package observability + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +const apiTestConfig = "../../hack/test.env" + +func TestLogger(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + // add request id header + config.API.RequestIDHeader = "X-Request-ID" + addRequestIdHandler := AddRequestID(config) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com/path", nil) + req.Header.Add("X-Request-ID", "test-request-id") + require.NoError(t, err) + addRequestIdHandler(logHandler).ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "api", logs["component"]) + require.Equal(t, http.MethodPost, logs["method"]) + require.Equal(t, "/path", logs["path"]) + require.Equal(t, "test-request-id", logs["request_id"]) + require.NotNil(t, logs["time"]) +} + +func TestExcludeHealthFromLogs(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://example.com/health", nil) + require.NoError(t, err) + logHandler.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + require.Empty(t, logBuffer) +} diff --git a/internal/utilities/context.go b/internal/utilities/context.go new file mode 100644 index 000000000..2818fdd6c --- /dev/null +++ b/internal/utilities/context.go @@ -0,0 +1,28 @@ +package utilities + +import "context" + +type contextKey string + +func (c contextKey) String() string { + return "gotrue api context key " + string(c) +} + +const ( + requestIDKey = contextKey("request_id") +) + +// WithRequestID adds the provided request ID to the context. +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +// GetRequestID reads the request ID from the context. +func GetRequestID(ctx context.Context) string { + obj := ctx.Value(requestIDKey) + if obj == nil { + return "" + } + + return obj.(string) +}