Skip to content

Commit 4b5bc8d

Browse files
authored
Merge branch 'master' into cs/feat-percentage-based-db-conn-limits
2 parents f813ebd + 1f804a2 commit 4b5bc8d

File tree

13 files changed

+802
-31
lines changed

13 files changed

+802
-31
lines changed

internal/api/admin.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
514514
user := getUser(ctx)
515515
config := a.config
516516
adminUser := getAdminUser(ctx)
517+
db := a.db.WithContext(ctx)
517518

518519
// ShouldSoftDelete defaults to false
519520
params := &adminUserDeleteParams{}
@@ -525,7 +526,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
525526
}
526527
}
527528

528-
err := a.db.Transaction(func(tx *storage.Connection) error {
529+
err := db.Transaction(func(tx *storage.Connection) error {
529530
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{
530531
"user_id": user.ID,
531532
"user_email": user.Email,
@@ -575,8 +576,9 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
575576
config := a.config
576577
user := getUser(ctx)
577578
factor := getFactor(ctx)
579+
db := a.db.WithContext(ctx)
578580

579-
err := a.db.Transaction(func(tx *storage.Connection) error {
581+
err := db.Transaction(func(tx *storage.Connection) error {
580582
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{
581583
"user_id": user.ID,
582584
"factor_id": factor.ID,
@@ -608,12 +610,13 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
608610
user := getUser(ctx)
609611
adminUser := getAdminUser(ctx)
610612
params := &adminUserUpdateFactorParams{}
613+
db := a.db.WithContext(ctx)
611614

612615
if err := retrieveRequestParams(r, params); err != nil {
613616
return err
614617
}
615618

616-
err := a.db.Transaction(func(tx *storage.Connection) error {
619+
err := db.Transaction(func(tx *storage.Connection) error {
617620
if params.FriendlyName != "" {
618621
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
619622
return terr

internal/api/external.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
8383

8484
flowStateID := ""
8585
if isPKCEFlow(flowType) {
86-
flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
86+
flowState, err := generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
8787
if err != nil {
8888
return "", err
8989
}
@@ -200,7 +200,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
200200
var flowState *models.FlowState
201201
// if there's a non-empty FlowStateID we perform PKCE Flow
202202
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
203-
flowState, err = models.FindFlowStateByID(a.db, flowStateID)
203+
flowState, err = models.FindFlowStateByID(db, flowStateID)
204204
if models.IsNotFoundError(err) {
205205
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
206206
} else if err != nil {
@@ -506,7 +506,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p
506506
return user, nil
507507
}
508508

509-
func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
509+
func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storage.Connection) (context.Context, error) {
510510
var state string
511511
switch r.Method {
512512
case http.MethodPost:
@@ -564,7 +564,7 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C
564564
if err != nil {
565565
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
566566
}
567-
u, err := models.FindUserByID(a.db, linkingTargetUserID)
567+
u, err := models.FindUserByID(db, linkingTargetUserID)
568568
if err != nil {
569569
if models.IsNotFoundError(err) {
570570
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found")

internal/api/external_oauth.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type OAuthProviderData struct {
2727
// extracting the provider requested
2828
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
2929
ctx := r.Context()
30+
db := a.db.WithContext(ctx)
31+
3032
oauthToken := r.URL.Query().Get("oauth_token")
3133
if oauthToken != "" {
3234
ctx = withRequestToken(ctx, oauthToken)
@@ -37,7 +39,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
3739
}
3840

3941
var err error
40-
ctx, err = a.loadExternalState(ctx, r)
42+
ctx, err = a.loadExternalState(ctx, r, db)
4143
if err != nil {
4244
u, uerr := url.ParseRequestURI(a.config.SiteURL)
4345
if uerr != nil {

internal/api/hooks.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,31 @@ import (
1414

1515
func (a *API) triggerBeforeUserCreated(
1616
r *http.Request,
17-
conn *storage.Connection,
17+
db *storage.Connection,
1818
user *models.User,
1919
) error {
2020
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
2121
return nil
2222
}
23-
if err := checkTX(conn); err != nil {
23+
if err := checkTX(db); err != nil {
2424
return err
2525
}
2626

2727
req := v0hooks.NewBeforeUserCreatedInput(r, user)
2828
res := new(v0hooks.BeforeUserCreatedOutput)
29-
return a.hooksMgr.InvokeHook(conn, r, req, res)
29+
return a.hooksMgr.InvokeHook(db, r, req, res)
3030
}
3131

3232
func (a *API) triggerBeforeUserCreatedExternal(
3333
r *http.Request,
34-
conn *storage.Connection,
34+
db *storage.Connection,
3535
userData *provider.UserProvidedData,
3636
providerType string,
3737
) error {
3838
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
3939
return nil
4040
}
41-
if err := checkTX(conn); err != nil {
41+
if err := checkTX(db); err != nil {
4242
return err
4343
}
4444

@@ -55,7 +55,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
5555
err error
5656
decision models.AccountLinkingResult
5757
)
58-
err = a.db.Transaction(func(tx *storage.Connection) error {
58+
err = db.Transaction(func(tx *storage.Connection) error {
5959
decision, err = models.DetermineAccountLinking(
6060
tx, config, userData.Emails, aud,
6161
providerType, userData.Metadata.Subject)
@@ -93,7 +93,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
9393
if err != nil {
9494
return err
9595
}
96-
return a.triggerBeforeUserCreated(r, conn, user)
96+
return a.triggerBeforeUserCreated(r, db, user)
9797
}
9898

9999
func checkTX(conn *storage.Connection) error {

internal/api/identity.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
1717
ctx := r.Context()
18+
db := a.db.WithContext(ctx)
1819
config := a.config
1920

2021
claims := getClaims(ctx)
@@ -49,7 +50,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
4950
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist")
5051
}
5152

52-
err = a.db.Transaction(func(tx *storage.Connection) error {
53+
err = db.Transaction(func(tx *storage.Connection) error {
5354
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{
5455
"identity_id": identityToBeDeleted.ID,
5556
"provider": identityToBeDeleted.Provider,

internal/api/magic_link.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
130130
}
131131

132132
if isPKCEFlow(flowType) {
133-
if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
133+
if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
134134
return err
135135
}
136136
}

internal/api/mfa.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
418418
},
419419
}
420420
output := v0hooks.SendSMSOutput{}
421-
err := a.hooksMgr.InvokeHook(a.db, r, &input, &output)
421+
err := a.hooksMgr.InvokeHook(db, r, &input, &output)
422422
if err != nil {
423423
return apierrors.NewInternalServerError("error invoking hook")
424424
}

internal/api/oauthserver/authorize.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ type AuthorizeParams struct {
3535
// AuthorizationDetailsResponse represents the response for getting authorization details
3636
type AuthorizationDetailsResponse struct {
3737
AuthorizationID string `json:"authorization_id"`
38+
RedirectURI string `json:"redirect_uri,omitempty"`
3839
Client ClientDetailsResponse `json:"client,omitempty"`
3940
User UserDetailsResponse `json:"user,omitempty"`
4041
Scope string `json:"scope,omitempty"`
@@ -234,6 +235,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
234235
// Build response with client and user details
235236
response := AuthorizationDetailsResponse{
236237
AuthorizationID: authorization.AuthorizationID,
238+
RedirectURI: authorization.RedirectURI,
237239
Client: ClientDetailsResponse{
238240
ClientID: authorization.Client.ID.String(),
239241
ClientName: utilities.StringValue(authorization.Client.ClientName),
@@ -369,13 +371,15 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro
369371

370372
// validateRequestOrigin checks if the request is coming from an authorized origin
371373
func (s *Server) validateRequestOrigin(r *http.Request) error {
372-
// Check referer header
373-
referer := r.Referer()
374-
if referer == "" {
375-
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "request must originate from authorized domain")
374+
// Check Origin header
375+
// browsers add this header by default, we can at least prevent some basic cross-origin attacks
376+
origin := r.Header.Get("Origin")
377+
if origin == "" {
378+
// Empty Origin header is ok (e.g., for backend-originated requests or mobile apps)
379+
return nil
376380
}
377381

378-
if !utilities.IsRedirectURLValid(s.config, referer) {
382+
if !utilities.IsRedirectURLValid(s.config, origin) {
379383
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "unauthorized request origin")
380384
}
381385

0 commit comments

Comments
 (0)