Skip to content

Commit

Permalink
feat(auth): add login/ logout routes and createUser mutation (#4293)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Aug 21, 2024
1 parent 43b8a65 commit a3ff0f6
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 17 deletions.
17 changes: 17 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ input CreateTraceAnnotationInput {
metadata: JSON! = {}
}

input CreateUserInput {
email: String!
username: String
password: String!
role: UserRoleInput!
}

enum DataQualityMetric {
cardinality
percentEmpty
Expand Down Expand Up @@ -899,6 +906,7 @@ type Mutation {
patchTraceAnnotations(input: [PatchAnnotationInput!]!): TraceAnnotationMutationPayload!
deleteTraceAnnotations(input: DeleteAnnotationsInput!): TraceAnnotationMutationPayload!
createSystemApiKey(input: CreateApiKeyInput!): CreateSystemApiKeyMutationPayload!
createUser(input: CreateUserInput!): UserMutationPayload!
}

"""An object with a Globally Unique ID"""
Expand Down Expand Up @@ -1449,12 +1457,21 @@ type UserEdge {
node: User!
}

type UserMutationPayload {
user: User!
}

type UserRole implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
}

enum UserRoleInput {
ADMIN
MEMBER
}

type ValidationResult {
isValid: Boolean!
errorMessage: String
Expand Down
45 changes: 45 additions & 0 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import re
from hashlib import pbkdf2_hmac


def compute_password_hash(*, password: str, salt: str) -> str:
"""
Salts and hashes a password using PBKDF2, HMAC, and SHA256.
"""
password_bytes = password.encode("utf-8")
salt_bytes = salt.encode("utf-8")
password_hash_bytes = pbkdf2_hmac("sha256", password_bytes, salt_bytes, NUM_ITERATIONS)
password_hash = password_hash_bytes.hex()
return password_hash


def is_valid_password(*, password: str, salt: str, password_hash: str) -> bool:
"""
Determines whether the password is valid by salting and hashing the password
and comparing against the existing hash value.
"""
return password_hash == compute_password_hash(password=password, salt=salt)


def validate_email_format(email: str) -> None:
"""
Checks that the email has a valid format.
"""
if EMAIL_PATTERN.match(email) is None:
raise ValueError("Invalid email address")


def validate_password_format(password: str) -> None:
"""
Checks that the password has a valid format.
"""
if not password:
raise ValueError("Password must be non-empty")
if any(char.isspace() for char in password):
raise ValueError("Password cannot contain whitespace characters")
if not password.isascii():
raise ValueError("Password can contain only ASCII characters")


EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+[.][^@\s]+\Z")
NUM_ITERATIONS = 10_000
Empty file removed src/phoenix/auth/__init__.py
Empty file.
15 changes: 0 additions & 15 deletions src/phoenix/auth/utils.py

This file was deleted.

9 changes: 9 additions & 0 deletions src/phoenix/server/api/input_types/UserRoleInput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum

import strawberry


@strawberry.enum
class UserRoleInput(Enum):
ADMIN = "ADMIN"
MEMBER = "MEMBER"
2 changes: 2 additions & 0 deletions src/phoenix/server/api/mutations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin
from phoenix.server.api.mutations.trace_annotations_mutations import TraceAnnotationMutationMixin
from phoenix.server.api.mutations.user_mutations import UserMutationMixin


@strawberry.type
Expand All @@ -18,5 +19,6 @@ class Mutation(
SpanAnnotationMutationMixin,
TraceAnnotationMutationMixin,
ApiKeyMutationMixin,
UserMutationMixin,
):
pass
89 changes: 89 additions & 0 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
from typing import Optional

import strawberry
from sqlalchemy import insert, select
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry.types import Info

from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format
from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from phoenix.server.api.types.User import User
from phoenix.server.api.types.UserRole import UserRole


@strawberry.input
class CreateUserInput:
email: str
username: Optional[str]
password: str
role: UserRoleInput


@strawberry.type
class UserMutationPayload:
user: User


@strawberry.type
class UserMutationMixin:
@strawberry.mutation
async def create_user(
self,
info: Info[Context, None],
input: CreateUserInput,
) -> UserMutationPayload:
validate_email_format(email := input.email)
validate_password_format(password := input.password)
role_name = input.role.value
user_role_id = (
select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
)
secret = info.context.get_secret()
loop = asyncio.get_running_loop()
password_hash = await loop.run_in_executor(
executor=None,
func=lambda: compute_password_hash(password=password, salt=secret),
)
try:
async with info.context.db() as session:
user = await session.scalar(
insert(models.User)
.values(
user_role_id=user_role_id,
username=input.username,
email=email,
auth_method="LOCAL",
password_hash=password_hash,
reset_password=True,
)
.returning(models.User)
)
assert user is not None
except IntegrityError as error:
raise ValueError(_get_user_create_error_message(error))
return UserMutationPayload(
user=User(
id_attr=user.id,
email=user.email,
username=user.username,
created_at=user.created_at,
role=UserRole(id_attr=user.user_role_id, name=role_name),
)
)


def _get_user_create_error_message(error: IntegrityError) -> str:
"""
Gets a user-facing error message to explain why user creation failed.
"""
original_error_message = str(error)
username_already_exists = "users.username" in original_error_message
email_already_exists = "users.email" in original_error_message
if username_already_exists:
return "Username already exists"
elif email_already_exists:
return "Email already exists"
return "Failed to create user"
52 changes: 52 additions & 0 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
from datetime import timedelta

from fastapi import APIRouter, Form, Request, Response
from sqlalchemy import select
from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated

from phoenix.auth import is_valid_password
from phoenix.db import models

router = APIRouter(include_in_schema=False)

PHOENIX_ACCESS_TOKEN_COOKIE_NAME = "phoenix-access-token"
PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS = int(timedelta(days=31).total_seconds())


@router.post("/login")
async def login(
request: Request,
email: Annotated[str, Form()],
password: Annotated[str, Form()],
) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(select(models.User).where(models.User.email == email))
) is None or (password_hash := user.password_hash) is None:
return Response(status_code=HTTP_401_UNAUTHORIZED)
secret = request.app.state.get_secret()
loop = asyncio.get_running_loop()
if not await loop.run_in_executor(
executor=None,
func=lambda: is_valid_password(password=password, salt=secret, password_hash=password_hash),
):
return Response(status_code=HTTP_401_UNAUTHORIZED)
response = Response(status_code=HTTP_204_NO_CONTENT)
response.set_cookie(
key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
value="token", # todo: compute access token
secure=True,
httponly=True,
samesite="strict",
max_age=PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS,
)
return response


@router.post("/logout")
async def logout() -> Response:
response = Response(status_code=HTTP_204_NO_CONTENT)
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
return response
27 changes: 25 additions & 2 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from functools import cached_property
from pathlib import Path
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -30,6 +31,7 @@
AsyncSession,
async_sessionmaker,
)
from starlette.datastructures import State as StarletteState
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
Expand Down Expand Up @@ -80,6 +82,7 @@
TokenCountDataLoader,
TraceRowIdsDataLoader,
)
from phoenix.server.api.routers.auth import router as auth_router
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.routers.v1 import router as v1_router
from phoenix.server.api.schema import schema
Expand Down Expand Up @@ -536,6 +539,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
app.include_router(router)
app.include_router(graphql_router)
app.add_middleware(GZipMiddleware)
if authentication_enabled:
app.include_router(auth_router)
if serve_ui:
app.mount(
"/",
Expand All @@ -554,12 +559,30 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
),
name="static",
)

app.state.db = db
app = _update_app_state(app, db=db, secret=secret)
if tracer_provider:
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor

FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
shutdown_callbacks_list.append(FastAPIInstrumentor().uninstrument)
return app


def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
"""
Dynamically updates the app's `state` to include useful fields and methods
(at the time of this writing, FastAPI does not support setting this state
during the creation of the app).
"""
app.state.db = db
app.state._secret = secret

def get_secret(self: StarletteState) -> str:
if (secret := self._secret) is None:
raise ValueError("app secret is not set")
assert isinstance(secret, str)
return secret

app.state.get_secret = MethodType(get_secret, app.state)
return app
51 changes: 51 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from phoenix.auth import validate_email_format, validate_password_format


def test_validate_email_format_does_not_raise_on_valid_format() -> None:
validate_email_format("user@domain.com")


@pytest.mark.parametrize(
"email",
(
pytest.param("userdomain.com", id="missing-@"),
pytest.param("user@domain", id="missing-top-level-domain-name"),
pytest.param("user@domain.", id="empty-top-level-domain-name"),
pytest.param("user@.com", id="missing-domain-name"),
pytest.param("@domain.com", id="missing-username"),
pytest.param("user @domain.com", id="username-contains-space"),
pytest.param("user@do main.com", id="domain-name-contains-space"),
pytest.param("user@domain.c om", id="top-level-domain-name-contains-space"),
pytest.param(" user@domain.com", id="leading-space"),
pytest.param("user@domain.com ", id="trailing-space"),
pytest.param(" user@domain.com", id="leading-space"),
pytest.param("\nuser@domain.com ", id="leading-newline"),
pytest.param("user@domain.com\n", id="trailing-newline"),
),
)
def test_validate_email_format_raises_on_invalid_format(email: str) -> None:
with pytest.raises(ValueError):
validate_email_format(email)


def test_validate_password_format_does_not_raise_on_valid_format() -> None:
validate_password_format("Password1234!")


@pytest.mark.parametrize(
"password",
(
pytest.param("", id="empty"),
pytest.param("pass word", id="contains-space"),
pytest.param("pass\nword", id="contains-newline"),
pytest.param("password\n", id="trailing-newline"),
pytest.param("P@ßwø®∂!ñ", id="contains-non-ascii-chars"),
pytest.param("안녕하세요", id="korean"),
pytest.param("🚀", id="emoji"),
),
)
def test_validate_password_format_raises_on_invalid_format(password: str) -> None:
with pytest.raises(ValueError):
validate_password_format(password)

0 comments on commit a3ff0f6

Please sign in to comment.