From 20d59f10b601577683d05bcd7d2128ff4bc462a0 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Wed, 3 Jul 2024 12:13:22 -0700 Subject: [PATCH] feat: add `password_hash` and `id` fields to admin create user (#1641) ## What kind of change does this PR introduce? * Add a `password_hash` field to admin create user, which allows an admin to create a user with a given password hash (argon2 or bcrypt) * Add an `id` field to admin create user, which allows an admin to create a user with a custom id * To prevent someone from creating a bunch of users with a high bcrypt hashing cost, we opt to rehash the password with the default cost (10) on subsequent sign-in. ## What is the current behavior? * Only plaintext passwords are allowed, which will subsequently be hashed internally ## What is the new behavior? Example request using the bcrypt hash of "test": ```bash $ curl -X POST 'http://localhost:9999/admin/users' \ -H 'Authorization: Bearer ' \ -H 'Content-Type: application/json' \ -d '{"email": "foo@example.com", "password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq"}' ``` Example request using a custom id: ```bash $ curl -X POST 'http://localhost:9999/admin/users' \ -H 'Authorization: Bearer ' \ -H 'Content-Type: application/json' \ -d '{"id": "2a8813c2-bda7-47f0-94a6-49fcfdf61a70", "email": "foo@example.com"}' ``` Feel free to include screenshots if it includes visual changes. ## Additional context Add any other context or screenshots. --- internal/api/admin.go | 65 ++++++++++++++------ internal/api/admin_test.go | 112 ++++++++++++++++++++++++++++++++--- internal/api/token.go | 2 +- internal/api/user.go | 2 +- internal/api/user_test.go | 6 +- internal/crypto/password.go | 73 +++++++++++++++-------- internal/models/user.go | 44 +++++++++++++- internal/models/user_test.go | 92 ++++++++++++++++++++++++++++ 8 files changed, 340 insertions(+), 56 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index 1cda8e264..ecd9c2053 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "strings" "time" "github.com/fatih/structs" @@ -18,11 +17,13 @@ import ( ) type AdminUserParams struct { + Id string `json:"id"` Aud string `json:"aud"` Role string `json:"role"` Email string `json:"email"` Phone string `json:"phone"` Password *string `json:"password"` + PasswordHash string `json:"password_hash"` EmailConfirm bool `json:"email_confirm"` PhoneConfirm bool `json:"phone_confirm"` UserMetaData map[string]interface{} `json:"user_metadata"` @@ -156,6 +157,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { } } + var banDuration *time.Duration if params.BanDuration != "" { duration := time.Duration(0) if params.BanDuration != "none" { @@ -164,9 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } - if terr := user.Ban(a.db, duration); terr != nil { - return terr - } + banDuration = &duration } if params.Password != nil { @@ -291,6 +291,12 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { } } + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { + return terr + } + } + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserModifiedAction, "", map[string]interface{}{ "user_id": user.ID, "user_email": user.Email, @@ -356,7 +362,11 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { providers = append(providers, "phone") } - if params.Password == nil || *params.Password == "" { + if params.Password != nil && params.PasswordHash != "" { + return badRequestError(ErrorCodeValidationFailed, "Only a password or a password hash should be provided") + } + + if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" { password, err := password.Generate(64, 10, 0, false, true) if err != nil { return internalServerError("Error generating password").WithInternalError(err) @@ -364,11 +374,28 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { params.Password = &password } - user, err := models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData) + var user *models.User + if params.PasswordHash != "" { + user, err = models.NewUserWithPasswordHash(params.Phone, params.Email, params.PasswordHash, aud, params.UserMetaData) + } else { + user, err = models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData) + } + if err != nil { return internalServerError("Error creating user").WithInternalError(err) } + if params.Id != "" { + customId, err := uuid.FromString(params.Id) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "ID must conform to the uuid v4 format") + } + if customId == uuid.Nil { + return badRequestError(ErrorCodeValidationFailed, "ID cannot be a nil uuid") + } + user.ID = customId + } + user.AppMetaData = map[string]interface{}{ // TODO: Deprecate "provider" field // default to the first provider in the providers slice @@ -376,6 +403,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { "providers": providers, } + var banDuration *time.Duration + if params.BanDuration != "" { + duration := time.Duration(0) + if params.BanDuration != "none" { + duration, err = time.ParseDuration(params.BanDuration) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + } + } + banDuration = &duration + } + err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(user); terr != nil { return terr @@ -442,15 +481,8 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } } - if params.BanDuration != "" { - duration := time.Duration(0) - if params.BanDuration != "none" { - duration, err = time.ParseDuration(params.BanDuration) - if err != nil { - return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) - } - } - if terr := user.Ban(a.db, duration); terr != nil { + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { return terr } } @@ -459,9 +491,6 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - if strings.Contains("invalid format for ban duration", err.Error()) { - return err - } return internalServerError("Database error creating new user").WithInternalError(err) } diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go index 135616c1f..e1b2c0328 100644 --- a/internal/api/admin_test.go +++ b/internal/api/admin_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -244,6 +245,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() { "isAuthenticated": true, "provider": "phone", "providers": []string{"phone"}, + "password": "test1", }, }, { @@ -259,6 +261,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() { "isAuthenticated": true, "provider": "email", "providers": []string{"email", "phone"}, + "password": "test1", }, }, { @@ -288,6 +291,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() { "isAuthenticated": false, "provider": "email", "providers": []string{"email"}, + "password": "", }, }, { @@ -304,6 +308,39 @@ func (ts *AdminTestSuite) TestAdminUserCreate() { "isAuthenticated": true, "provider": "email", "providers": []string{"email"}, + "password": "test1", + }, + }, + { + desc: "With password hash", + params: map[string]interface{}{ + "email": "test5@example.com", + "password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + expected: map[string]interface{}{ + "email": "test5@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", + }, + }, + { + desc: "With custom id", + params: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "password": "test", + }, + expected: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", }, }, } @@ -345,15 +382,18 @@ func (ts *AdminTestSuite) TestAdminUserCreate() { } } - var expectedPassword string - if _, ok := c.params["password"]; ok { - expectedPassword = fmt.Sprintf("%v", c.params["password"]) + if _, ok := c.expected["password"]; ok { + expectedPassword := fmt.Sprintf("%v", c.expected["password"]) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated) } - isAuthenticated, _, err := u.Authenticate(context.Background(), expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) - require.NoError(ts.T(), err) - - assert.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated) + if id, ok := c.expected["id"]; ok { + uid, err := uuid.FromString(id.(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), uid, data.ID) + } // remove created user after each case require.NoError(ts.T(), ts.API.db.Destroy(u)) @@ -820,5 +860,63 @@ func (ts *AdminTestSuite) TestAdminUserUpdateFactor() { require.Equal(ts.T(), c.ExpectedCode, w.Code) }) } +} + +func (ts *AdminTestSuite) TestAdminUserCreateValidationErrors() { + cases := []struct { + desc string + params map[string]interface{} + }{ + { + desc: "create user without email and phone", + params: map[string]interface{}{ + "password": "test_password", + }, + }, + { + desc: "create user with password and password hash", + params: map[string]interface{}{ + "email": "test@example.com", + "password": "test_password", + "password_hash": "$2y$10$Tk6yEdmTbb/eQ/haDMaCsuCsmtPVprjHMcij1RqiJdLGPDXnL3L1a", + }, + }, + { + desc: "invalid ban duration", + params: map[string]interface{}{ + "email": "test@example.com", + "ban_duration": "never", + }, + }, + { + desc: "custom id is nil", + params: map[string]interface{}{ + "id": "00000000-0000-0000-0000-000000000000", + "email": "test@example.com", + }, + }, + { + desc: "bad id format", + params: map[string]interface{}{ + "id": "bad_uuid_format", + "email": "test@example.com", + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code, w) + data := map[string]interface{}{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), data["error_code"], ErrorCodeValidationFailed) + }) + + } } diff --git a/internal/api/token.go b/internal/api/token.go index 0d03d4fd4..22c42e9b1 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -145,7 +145,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return oauthError("invalid_grant", InvalidLoginMessage) } - isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) if err != nil { return err } diff --git a/internal/api/user.go b/internal/api/user.go index 960322c94..4cb1fc1d9 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -157,7 +157,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { isSamePassword := false if user.HasPassword() { - auth, _, err := user.Authenticate(ctx, password, config.Security.DBEncryption.DecryptionKeys, false, "") + auth, _, err := user.Authenticate(ctx, db, password, config.Security.DBEncryption.DecryptionKeys, false, "") if err != nil { return err } diff --git a/internal/api/user_test.go b/internal/api/user_test.go index 8272bb87e..af9cfec37 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -310,7 +310,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - isAuthenticated, _, err := u.Authenticate(context.Background(), c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) @@ -372,7 +372,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordNoReauthenticationRequired() { u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - isAuthenticated, _, err := u.Authenticate(context.Background(), c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) @@ -430,7 +430,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() { u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - isAuthenticated, _, err := u.Authenticate(context.Background(), "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID) require.NoError(ts.T(), err) require.True(ts.T(), isAuthenticated) diff --git a/internal/crypto/password.go b/internal/crypto/password.go index 554daccaa..dca101450 100644 --- a/internal/crypto/password.go +++ b/internal/crypto/password.go @@ -32,6 +32,8 @@ const ( // BCrypt hashed passwords have a 72 character limit MaxPasswordLength = 72 + + Argon2Prefix = "$argon2" ) // PasswordHashCost is the current pasword hashing cost @@ -54,11 +56,23 @@ var ErrArgon2MismatchedHashAndPassword = errors.New("crypto: argon2 hash and pas // argon2HashRegexp https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#argon2-encoding var argon2HashRegexp = regexp.MustCompile("^[$](?Pargon2(d|i|id))[$]v=(?P(16|19))[$]m=(?P[0-9]+),t=(?P[0-9]+),p=(?P

[0-9]+)(,keyid=(?P[^,]+))?(,data=(?P[^$]+))?[$](?P[^$]+)[$](?P.+)$") -func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) error { +type Argon2HashInput struct { + alg string + v string + memory uint64 + time uint64 + threads uint64 + keyid string + data string + salt []byte + rawHash []byte +} + +func ParseArgon2Hash(hash string) (*Argon2HashInput, error) { submatch := argon2HashRegexp.FindStringSubmatchIndex(hash) if submatch == nil { - return errors.New("crypto: incorrect argon2 hash format") + return nil, errors.New("crypto: incorrect argon2 hash format") } alg := string(argon2HashRegexp.ExpandString(nil, "$alg", hash, submatch)) @@ -72,58 +86,68 @@ func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) er hashB64 := string(argon2HashRegexp.ExpandString(nil, "$hash", hash, submatch)) if alg != "argon2i" && alg != "argon2id" { - return fmt.Errorf("crypto: argon2 hash uses unsupported algorithm %q only argon2i and argon2id supported", alg) + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported algorithm %q only argon2i and argon2id supported", alg) } if v != "19" { - return fmt.Errorf("crypto: argon2 hash uses unsupported version %q only %d is supported", v, argon2.Version) + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported version %q only %d is supported", v, argon2.Version) } if data != "" { - return fmt.Errorf("crypto: argon2 hashes with the data parameter not supported") + return nil, fmt.Errorf("crypto: argon2 hashes with the data parameter not supported") } if keyid != "" { - return fmt.Errorf("crypto: argon2 hashes with the keyid parameter not supported") + return nil, fmt.Errorf("crypto: argon2 hashes with the keyid parameter not supported") } memory, err := strconv.ParseUint(m, 10, 32) if err != nil { - return fmt.Errorf("crypto: argon2 hash has invalid m parameter %q %w", m, err) + return nil, fmt.Errorf("crypto: argon2 hash has invalid m parameter %q %w", m, err) } time, err := strconv.ParseUint(t, 10, 32) if err != nil { - return fmt.Errorf("crypto: argon2 hash has invalid t parameter %q %w", t, err) + return nil, fmt.Errorf("crypto: argon2 hash has invalid t parameter %q %w", t, err) } threads, err := strconv.ParseUint(p, 10, 8) if err != nil { - return fmt.Errorf("crypto: argon2 hash has invalid p parameter %q %w", p, err) + return nil, fmt.Errorf("crypto: argon2 hash has invalid p parameter %q %w", p, err) } rawHash, err := base64.RawStdEncoding.DecodeString(hashB64) if err != nil { - return fmt.Errorf("crypto: argon2 hash has invalid base64 in the hash section %w", err) + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the hash section %w", err) } salt, err := base64.RawStdEncoding.DecodeString(saltB64) if err != nil { - return fmt.Errorf("crypto: argon2 hash has invalid base64 in the salt section %w", err) + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the salt section %w", err) } - var match bool - var derivedKey []byte + input := Argon2HashInput{alg, v, memory, time, threads, keyid, data, salt, rawHash} + + return &input, nil +} + +func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) error { + input, err := ParseArgon2Hash(hash) + if err != nil { + return err + } attributes := []attribute.KeyValue{ - attribute.String("alg", alg), - attribute.String("v", v), - attribute.Int64("m", int64(memory)), - attribute.Int64("t", int64(time)), - attribute.Int("p", int(threads)), - attribute.Int("len", len(rawHash)), + attribute.String("alg", input.alg), + attribute.String("v", input.v), + attribute.Int64("m", int64(input.memory)), + attribute.Int64("t", int64(input.time)), + attribute.Int("p", int(input.threads)), + attribute.Int("len", len(input.rawHash)), } + var match bool + var derivedKey []byte compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) defer func() { attributes = append(attributes, attribute.Bool( @@ -134,15 +158,15 @@ func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) er compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) }() - switch alg { + switch input.alg { case "argon2i": - derivedKey = argon2.Key([]byte(password), salt, uint32(time), uint32(memory)*1024, uint8(threads), uint32(len(rawHash))) + derivedKey = argon2.Key([]byte(password), input.salt, uint32(input.time), uint32(input.memory)*1024, uint8(input.threads), uint32(len(input.rawHash))) case "argon2id": - derivedKey = argon2.IDKey([]byte(password), salt, uint32(time), uint32(memory)*1024, uint8(threads), uint32(len(rawHash))) + derivedKey = argon2.IDKey([]byte(password), input.salt, uint32(input.time), uint32(input.memory)*1024, uint8(input.threads), uint32(len(input.rawHash))) } - match = subtle.ConstantTimeCompare(derivedKey, rawHash) == 0 + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 0 if !match { return ErrArgon2MismatchedHashAndPassword @@ -155,7 +179,7 @@ func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) er // password, returns nil if equal otherwise an error. Context can be used to // cancel the hashing if the algorithm supports it. func CompareHashAndPassword(ctx context.Context, hash, password string) error { - if strings.HasPrefix(hash, "$argon2") { + if strings.HasPrefix(hash, Argon2Prefix) { return compareHashAndPasswordArgon2(ctx, hash, password) } @@ -181,7 +205,6 @@ func CompareHashAndPassword(ctx context.Context, hash, password string) error { }() err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) - return err } diff --git a/internal/models/user.go b/internal/models/user.go index 0d074562a..3c871e543 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/storage" + "golang.org/x/crypto/bcrypt" ) // User respresents a registered user with email/password authentication @@ -71,6 +72,31 @@ type User struct { DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` } +func NewUserWithPasswordHash(phone, email, passwordHash, aud string, userData map[string]interface{}) (*User, error) { + if strings.HasPrefix(passwordHash, crypto.Argon2Prefix) { + _, err := crypto.ParseArgon2Hash(passwordHash) + if err != nil { + return nil, err + } + } else { + // verify that the hash is a bcrypt hash + _, err := bcrypt.Cost([]byte(passwordHash)) + if err != nil { + return nil, err + } + } + id := uuid.Must(uuid.NewV4()) + user := &User{ + ID: id, + Aud: aud, + Email: storage.NullString(strings.ToLower(email)), + Phone: storage.NullString(phone), + UserMetaData: userData, + EncryptedPassword: &passwordHash, + } + return user, nil +} + // NewUser initializes a new user from an email, password and user data. func NewUser(phone, email, password, aud string, userData map[string]interface{}) (*User, error) { passwordHash := "" @@ -351,7 +377,7 @@ func (u *User) UpdatePassword(tx *storage.Connection, sessionID *uuid.UUID) erro } // Authenticate a user from a password -func (u *User) Authenticate(ctx context.Context, password string, decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (bool, bool, error) { +func (u *User) Authenticate(ctx context.Context, tx *storage.Connection, password string, decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (bool, bool, error) { if u.EncryptedPassword == nil { return false, false, nil } @@ -370,6 +396,22 @@ func (u *User) Authenticate(ctx context.Context, password string, decryptionKeys compareErr := crypto.CompareHashAndPassword(ctx, hash, password) + if !strings.HasPrefix(hash, crypto.Argon2Prefix) { + // check if cost exceeds default cost or is too low + cost, err := bcrypt.Cost([]byte(hash)) + if err != nil { + return compareErr == nil, false, err + } + + if cost > bcrypt.DefaultCost || cost == bcrypt.MinCost { + // don't bother with encrypting the password in Authenticate + // since it's handled separately + if err := u.SetPassword(ctx, password, false, "", ""); err != nil { + return compareErr == nil, false, err + } + } + } + return compareErr == nil, encrypt && (es == nil || es.ShouldReEncrypt(encryptionKeyID)), nil } diff --git a/internal/models/user_test.go b/internal/models/user_test.go index 011cf28f0..47d16178d 100644 --- a/internal/models/user_test.go +++ b/internal/models/user_test.go @@ -1,6 +1,7 @@ package models import ( + "context" "strings" "testing" @@ -11,6 +12,7 @@ import ( "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/storage/test" + "golang.org/x/crypto/bcrypt" ) const modelsTestConfig = "../../hack/test.env" @@ -378,3 +380,93 @@ func (ts *UserTestSuite) TestSetPasswordTooLong() { err = user.SetPassword(ts.db.Context(), strings.Repeat("a", crypto.MaxPasswordLength), false, "", "") require.NoError(ts.T(), err) } + +func (ts *UserTestSuite) TestNewUserWithPasswordHashSuccess() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Valid bcrypt hash", + hash: "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + { + desc: "Valid argon2i hash", + hash: "$argon2i$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + }, + { + desc: "Valid argon2id hash", + hash: "$argon2id$v=19$m=32,t=3,p=2$SFVpOWJ0eXhjRzVkdGN1RQ$RXnb8rh7LaDcn07xsssqqulZYXOM/EUCEFMVcAcyYVk", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestNewUserWithPasswordHashFailure() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Invalid argon2i hash", + hash: "$argon2id$test", + }, + { + desc: "Invalid bcrypt hash", + hash: "plaintest_password", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.Error(ts.T(), err) + require.Nil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestAuthenticate() { + // every case uses "test" as the password + cases := []struct { + desc string + hash string + expectedHashCost int + }{ + { + desc: "Invalid bcrypt hash cost of 11", + hash: "$2y$11$4lH57PU7bGATpRcx93vIoObH3qDmft/pytbOzDG9/1WsyNmN5u4di", + expectedHashCost: bcrypt.MinCost, + }, + { + desc: "Valid bcrypt hash cost of 10", + hash: "$2y$10$va66S4MxFrH6G6L7BzYl0.QgcYgvSr/F92gc.3botlz7bG4p/g/1i", + expectedHashCost: bcrypt.DefaultCost, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(u)) + require.NotNil(ts.T(), u) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.db, "test", nil, false, "") + require.NoError(ts.T(), err) + require.True(ts.T(), isAuthenticated) + + // check hash cost + hashCost, err := bcrypt.Cost([]byte(*u.EncryptedPassword)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expectedHashCost, hashCost) + }) + } +}