Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve logging structure #1583

Merged
merged 8 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,20 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig)

r := newRouter()
r.UseBypass(observability.AddRequestID(globalConfig))
r.UseBypass(logger)
r.UseBypass(xffmw.Handler)
r.UseBypass(recoverer)

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
}

r.Use(addRequestID(globalConfig))

// request tracing should be added only when tracing or metrics is enabled
if globalConfig.Tracing.Enabled || globalConfig.Metrics.Enabled {
r.UseBypass(observability.RequestTracing())
}

r.UseBypass(xffmw.Handler)
r.Use(recoverer)

if globalConfig.DB.CleanupEnabled {
cleanup := models.NewCleanup(globalConfig)
r.UseBypass(api.databaseCleanup(cleanup))
Expand All @@ -121,7 +120,6 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
r.Get("/health", api.HealthCheck)

r.Route("/callback", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)
r.Use(api.loadFlowState)

Expand All @@ -130,7 +128,6 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
})

r.Route("/", func(r *router) {
r.UseBypass(logger)
r.Use(api.isValidExternalHost)

r.Get("/settings", api.Settings)
Expand Down
8 changes: 3 additions & 5 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"
"net/http"
"time"
"strings"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -44,11 +44,10 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte
return ctx, nil
}

func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Context, error) {
func (a *API) requireAdmin(ctx context.Context) (context.Context, error) {
// Find the administrative user
claims := getClaims(ctx)
if claims == nil {
fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token")
return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token")
}

Expand All @@ -59,8 +58,7 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex
return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil
}

fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'")
return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed")
return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", ")))
}

func (a *API) extractBearerToken(r *http.Request) (string, error) {
Expand Down
16 changes: 0 additions & 16 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ func (c contextKey) String() string {

const (
tokenKey = contextKey("jwt")
requestIDKey = contextKey("request_id")
inviteTokenKey = contextKey("invite_token")
signatureKey = contextKey("signature")
externalProviderTypeKey = contextKey("external_provider_type")
Expand Down Expand Up @@ -57,21 +56,6 @@ func getClaims(ctx context.Context) *AccessTokenClaims {
return token.Claims.(*AccessTokenClaims)
}

// withRequestID adds the provided request ID to the context.
func withRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, requestIDKey, id)
}

// getRequestID reads the request ID from the context.
func getRequestID(ctx context.Context) string {
obj := ctx.Value(requestIDKey)
if obj == nil {
return ""
}

return obj.(string)
}

// withUser adds the user to the context.
func withUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, userKey, u)
Expand Down
43 changes: 22 additions & 21 deletions internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,27 +148,28 @@ func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...in
// Recoverer is a middleware that recovers from panics, logs the panic (and a
// backtrace), and returns a HTTP 500 (Internal Server Error) status if
// possible. Recoverer prints a request ID if one is provided.
func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) {
defer func() {
if rvr := recover(); rvr != nil {

logEntry := observability.GetLogEntry(r)
if logEntry != nil {
logEntry.Panic(rvr, debug.Stack())
} else {
fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr)
debug.PrintStack()
}
func recoverer(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil {
logEntry := observability.GetLogEntry(r)
if logEntry != nil {
logEntry.Panic(rvr, debug.Stack())
} else {
fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr)
debug.PrintStack()
}

se := &HTTPError{
HTTPStatus: http.StatusInternalServerError,
Message: http.StatusText(http.StatusInternalServerError),
se := &HTTPError{
HTTPStatus: http.StatusInternalServerError,
Message: http.StatusText(http.StatusInternalServerError),
}
HandleResponseError(se, w, r)
}
HandleResponseError(se, w, r)
}
}()

return nil, nil
}()
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}

// ErrorCause is an error interface that contains the method Cause() for returning root cause errors
Expand All @@ -182,8 +183,8 @@ type HTTPErrorResponse20240101 struct {
}

func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
log := observability.GetLogEntry(r)
errorID := getRequestID(r.Context())
log := observability.GetLogEntry(r).Entry
errorID := utilities.GetRequestID(r.Context())

apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName))
if averr != nil {
Expand Down
41 changes: 41 additions & 0 deletions internal/api/errors_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package api

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/observability"
)

func TestHandleResponseErrorWithHTTPError(t *testing.T) {
Expand Down Expand Up @@ -62,3 +67,39 @@ func TestHandleResponseErrorWithHTTPError(t *testing.T) {
require.Equal(t, example.ExpectedBody, rec.Body.String())
}
}

func TestRecoverer(t *testing.T) {
var logBuffer bytes.Buffer
config, err := conf.LoadGlobal(apiTestConfig)
require.NoError(t, err)
require.NoError(t, observability.ConfigureLogging(&config.Logging))

// logrus should write to the buffer so we can check if the logs are output correctly
logrus.SetOutput(&logBuffer)
panicHandler := recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
}))

w := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodPost, "http://example.com", nil)
require.NoError(t, err)

panicHandler.ServeHTTP(w, req)

require.Equal(t, http.StatusInternalServerError, w.Code)

var data HTTPError

// panic should return an internal server error
require.NoError(t, json.NewDecoder(w.Body).Decode(&data))
require.Equal(t, ErrorCodeUnexpectedFailure, data.ErrorCode)
require.Equal(t, http.StatusInternalServerError, data.HTTPStatus)
require.Equal(t, "Internal Server Error", data.Message)

// panic should log the error message internally
var logs map[string]interface{}
require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs))
require.Equal(t, "request panicked", logs["msg"])
require.Equal(t, "test panic", logs["panic"])
require.NotEmpty(t, logs["stack"])
}
6 changes: 3 additions & 3 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
}

redirectURL := utilities.GetReferrer(r, config)
log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry
log.WithField("provider", providerType).Info("Redirecting to external provider")
if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
return "", err
Expand Down Expand Up @@ -573,8 +573,8 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide

func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
ctx := r.Context()
log := observability.GetLogEntry(r)
errorID := getRequestID(ctx)
log := observability.GetLogEntry(r).Entry
errorID := utilities.GetRequestID(ctx)
err := handler(w, r)
if err != nil {
q := getErrorQueryString(err, errorID, log, u.Query())
Expand Down
2 changes: 1 addition & 1 deletion internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
}

log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry
log.WithFields(logrus.Fields{
"provider": providerType,
"code": oauthCode,
Expand Down
18 changes: 0 additions & 18 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,12 @@ import (
"fmt"
"net/http"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/utilities"
)

func addRequestID(globalConfig *conf.GlobalConfiguration) middlewareHandler {
return func(w http.ResponseWriter, r *http.Request) (context.Context, error) {
id := ""
if globalConfig.API.RequestIDHeader != "" {
id = r.Header.Get(globalConfig.API.RequestIDHeader)
}
if id == "" {
uid := uuid.Must(uuid.NewV4())
id = uid.String()
}

ctx := r.Context()
ctx = withRequestID(ctx, id)
return ctx, nil
}
}

func sendJSON(w http.ResponseWriter, status int, obj interface{}) error {
w.Header().Set("Content-Type", "application/json")
b, err := json.Marshal(obj)
Expand Down
2 changes: 1 addition & 1 deletion internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointCon
ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout)
defer cancel()

log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
Expand Down
8 changes: 4 additions & 4 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
key := req.Header.Get(limitHeader)

if key == "" {
log := observability.GetLogEntry(req)
log := observability.GetLogEntry(req).Entry
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
return c, nil
} else {
Expand Down Expand Up @@ -145,7 +145,7 @@ func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request)
return nil, err
}

return a.requireAdmin(ctx, req)
return a.requireAdmin(ctx)
}

func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) {
Expand Down Expand Up @@ -212,7 +212,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con
}
if u, err = url.ParseRequestURI(baseUrl); err != nil {
// fallback to the default hostname
log := observability.GetLogEntry(req)
log := observability.GetLogEntry(req).Entry
log.WithField("request_url", baseUrl).Warn(err)
if u, err = url.ParseRequestURI(config.API.ExternalURL); err != nil {
return ctx, err
Expand Down Expand Up @@ -251,7 +251,7 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
}

db := a.db.WithContext(r.Context())
log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry

affectedRows, err := cleanup.Clean(db)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {

db := a.db.WithContext(ctx)
config := a.config
log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry

relayStateValue := r.FormValue("RelayState")
relayStateUUID := uuid.FromStringOrNil(relayStateValue)
Expand Down
2 changes: 1 addition & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
if wpe, ok := err.(*WeakPasswordError); ok {
weakPasswordError = wpe
} else {
observability.GetLogEntry(r).WithError(err).Warn("Password strength check on sign-in failed")
observability.GetLogEntry(r).Entry.WithError(err).Warn("Password strength check on sign-in failed")
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type IdTokenGrantParams struct {
}

func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, *conf.OAuthProviderConfiguration, string, []string, error) {
log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry

var cfg *conf.OAuthProviderConfiguration
var issuer string
Expand Down Expand Up @@ -113,7 +113,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa

// IdTokenGrant implements the id_token grant type flow
func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
log := observability.GetLogEntry(r)
log := observability.GetLogEntry(r).Entry

db := a.db.WithContext(ctx)
config := a.config
Expand Down
4 changes: 2 additions & 2 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string,

// Maintain separate query params for hash and query
hq := url.Values{}
log := observability.GetLogEntry(r)
errorID := getRequestID(r.Context())
log := observability.GetLogEntry(r).Entry
errorID := utilities.GetRequestID(r.Context())
err.ErrorID = errorID
log.WithError(err.Cause()).Info(err.Error())
if str, ok := oauthErrorMap[err.HTTPStatus]; ok {
Expand Down
Loading
Loading