-
Notifications
You must be signed in to change notification settings - Fork 285
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(auth): add login/ logout routes and createUser mutation (#4293)
- Loading branch information
1 parent
43b8a65
commit a3ff0f6
Showing
10 changed files
with
290 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import re | ||
from hashlib import pbkdf2_hmac | ||
|
||
|
||
def compute_password_hash(*, password: str, salt: str) -> str: | ||
""" | ||
Salts and hashes a password using PBKDF2, HMAC, and SHA256. | ||
""" | ||
password_bytes = password.encode("utf-8") | ||
salt_bytes = salt.encode("utf-8") | ||
password_hash_bytes = pbkdf2_hmac("sha256", password_bytes, salt_bytes, NUM_ITERATIONS) | ||
password_hash = password_hash_bytes.hex() | ||
return password_hash | ||
|
||
|
||
def is_valid_password(*, password: str, salt: str, password_hash: str) -> bool: | ||
""" | ||
Determines whether the password is valid by salting and hashing the password | ||
and comparing against the existing hash value. | ||
""" | ||
return password_hash == compute_password_hash(password=password, salt=salt) | ||
|
||
|
||
def validate_email_format(email: str) -> None: | ||
""" | ||
Checks that the email has a valid format. | ||
""" | ||
if EMAIL_PATTERN.match(email) is None: | ||
raise ValueError("Invalid email address") | ||
|
||
|
||
def validate_password_format(password: str) -> None: | ||
""" | ||
Checks that the password has a valid format. | ||
""" | ||
if not password: | ||
raise ValueError("Password must be non-empty") | ||
if any(char.isspace() for char in password): | ||
raise ValueError("Password cannot contain whitespace characters") | ||
if not password.isascii(): | ||
raise ValueError("Password can contain only ASCII characters") | ||
|
||
|
||
EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+[.][^@\s]+\Z") | ||
NUM_ITERATIONS = 10_000 |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from enum import Enum | ||
|
||
import strawberry | ||
|
||
|
||
@strawberry.enum | ||
class UserRoleInput(Enum): | ||
ADMIN = "ADMIN" | ||
MEMBER = "MEMBER" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import asyncio | ||
from typing import Optional | ||
|
||
import strawberry | ||
from sqlalchemy import insert, select | ||
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped] | ||
from strawberry.types import Info | ||
|
||
from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format | ||
from phoenix.db import models | ||
from phoenix.server.api.context import Context | ||
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput | ||
from phoenix.server.api.types.User import User | ||
from phoenix.server.api.types.UserRole import UserRole | ||
|
||
|
||
@strawberry.input | ||
class CreateUserInput: | ||
email: str | ||
username: Optional[str] | ||
password: str | ||
role: UserRoleInput | ||
|
||
|
||
@strawberry.type | ||
class UserMutationPayload: | ||
user: User | ||
|
||
|
||
@strawberry.type | ||
class UserMutationMixin: | ||
@strawberry.mutation | ||
async def create_user( | ||
self, | ||
info: Info[Context, None], | ||
input: CreateUserInput, | ||
) -> UserMutationPayload: | ||
validate_email_format(email := input.email) | ||
validate_password_format(password := input.password) | ||
role_name = input.role.value | ||
user_role_id = ( | ||
select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery() | ||
) | ||
secret = info.context.get_secret() | ||
loop = asyncio.get_running_loop() | ||
password_hash = await loop.run_in_executor( | ||
executor=None, | ||
func=lambda: compute_password_hash(password=password, salt=secret), | ||
) | ||
try: | ||
async with info.context.db() as session: | ||
user = await session.scalar( | ||
insert(models.User) | ||
.values( | ||
user_role_id=user_role_id, | ||
username=input.username, | ||
email=email, | ||
auth_method="LOCAL", | ||
password_hash=password_hash, | ||
reset_password=True, | ||
) | ||
.returning(models.User) | ||
) | ||
assert user is not None | ||
except IntegrityError as error: | ||
raise ValueError(_get_user_create_error_message(error)) | ||
return UserMutationPayload( | ||
user=User( | ||
id_attr=user.id, | ||
email=user.email, | ||
username=user.username, | ||
created_at=user.created_at, | ||
role=UserRole(id_attr=user.user_role_id, name=role_name), | ||
) | ||
) | ||
|
||
|
||
def _get_user_create_error_message(error: IntegrityError) -> str: | ||
""" | ||
Gets a user-facing error message to explain why user creation failed. | ||
""" | ||
original_error_message = str(error) | ||
username_already_exists = "users.username" in original_error_message | ||
email_already_exists = "users.email" in original_error_message | ||
if username_already_exists: | ||
return "Username already exists" | ||
elif email_already_exists: | ||
return "Email already exists" | ||
return "Failed to create user" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import asyncio | ||
from datetime import timedelta | ||
|
||
from fastapi import APIRouter, Form, Request, Response | ||
from sqlalchemy import select | ||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED | ||
from typing_extensions import Annotated | ||
|
||
from phoenix.auth import is_valid_password | ||
from phoenix.db import models | ||
|
||
router = APIRouter(include_in_schema=False) | ||
|
||
PHOENIX_ACCESS_TOKEN_COOKIE_NAME = "phoenix-access-token" | ||
PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS = int(timedelta(days=31).total_seconds()) | ||
|
||
|
||
@router.post("/login") | ||
async def login( | ||
request: Request, | ||
email: Annotated[str, Form()], | ||
password: Annotated[str, Form()], | ||
) -> Response: | ||
async with request.app.state.db() as session: | ||
if ( | ||
user := await session.scalar(select(models.User).where(models.User.email == email)) | ||
) is None or (password_hash := user.password_hash) is None: | ||
return Response(status_code=HTTP_401_UNAUTHORIZED) | ||
secret = request.app.state.get_secret() | ||
loop = asyncio.get_running_loop() | ||
if not await loop.run_in_executor( | ||
executor=None, | ||
func=lambda: is_valid_password(password=password, salt=secret, password_hash=password_hash), | ||
): | ||
return Response(status_code=HTTP_401_UNAUTHORIZED) | ||
response = Response(status_code=HTTP_204_NO_CONTENT) | ||
response.set_cookie( | ||
key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME, | ||
value="token", # todo: compute access token | ||
secure=True, | ||
httponly=True, | ||
samesite="strict", | ||
max_age=PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS, | ||
) | ||
return response | ||
|
||
|
||
@router.post("/logout") | ||
async def logout() -> Response: | ||
response = Response(status_code=HTTP_204_NO_CONTENT) | ||
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import pytest | ||
|
||
from phoenix.auth import validate_email_format, validate_password_format | ||
|
||
|
||
def test_validate_email_format_does_not_raise_on_valid_format() -> None: | ||
validate_email_format("user@domain.com") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"email", | ||
( | ||
pytest.param("userdomain.com", id="missing-@"), | ||
pytest.param("user@domain", id="missing-top-level-domain-name"), | ||
pytest.param("user@domain.", id="empty-top-level-domain-name"), | ||
pytest.param("user@.com", id="missing-domain-name"), | ||
pytest.param("@domain.com", id="missing-username"), | ||
pytest.param("user @domain.com", id="username-contains-space"), | ||
pytest.param("user@do main.com", id="domain-name-contains-space"), | ||
pytest.param("user@domain.c om", id="top-level-domain-name-contains-space"), | ||
pytest.param(" user@domain.com", id="leading-space"), | ||
pytest.param("user@domain.com ", id="trailing-space"), | ||
pytest.param(" user@domain.com", id="leading-space"), | ||
pytest.param("\nuser@domain.com ", id="leading-newline"), | ||
pytest.param("user@domain.com\n", id="trailing-newline"), | ||
), | ||
) | ||
def test_validate_email_format_raises_on_invalid_format(email: str) -> None: | ||
with pytest.raises(ValueError): | ||
validate_email_format(email) | ||
|
||
|
||
def test_validate_password_format_does_not_raise_on_valid_format() -> None: | ||
validate_password_format("Password1234!") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"password", | ||
( | ||
pytest.param("", id="empty"), | ||
pytest.param("pass word", id="contains-space"), | ||
pytest.param("pass\nword", id="contains-newline"), | ||
pytest.param("password\n", id="trailing-newline"), | ||
pytest.param("P@ßwø®∂!ñ", id="contains-non-ascii-chars"), | ||
pytest.param("안녕하세요", id="korean"), | ||
pytest.param("🚀", id="emoji"), | ||
), | ||
) | ||
def test_validate_password_format_raises_on_invalid_format(password: str) -> None: | ||
with pytest.raises(ValueError): | ||
validate_password_format(password) |