Skip to content

Commit

Permalink
fix: redirect invalid state errors to site url (#1722)
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay authored Aug 16, 2024
1 parent 4351226 commit b2b1123
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 22 deletions.
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Post("/acs", api.SAMLACS)
)).Post("/acs", api.SamlAcs)
})
})

Expand Down
20 changes: 15 additions & 5 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e
if err != nil {
return err
}
a.redirectErrors(a.internalExternalProviderCallback, w, r, u)
redirectErrors(a.internalExternalProviderCallback, w, r, u)
return nil
}

Expand Down Expand Up @@ -478,18 +478,28 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p
return user, nil
}

func (a *API) loadExternalState(ctx context.Context, state string) (context.Context, error) {
func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
var state string
switch r.Method {
case http.MethodPost:
state = r.FormValue("state")
default:
state = r.URL.Query().Get("state")
}
if state == "" {
return ctx, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
}
config := a.config
claims := ExternalProviderClaims{}
p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
}
if claims.Provider == "" {
return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
}
if claims.InviteToken != "" {
ctx = withInviteToken(ctx, claims.InviteToken)
Expand Down Expand Up @@ -573,7 +583,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
}
}

func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
ctx := r.Context()
log := observability.GetLogEntry(r).Entry
errorID := utilities.GetRequestID(ctx)
Expand Down
28 changes: 16 additions & 12 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
)

// OAuthProviderData contains the userData and token returned by the oauth provider
Expand All @@ -23,17 +24,6 @@ type OAuthProviderData struct {
// loadFlowState parses the `state` query parameter as a JWS payload,
// extracting the provider requested
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
var state string
if r.Method == http.MethodPost {
state = r.FormValue("state")
} else {
state = r.URL.Query().Get("state")
}

if state == "" {
return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
}

ctx := r.Context()
oauthToken := r.URL.Query().Get("oauth_token")
if oauthToken != "" {
Expand All @@ -43,7 +33,21 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
if oauthVerifier != "" {
ctx = withOAuthVerifier(ctx, oauthVerifier)
}
return a.loadExternalState(ctx, state)

var err error
ctx, err = a.loadExternalState(ctx, r)
if err != nil {
u, uerr := url.ParseRequestURI(a.config.SiteURL)
if uerr != nil {
return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr)
}

q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query())
u.RawQuery = q.Encode()

http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
return ctx, err
}

func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() {
parsedURL, err := url.Parse(c.RedirectURL)
require.Equal(ts.T(), err, nil)

ts.API.redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL)
redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL)

parsedParams, err := url.ParseQuery(parsedURL.RawQuery)
require.Equal(ts.T(), err, nil)
Expand Down
19 changes: 16 additions & 3 deletions internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models
return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate
}

// SAMLACS implements the main Assertion Consumer Service endpoint behavior.
func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error {
if err := a.handleSamlAcs(w, r); err != nil {
u, uerr := url.Parse(a.config.SiteURL)
if uerr != nil {
return internalServerError("site url is improperly formattted").WithInternalError(err)
}

q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query())
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
return nil
}

// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior.
func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

db := a.db.WithContext(ctx)
Expand All @@ -61,7 +75,6 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
var requestIds []string

var flowState *models.FlowState
flowState = nil
if relayStateUUID != uuid.Nil {
// relay state is a valid UUID, therefore this is likely a SP initiated flow

Expand Down

0 comments on commit b2b1123

Please sign in to comment.