Skip to content

Commit

Permalink
feat: add password_hash and id fields to admin create user (supab…
Browse files Browse the repository at this point in the history
…ase#1641)

## What kind of change does this PR introduce?
* Add a `password_hash` field to admin create user, which allows an
admin to create a user with a given password hash (argon2 or bcrypt)
* Add an `id` field to admin create user, which allows an admin to
create a user with a custom id
* To prevent someone from creating a bunch of users with a high bcrypt
hashing cost, we opt to rehash the password with the default cost (10)
on subsequent sign-in.

## What is the current behavior?
* Only plaintext passwords are allowed, which will subsequently be
hashed internally

## What is the new behavior?

Example request using the bcrypt hash of "test":
```bash
$ curl -X POST 'http://localhost:9999/admin/users' \
-H 'Authorization: Bearer <admin_jwt>' \
-H 'Content-Type: application/json' \
-d '{"email": "foo@example.com", "password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq"}'
```

Example request using a custom id:
```bash
$ curl -X POST 'http://localhost:9999/admin/users' \
-H 'Authorization: Bearer <admin_jwt>' \
-H 'Content-Type: application/json' \
-d '{"id": "2a8813c2-bda7-47f0-94a6-49fcfdf61a70", "email": "foo@example.com"}'
```

Feel free to include screenshots if it includes visual changes.

## Additional context

Add any other context or screenshots.
  • Loading branch information
kangmingtay authored Jul 3, 2024
1 parent 3f70d9d commit 20d59f1
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 56 deletions.
65 changes: 47 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"net/http"
"strings"
"time"

"github.com/fatih/structs"
Expand All @@ -18,11 +17,13 @@ import (
)

type AdminUserParams struct {
Id string `json:"id"`
Aud string `json:"aud"`
Role string `json:"role"`
Email string `json:"email"`
Phone string `json:"phone"`
Password *string `json:"password"`
PasswordHash string `json:"password_hash"`
EmailConfirm bool `json:"email_confirm"`
PhoneConfirm bool `json:"phone_confirm"`
UserMetaData map[string]interface{} `json:"user_metadata"`
Expand Down Expand Up @@ -156,6 +157,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

var banDuration *time.Duration
if params.BanDuration != "" {
duration := time.Duration(0)
if params.BanDuration != "none" {
Expand All @@ -164,9 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
return terr
}
banDuration = &duration
}

if params.Password != nil {
Expand Down Expand Up @@ -291,6 +291,12 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

if banDuration != nil {
if terr := user.Ban(tx, *banDuration); terr != nil {
return terr
}
}

if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserModifiedAction, "", map[string]interface{}{
"user_id": user.ID,
"user_email": user.Email,
Expand Down Expand Up @@ -356,26 +362,59 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
providers = append(providers, "phone")
}

if params.Password == nil || *params.Password == "" {
if params.Password != nil && params.PasswordHash != "" {
return badRequestError(ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
}

if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" {
password, err := password.Generate(64, 10, 0, false, true)
if err != nil {
return internalServerError("Error generating password").WithInternalError(err)
}
params.Password = &password
}

user, err := models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData)
var user *models.User
if params.PasswordHash != "" {
user, err = models.NewUserWithPasswordHash(params.Phone, params.Email, params.PasswordHash, aud, params.UserMetaData)
} else {
user, err = models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData)
}

if err != nil {
return internalServerError("Error creating user").WithInternalError(err)
}

if params.Id != "" {
customId, err := uuid.FromString(params.Id)
if err != nil {
return badRequestError(ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
}
if customId == uuid.Nil {
return badRequestError(ErrorCodeValidationFailed, "ID cannot be a nil uuid")
}
user.ID = customId
}

user.AppMetaData = map[string]interface{}{
// TODO: Deprecate "provider" field
// default to the first provider in the providers slice
"provider": providers[0],
"providers": providers,
}

var banDuration *time.Duration
if params.BanDuration != "" {
duration := time.Duration(0)
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
banDuration = &duration
}

err = db.Transaction(func(tx *storage.Connection) error {
if terr := tx.Create(user); terr != nil {
return terr
Expand Down Expand Up @@ -442,15 +481,8 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}
}

if params.BanDuration != "" {
duration := time.Duration(0)
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
if banDuration != nil {
if terr := user.Ban(tx, *banDuration); terr != nil {
return terr
}
}
Expand All @@ -459,9 +491,6 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
})

if err != nil {
if strings.Contains("invalid format for ban duration", err.Error()) {
return err
}
return internalServerError("Database error creating new user").WithInternalError(err)
}

Expand Down
112 changes: 105 additions & 7 deletions internal/api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"
"time"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -244,6 +245,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
"isAuthenticated": true,
"provider": "phone",
"providers": []string{"phone"},
"password": "test1",
},
},
{
Expand All @@ -259,6 +261,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
"isAuthenticated": true,
"provider": "email",
"providers": []string{"email", "phone"},
"password": "test1",
},
},
{
Expand Down Expand Up @@ -288,6 +291,7 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
"isAuthenticated": false,
"provider": "email",
"providers": []string{"email"},
"password": "",
},
},
{
Expand All @@ -304,6 +308,39 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
"isAuthenticated": true,
"provider": "email",
"providers": []string{"email"},
"password": "test1",
},
},
{
desc: "With password hash",
params: map[string]interface{}{
"email": "test5@example.com",
"password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq",
},
expected: map[string]interface{}{
"email": "test5@example.com",
"phone": "",
"isAuthenticated": true,
"provider": "email",
"providers": []string{"email"},
"password": "test",
},
},
{
desc: "With custom id",
params: map[string]interface{}{
"id": "fc56ab41-2010-4870-a9b9-767c1dc573fb",
"email": "test6@example.com",
"password": "test",
},
expected: map[string]interface{}{
"id": "fc56ab41-2010-4870-a9b9-767c1dc573fb",
"email": "test6@example.com",
"phone": "",
"isAuthenticated": true,
"provider": "email",
"providers": []string{"email"},
"password": "test",
},
},
}
Expand Down Expand Up @@ -345,15 +382,18 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
}
}

var expectedPassword string
if _, ok := c.params["password"]; ok {
expectedPassword = fmt.Sprintf("%v", c.params["password"])
if _, ok := c.expected["password"]; ok {
expectedPassword := fmt.Sprintf("%v", c.expected["password"])
isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
require.NoError(ts.T(), err)
require.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated)
}

isAuthenticated, _, err := u.Authenticate(context.Background(), expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
require.NoError(ts.T(), err)

assert.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated)
if id, ok := c.expected["id"]; ok {
uid, err := uuid.FromString(id.(string))
require.NoError(ts.T(), err)
require.Equal(ts.T(), uid, data.ID)
}

// remove created user after each case
require.NoError(ts.T(), ts.API.db.Destroy(u))
Expand Down Expand Up @@ -820,5 +860,63 @@ func (ts *AdminTestSuite) TestAdminUserUpdateFactor() {
require.Equal(ts.T(), c.ExpectedCode, w.Code)
})
}
}

func (ts *AdminTestSuite) TestAdminUserCreateValidationErrors() {
cases := []struct {
desc string
params map[string]interface{}
}{
{
desc: "create user without email and phone",
params: map[string]interface{}{
"password": "test_password",
},
},
{
desc: "create user with password and password hash",
params: map[string]interface{}{
"email": "test@example.com",
"password": "test_password",
"password_hash": "$2y$10$Tk6yEdmTbb/eQ/haDMaCsuCsmtPVprjHMcij1RqiJdLGPDXnL3L1a",
},
},
{
desc: "invalid ban duration",
params: map[string]interface{}{
"email": "test@example.com",
"ban_duration": "never",
},
},
{
desc: "custom id is nil",
params: map[string]interface{}{
"id": "00000000-0000-0000-0000-000000000000",
"email": "test@example.com",
},
},
{
desc: "bad id format",
params: map[string]interface{}{
"id": "bad_uuid_format",
"email": "test@example.com",
},
},
}
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params))
req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusBadRequest, w.Code, w)

data := map[string]interface{}{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), data["error_code"], ErrorCodeValidationFailed)
})

}
}
2 changes: 1 addition & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return oauthError("invalid_grant", InvalidLoginMessage)
}

isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID)
isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
isSamePassword := false

if user.HasPassword() {
auth, _, err := user.Authenticate(ctx, password, config.Security.DBEncryption.DecryptionKeys, false, "")
auth, _, err := user.Authenticate(ctx, db, password, config.Security.DBEncryption.DecryptionKeys, false, "")
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions internal/api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() {
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

isAuthenticated, _, err := u.Authenticate(context.Background(), c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
require.NoError(ts.T(), err)

require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated)
Expand Down Expand Up @@ -372,7 +372,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordNoReauthenticationRequired() {
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

isAuthenticated, _, err := u.Authenticate(context.Background(), c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID)
require.NoError(ts.T(), err)

require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated)
Expand Down Expand Up @@ -430,7 +430,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() {
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

isAuthenticated, _, err := u.Authenticate(context.Background(), "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID)
isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID)
require.NoError(ts.T(), err)

require.True(ts.T(), isAuthenticated)
Expand Down
Loading

0 comments on commit 20d59f1

Please sign in to comment.