Skip to content

Commit

Permalink
feat(auth): secure graphql api when auth is enabled (#4508)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored and RogerHYang committed Sep 21, 2024
1 parent a76638b commit 39b1e07
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
22 changes: 11 additions & 11 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ def test_admin(
passwords: Iterator[_Password],
) -> None:
password = secret if use_secret else next(passwords)
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_system_api_key(None, name=fake.unique.pystr())
with expectation:
with log_in(password, email=email) as (token, _):
create_system_api_key(token, name=fake.unique.pystr())
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_system_api_key(token, name=fake.unique.pystr())

def test_end_to_end_credentials_flow(
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_end_to_end_credentials_flow(
resp.raise_for_status()

# original access token is invalid after refresh
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_system_api_key(browser_0_access_token_0, name="api-key-2")

# user logs into second browser
Expand All @@ -283,9 +283,9 @@ def test_end_to_end_credentials_flow(
resp.raise_for_status()

# user is logged out of both browsers
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_system_api_key(browser_0_access_token_1, name="api-key-4")
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_system_api_key(browser_1_access_token_0, name="api-key-5")

@pytest.mark.parametrize(
Expand All @@ -311,7 +311,7 @@ def test_create_user(
email = profile.email
username = profile.username
password = profile.password
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
create_user(None, email=email, password=password, username=username, role=role)
with log_in(secret, email=admin_email) as (token, _):
create_user(token, email=email, password=password, username=username, role=role)
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_user_can_change_password_for_self(
log_in(password, email=email).__enter__()
patch_viewer((old_token := token), (old_password := password), new_password=new_password)
another_password = f"another_password_{next(passwords)}"
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
patch_viewer(old_token, new_password, new_password=another_password)
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
log_in(old_password, email=email).__enter__()
Expand All @@ -376,7 +376,7 @@ def test_user_can_change_username_for_self(
token, password = user.token, user.profile.password
new_username = f"new_username_{next(usernames)}"
for _password in (None, password):
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
patch_viewer(None, _password, new_username=new_username)
patch_viewer(token, None, new_username=new_username)
another_username = f"another_username_{next(usernames)}"
Expand All @@ -403,7 +403,7 @@ def test_only_admin_can_change_role_for_non_self(
non_self = get_new_user(UserRoleInput.MEMBER)
assert user.gid != non_self.gid
token, gid = user.token, non_self.gid
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
patch_user(None, gid, new_role=UserRoleInput.ADMIN)
with expectation:
patch_user(token, gid, new_role=UserRoleInput.ADMIN)
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_only_admin_can_change_password_for_non_self(
new_password = f"new_password_{next(passwords)}"
assert new_password != old_password
token, gid = user.token, non_self.gid
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
patch_user(None, gid, new_password=new_password)
with expectation as e:
patch_user(token, gid, new_password=new_password)
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_only_admin_can_change_username_for_non_self(
new_username = f"new_username_{next(usernames)}"
assert new_username != old_username
token, gid = user.token, non_self.gid
with pytest.raises(Unauthorized):
with pytest.raises(HTTPStatusError, match="401 Unauthorized"):
patch_user(None, gid, new_username=new_username)
with expectation:
patch_user(token, gid, new_username=new_username)
Expand Down
17 changes: 3 additions & 14 deletions src/phoenix/server/api/routers/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import APIKeyHeader
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from starlette.status import HTTP_403_FORBIDDEN

from phoenix.auth import ClaimSetStatus
from phoenix.config import ENABLE_AUTH
from phoenix.server.bearer_auth import PhoenixUser
from phoenix.server.bearer_auth import is_authenticated

from .datasets import router as datasets_router
from .evaluations import router as evaluations_router
Expand All @@ -29,16 +28,6 @@ async def prevent_access_in_read_only_mode(request: Request) -> None:
)


async def authentication(request: Request) -> None:
if not isinstance((user := request.user), PhoenixUser):
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
claims = user.claims
if claims.status is ClaimSetStatus.EXPIRED:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
if claims.status is not ClaimSetStatus.VALID:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")


dependencies = [Depends(prevent_access_in_read_only_mode)]
if ENABLE_AUTH:
dependencies.append(
Expand All @@ -51,7 +40,7 @@ async def authentication(request: Request) -> None:
)
)
)
dependencies.append(Depends(authentication))
dependencies.append(Depends(is_authenticated))

router = APIRouter(
prefix="/v1",
Expand Down
8 changes: 6 additions & 2 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

import strawberry
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter, Depends, FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse
from fastapi.utils import is_body_allowed_for_status_code
Expand Down Expand Up @@ -94,7 +94,7 @@
from phoenix.server.api.routers import auth_router, v1_router
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.schema import schema
from phoenix.server.bearer_auth import BearerTokenAuthBackend
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
from phoenix.server.dml_event import DmlEvent
from phoenix.server.dml_event_handler import DmlEventHandler
from phoenix.server.grpc_server import GrpcServer
Expand Down Expand Up @@ -444,6 +444,7 @@ def create_graphql_router(
model: Model,
export_path: Path,
last_updated_at: CanGetLastUpdatedAt,
authentication_enabled: bool,
corpus: Optional[Model] = None,
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
event_queue: CanPutItem[DmlEvent],
Expand All @@ -459,6 +460,7 @@ def create_graphql_router(
model (Model): The Model representing inferences (legacy)
export_path (Path): the file path to export data to for download (legacy)
last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
authentication_enabled (bool): Whether authentication is enabled.
event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
Expand Down Expand Up @@ -545,6 +547,7 @@ def get_context() -> Context:
context_getter=get_context,
include_in_schema=False,
prefix="/graphql",
dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
)


Expand Down Expand Up @@ -685,6 +688,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
),
model=model,
corpus=corpus,
authentication_enabled=authentication_enabled,
export_path=export_path,
last_updated_at=last_updated_at,
event_queue=dml_event_handler,
Expand Down
16 changes: 16 additions & 0 deletions src/phoenix/server/bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from typing import Any, Awaitable, Callable, Optional, Tuple

import grpc
from fastapi import Request
from grpc_interceptor import AsyncServerInterceptor
from grpc_interceptor.exceptions import Unauthenticated
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED

from phoenix.auth import (
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
Expand Down Expand Up @@ -81,3 +84,16 @@ async def intercept(
return await method(request_or_iterator, context)
raise Unauthenticated()
raise Unauthenticated()


async def is_authenticated(request: Request) -> None:
"""
Raises a 401 if the request is not authenticated.
"""
if not isinstance((user := request.user), PhoenixUser):
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
claims = user.claims
if claims.status is ClaimSetStatus.EXPIRED:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
if claims.status is not ClaimSetStatus.VALID:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")

0 comments on commit 39b1e07

Please sign in to comment.