Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): add login/ logout routes and createUser mutation #4293

Merged
merged 8 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ input CreateTraceAnnotationInput {
metadata: JSON! = {}
}

input CreateUserInput {
email: String!
username: String
password: String!
role: UserRoleInput!
}

enum DataQualityMetric {
cardinality
percentEmpty
Expand Down Expand Up @@ -899,6 +906,7 @@ type Mutation {
patchTraceAnnotations(input: [PatchAnnotationInput!]!): TraceAnnotationMutationPayload!
deleteTraceAnnotations(input: DeleteAnnotationsInput!): TraceAnnotationMutationPayload!
createSystemApiKey(input: CreateApiKeyInput!): CreateSystemApiKeyMutationPayload!
createUser(input: CreateUserInput!): UserMutationPayload!
}

"""An object with a Globally Unique ID"""
Expand Down Expand Up @@ -1449,12 +1457,21 @@ type UserEdge {
node: User!
}

type UserMutationPayload {
user: User!
}

type UserRole implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
name: String!
}

enum UserRoleInput {
ADMIN
MEMBER
}

type ValidationResult {
isValid: Boolean!
errorMessage: String
Expand Down
45 changes: 45 additions & 0 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import re
from hashlib import pbkdf2_hmac


def compute_password_hash(*, password: str, salt: str) -> str:
"""
Salts and hashes a password using PBKDF2, HMAC, and SHA256.
"""
password_bytes = password.encode("utf-8")
salt_bytes = salt.encode("utf-8")
password_hash_bytes = pbkdf2_hmac("sha256", password_bytes, salt_bytes, NUM_ITERATIONS)
password_hash = password_hash_bytes.hex()
return password_hash


def is_valid_password(*, password: str, salt: str, password_hash: str) -> bool:
"""
Determines whether the password is valid by salting and hashing the password
and comparing against the existing hash value.
"""
return password_hash == compute_password_hash(password=password, salt=salt)


def validate_email_format(email: str) -> None:
"""
Checks that the email has a valid format.
"""
if EMAIL_PATTERN.match(email) is None:
raise ValueError("Invalid email address")


def validate_password_format(password: str) -> None:
"""
Checks that the password has a valid format.
"""
if not password:
raise ValueError("Password must be non-empty")
if any(char.isspace() for char in password):
raise ValueError("Password cannot contain whitespace characters")
if not password.isascii():
raise ValueError("Password can contain only ASCII characters")


EMAIL_PATTERN = re.compile(r"^[^@\s]+@[^@\s]+[.][^@\s]+\Z")
NUM_ITERATIONS = 10_000
Empty file removed src/phoenix/auth/__init__.py
Empty file.
15 changes: 0 additions & 15 deletions src/phoenix/auth/utils.py

This file was deleted.

9 changes: 9 additions & 0 deletions src/phoenix/server/api/input_types/UserRoleInput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum

import strawberry


@strawberry.enum
class UserRoleInput(Enum):
ADMIN = "ADMIN"
MEMBER = "MEMBER"
2 changes: 2 additions & 0 deletions src/phoenix/server/api/mutations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin
from phoenix.server.api.mutations.trace_annotations_mutations import TraceAnnotationMutationMixin
from phoenix.server.api.mutations.user_mutations import UserMutationMixin


@strawberry.type
Expand All @@ -18,5 +19,6 @@ class Mutation(
SpanAnnotationMutationMixin,
TraceAnnotationMutationMixin,
ApiKeyMutationMixin,
UserMutationMixin,
):
pass
89 changes: 89 additions & 0 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
from typing import Optional

import strawberry
from sqlalchemy import insert, select
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry.types import Info

from phoenix.auth import compute_password_hash, validate_email_format, validate_password_format
from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from phoenix.server.api.types.User import User
from phoenix.server.api.types.UserRole import UserRole


@strawberry.input
class CreateUserInput:
email: str
username: Optional[str]
password: str
role: UserRoleInput


@strawberry.type
class UserMutationPayload:
user: User


@strawberry.type
class UserMutationMixin:
@strawberry.mutation
async def create_user(
self,
info: Info[Context, None],
input: CreateUserInput,
) -> UserMutationPayload:
validate_email_format(email := input.email)
validate_password_format(password := input.password)
role_name = input.role.value
user_role_id = (
select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
)
secret = info.context.get_secret()
loop = asyncio.get_running_loop()
password_hash = await loop.run_in_executor(
executor=None,
func=lambda: compute_password_hash(password=password, salt=secret),
)
try:
async with info.context.db() as session:
user = await session.scalar(
insert(models.User)
.values(
user_role_id=user_role_id,
username=input.username,
email=email,
auth_method="LOCAL",
password_hash=password_hash,
reset_password=True,
)
.returning(models.User)
)
assert user is not None
except IntegrityError as error:
raise ValueError(_get_user_create_error_message(error))
return UserMutationPayload(
user=User(
id_attr=user.id,
email=user.email,
username=user.username,
created_at=user.created_at,
role=UserRole(id_attr=user.user_role_id, name=role_name),
)
)


def _get_user_create_error_message(error: IntegrityError) -> str:
"""
Gets a user-facing error message to explain why user creation failed.
"""
original_error_message = str(error)
username_already_exists = "users.username" in original_error_message
email_already_exists = "users.email" in original_error_message
if username_already_exists:
return "Username already exists"
elif email_already_exists:
return "Email already exists"
return "Failed to create user"
52 changes: 52 additions & 0 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
from datetime import timedelta

from fastapi import APIRouter, Form, Request, Response
from sqlalchemy import select
from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated

from phoenix.auth import is_valid_password
from phoenix.db import models

router = APIRouter(include_in_schema=False)

PHOENIX_ACCESS_TOKEN_COOKIE_NAME = "phoenix-access-token"
PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS = int(timedelta(days=31).total_seconds())


@router.post("/login")
async def login(
request: Request,
email: Annotated[str, Form()],
password: Annotated[str, Form()],
) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(select(models.User).where(models.User.email == email))
) is None or (password_hash := user.password_hash) is None:
return Response(status_code=HTTP_401_UNAUTHORIZED)
secret = request.app.state.get_secret()
loop = asyncio.get_running_loop()
if not await loop.run_in_executor(
executor=None,
func=lambda: is_valid_password(password=password, salt=secret, password_hash=password_hash),
):
return Response(status_code=HTTP_401_UNAUTHORIZED)
response = Response(status_code=HTTP_204_NO_CONTENT)
response.set_cookie(
key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
value="token", # todo: compute access token
secure=True,
httponly=True,
samesite="strict",
max_age=PHOENIX_ACCESS_TOKEN_COOKIE_MAX_AGE_IN_SECONDS,
)
return response


@router.post("/logout")
async def logout() -> Response:
response = Response(status_code=HTTP_204_NO_CONTENT)
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
return response
27 changes: 25 additions & 2 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from functools import cached_property
from pathlib import Path
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -30,6 +31,7 @@
AsyncSession,
async_sessionmaker,
)
from starlette.datastructures import State as StarletteState
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
Expand Down Expand Up @@ -80,6 +82,7 @@
TokenCountDataLoader,
TraceRowIdsDataLoader,
)
from phoenix.server.api.routers.auth import router as auth_router
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.routers.v1 import router as v1_router
from phoenix.server.api.schema import schema
Expand Down Expand Up @@ -529,6 +532,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
app.include_router(router)
app.include_router(graphql_router)
app.add_middleware(GZipMiddleware)
if authentication_enabled:
app.include_router(auth_router)
if serve_ui:
app.mount(
"/",
Expand All @@ -547,12 +552,30 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
),
name="static",
)

app.state.db = db
app = _update_app_state(app, db=db, secret=secret)
if tracer_provider:
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor

FastAPIInstrumentor().instrument(tracer_provider=tracer_provider)
FastAPIInstrumentor.instrument_app(app, tracer_provider=tracer_provider)
clean_ups.append(FastAPIInstrumentor().uninstrument)
return app


def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
"""
Dynamically updates the app's `state` to include useful fields and methods
(at the time of this writing, FastAPI does not support setting this state
during the creation of the app).
"""
app.state.db = db
app.state._secret = secret

def get_secret(self: StarletteState) -> str:
if (secret := self._secret) is None:
raise ValueError("app secret is not set")
assert isinstance(secret, str)
return secret

app.state.get_secret = MethodType(get_secret, app.state)
return app
51 changes: 51 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from phoenix.auth import validate_email_format, validate_password_format


def test_validate_email_format_does_not_raise_on_valid_format() -> None:
validate_email_format("user@domain.com")


@pytest.mark.parametrize(
"email",
(
pytest.param("userdomain.com", id="missing-@"),
pytest.param("user@domain", id="missing-top-level-domain-name"),
pytest.param("user@domain.", id="empty-top-level-domain-name"),
pytest.param("user@.com", id="missing-domain-name"),
pytest.param("@domain.com", id="missing-username"),
pytest.param("user @domain.com", id="username-contains-space"),
pytest.param("user@do main.com", id="domain-name-contains-space"),
pytest.param("user@domain.c om", id="top-level-domain-name-contains-space"),
pytest.param(" user@domain.com", id="leading-space"),
pytest.param("user@domain.com ", id="trailing-space"),
pytest.param(" user@domain.com", id="leading-space"),
pytest.param("\nuser@domain.com ", id="leading-newline"),
pytest.param("user@domain.com\n", id="trailing-newline"),
),
)
def test_validate_email_format_raises_on_invalid_format(email: str) -> None:
with pytest.raises(ValueError):
validate_email_format(email)


def test_validate_password_format_does_not_raise_on_valid_format() -> None:
validate_password_format("Password1234!")


@pytest.mark.parametrize(
"password",
(
pytest.param("", id="empty"),
pytest.param("pass word", id="contains-space"),
pytest.param("pass\nword", id="contains-newline"),
pytest.param("password\n", id="trailing-newline"),
pytest.param("P@ßwø®∂!ñ", id="contains-non-ascii-chars"),
pytest.param("안녕하세요", id="korean"),
pytest.param("🚀", id="emoji"),
),
)
def test_validate_password_format_raises_on_invalid_format(password: str) -> None:
with pytest.raises(ValueError):
validate_password_format(password)
Loading