From e00862ffe3a75f3a57359f4e2c7e11d73a376de9 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 10:22:51 -0700 Subject: [PATCH 01/21] ci: remove unnecessary fixtures from integration tests --- integration_tests/__init__.py | 1 + integration_tests/auth/__init__.py | 0 integration_tests/auth/conftest.py | 571 ++++++++------------ integration_tests/auth/test_auth.py | 423 +++++---------- integration_tests/conftest.py | 277 ++++------ integration_tests/mypy.ini | 1 - integration_tests/py.typed | 0 integration_tests/server/__init__.py | 0 integration_tests/server/test_launch_app.py | 72 +-- 9 files changed, 493 insertions(+), 852 deletions(-) create mode 100644 integration_tests/__init__.py create mode 100644 integration_tests/auth/__init__.py create mode 100644 integration_tests/py.typed create mode 100644 integration_tests/server/__init__.py diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py new file mode 100644 index 0000000000..9d48db4f9f --- /dev/null +++ b/integration_tests/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/integration_tests/auth/__init__.py b/integration_tests/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index b52676af15..bbb44a9174 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -1,22 +1,12 @@ +from __future__ import annotations + import os import secrets -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack from dataclasses import asdict, dataclass from datetime import datetime from itertools import count, starmap -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Generator, - Iterator, - List, - Optional, - Protocol, - Tuple, - cast, -) +from typing import Any, Dict, Generator, Iterator, Optional, Protocol, Tuple, cast from unittest import mock from urllib.parse import urljoin @@ -28,407 +18,298 @@ PHOENIX_REFRESH_TOKEN_COOKIE_NAME, REQUIREMENTS_FOR_PHOENIX_SECRET, ) -from phoenix.config import ( - ENV_PHOENIX_ENABLE_AUTH, - ENV_PHOENIX_SECRET, - get_base_url, -) +from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET, get_base_url from phoenix.server.api.auth import IsAdmin, IsAuthenticated from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from typing_extensions import TypeAlias -_ProjectName: TypeAlias = str -_Name: TypeAlias = str -_ApiKey: TypeAlias = str -_GqlId: TypeAlias = str +from integration_tests.conftest import _httpx_client, _server -_Username: TypeAlias = str _Email: TypeAlias = str +_GqlId: TypeAlias = str +_Name: TypeAlias = str _Password: TypeAlias = str _Token: TypeAlias = str -_AccessToken: TypeAlias = str -_RefreshToken: TypeAlias = str - - -class _LogIn(Protocol): - def __call__( - self, - password: _Password, - /, - *, - email: _Email, - ) -> ContextManager[Tuple[_AccessToken, _RefreshToken]]: ... - - -class _LogOut(Protocol): - def __call__(self, token: _Token, /) -> None: ... - - -class _CreateUser(Protocol): - def __call__( - self, - token: Optional[_Token], - /, - *, - email: _Email, - password: _Password, - role: UserRoleInput, - username: Optional[_Username] = None, - ) -> _GqlId: ... - - -class _PatchUser(Protocol): - def __call__( - self, - token: Optional[_Token], - gid: _GqlId, - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - new_role: Optional[UserRoleInput] = None, - ) -> None: ... - - -class _PatchViewer(Protocol): - def __call__( - self, - token: Optional[_Token], - current_password: Optional[_Password], - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - ) -> None: ... - - -class _CreateSystemApiKey(Protocol): - def __call__( - self, - token: Optional[_Token], - /, - *, - name: _Name, - expires_at: Optional[datetime] = None, - ) -> Tuple[_ApiKey, _GqlId]: ... - - -class _DeleteSystemApiKey(Protocol): - def __call__( - self, - token: Optional[_Token], - gid: _GqlId, - /, - ) -> None: ... - - -class _GetGqlSpans(Protocol): - def __call__(self, *keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: ... +_Username: TypeAlias = str + +_AccessToken: TypeAlias = _Token +_ApiKey: TypeAlias = _Token +_RefreshToken: TypeAlias = _Token + + +@dataclass(frozen=True) +class _Profile: + email: _Email + password: _Password + username: Optional[_Username] = None + + +@dataclass(frozen=True) +class _User: + gid: _GqlId + role: UserRoleInput + profile: _Profile + + +@dataclass(frozen=True) +class _LoggedInTokens: + access: _AccessToken + refresh: _RefreshToken + + def log_out(self) -> None: + _log_out(self.access) + + def __iter__(self) -> Iterator[_Token]: + yield self.access + yield self.refresh + + def __enter__(self) -> _LoggedInTokens: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + self.log_out() + + +@dataclass(frozen=True) +class _LoggedInUser(_User): + tokens: _LoggedInTokens + + +class _UserGenerator(Protocol): + def send(self, role: UserRoleInput) -> _LoggedInUser: ... + + +class _GetNewUser(Protocol): + def __call__(self, role: UserRoleInput) -> _LoggedInUser: ... @pytest.fixture(scope="class") -def secret(fake: Faker) -> str: +def _secret() -> str: return secrets.token_hex(32) @pytest.fixture(autouse=True, scope="class") -def app( - secret: str, - env_phoenix_sql_database_url: Any, - server: Callable[[], ContextManager[None]], +def _app( + _secret: str, + _env_phoenix_sql_database_url: Any, ) -> Iterator[None]: values = ( (ENV_PHOENIX_ENABLE_AUTH, "true"), - (ENV_PHOENIX_SECRET, secret), + (ENV_PHOENIX_SECRET, _secret), ) with ExitStack() as stack: stack.enter_context(mock.patch.dict(os.environ, values)) - stack.enter_context(server()) + stack.enter_context(_server()) yield @pytest.fixture(scope="class") -def emails(fake: Faker) -> Iterator[_Email]: - return (fake.unique.email() for _ in count()) +def _emails(_fake: Faker) -> Iterator[_Email]: + return (_fake.unique.email() for _ in count()) @pytest.fixture(scope="class") -def passwords(fake: Faker) -> Iterator[_Password]: - return (fake.unique.password(**asdict(REQUIREMENTS_FOR_PHOENIX_SECRET)) for _ in count()) +def _passwords(_fake: Faker) -> Iterator[_Password]: + return (_fake.unique.password(**asdict(REQUIREMENTS_FOR_PHOENIX_SECRET)) for _ in count()) @pytest.fixture(scope="class") -def usernames(fake: Faker) -> Iterator[_Username]: - return (fake.unique.pystr() for _ in count()) - - -@dataclass(frozen=True) -class _Profile: - email: _Email - password: _Password - username: Optional[_Username] = None - - -@dataclass(frozen=True) -class _User: - gid: _GqlId - role: UserRoleInput - profile: _Profile - token: Optional[_Token] = None +def _usernames(_fake: Faker) -> Iterator[_Username]: + return (_fake.unique.pystr() for _ in count()) @pytest.fixture(scope="class") -def profiles( - emails: Iterator[_Email], - usernames: Iterator[_Username], - passwords: Iterator[_Password], +def _profiles( + _emails: Iterator[_Email], + _usernames: Iterator[_Password], + _passwords: Iterator[_Password], ) -> Iterator[_Profile]: - return starmap(_Profile, zip(emails, passwords, usernames)) - - -class _UserGenerator(Protocol): - def send(self, role: UserRoleInput) -> _User: ... + return starmap(_Profile, zip(_emails, _passwords, _usernames)) @pytest.fixture def _users( - profiles: Iterator[_Profile], - admin_token: _Token, - create_user: _CreateUser, - log_in: _LogIn, - fake: Faker, + _profiles: Iterator[_Profile], + _admin_token: _Token, + _fake: Faker, ) -> _UserGenerator: - def _() -> Generator[Optional[_User], UserRoleInput, None]: + def _() -> Generator[Optional[_LoggedInUser], UserRoleInput, None]: role = yield None - for profile in profiles: - gid = create_user(admin_token, **asdict(profile), role=role) + for profile in _profiles: + gid = _create_user(_admin_token, **asdict(profile), role=role) email, password = profile.email, profile.password - token, _ = log_in(password, email=email).__enter__() - role = yield _User(gid=gid, role=role, token=token, profile=profile) + tokens = _log_in(password, email=email) + role = yield _LoggedInUser(gid=gid, role=role, tokens=tokens, profile=profile) g = _() next(g) return cast(_UserGenerator, g) -class _GetNewUser(Protocol): - def __call__(self, role: UserRoleInput) -> _User: ... - - @pytest.fixture -def get_new_user( +def _get_new_user( _users: _UserGenerator, ) -> _GetNewUser: - def _(role: UserRoleInput) -> _User: + def _(role: UserRoleInput) -> _LoggedInUser: return _users.send(role) return _ @pytest.fixture -def admin_token( - admin_email: str, - secret: str, - log_in: _LogIn, +def _admin_token( + _admin_email: str, + _secret: str, ) -> Iterator[_Token]: - with log_in(secret, email=admin_email) as (token, _): + with _log_in(_secret, email=_admin_email) as (token, _): yield token @pytest.fixture(scope="module") -def admin_email() -> _Email: +def _admin_email() -> _Email: return "admin@localhost" -@pytest.fixture(scope="module") -def create_user( - httpx_client: Callable[[], httpx.Client], -) -> _CreateUser: - def _( - token: Optional[_Token], - /, - *, - email: _Email, - password: _Password, - role: UserRoleInput, - username: Optional[_Username] = None, - ) -> _GqlId: - args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] - if username: - args.append(f'username:"{username}"') - out = "user{id email role{name}}" - query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["createUser"]["user"]) - assert user["email"] == email - assert user["role"]["name"] == role.value - return cast(_GqlId, user["id"]) - - return _ - - -@pytest.fixture(scope="module") -def patch_user( - httpx_client: Callable[[], httpx.Client], -) -> _PatchUser: - def _( - token: Optional[_Token], - gid: _GqlId, - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - new_role: Optional[UserRoleInput] = None, - ) -> None: - args = [f'userId:"{gid}"'] - if new_password: - args.append(f'newPassword:"{new_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - if new_role: - args.append(f"newRole:{new_role.value}") - out = "user{id username role{name}}" - query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchUser"]["user"]) - assert user["id"] == gid - if new_username: - assert user["username"] == new_username - if new_role: - assert user["role"]["name"] == new_role.value - - return _ - - -@pytest.fixture(scope="module") -def patch_viewer( - httpx_client: Callable[[], httpx.Client], -) -> _PatchViewer: - def _( - token: Optional[_Token], - current_password: Optional[_Password], - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - ) -> None: - args = [] - if new_password: - args.append(f'newPassword:"{new_password}"') - if current_password: - args.append(f'currentPassword:"{current_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - out = "user{username}" - query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" - resp = httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchViewer"]["user"]) - if new_username: - assert user["username"] == new_username - - return _ - - -@pytest.fixture(scope="module") -def create_system_api_key( - httpx_client: Callable[[], httpx.Client], -) -> _CreateSystemApiKey: - def _( - token: Optional[_Token], - /, - *, - name: _Name, - expires_at: Optional[datetime] = None, - ) -> Tuple[_ApiKey, _GqlId]: - exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" - args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" - query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (result := resp_dict["data"]["createSystemApiKey"]) - assert (api_key := result["apiKey"]) - assert api_key["name"] == name - exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None - assert exp_t == expires_at - return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) - - return _ - - -@pytest.fixture(scope="module") -def delete_system_api_key( - httpx_client: Callable[[], httpx.Client], -) -> _DeleteSystemApiKey: - def _( - token: Optional[_Token], - gid: _GqlId, - /, - ) -> None: - args, out = f'id:"{gid}"', "apiKeyId" - query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid - - return _ - - -@pytest.fixture(scope="module") -def log_in( - httpx_client: Callable[[], httpx.Client], - log_out: _LogOut, -) -> _LogIn: - @contextmanager - def _(password: _Password, /, *, email: _Email) -> Iterator[Tuple[_AccessToken, _RefreshToken]]: - resp = httpx_client().post( - urljoin(get_base_url(), "/auth/login"), - json={"email": email, "password": password}, - ) - resp.raise_for_status() - assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) - yield access_token, refresh_token - log_out(access_token) +def _create_user( + token: Optional[_Token], + /, + *, + email: _Email, + password: _Password, + role: UserRoleInput, + username: Optional[_Username] = None, +) -> _GqlId: + args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] + if username: + args.append(f'username:"{username}"') + out = "user{id email role{name}}" + query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["createUser"]["user"]) + assert user["email"] == email + assert user["role"]["name"] == role.value + return cast(_GqlId, user["id"]) + + +def _patch_user( + token: Optional[_Token], + gid: _GqlId, + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, + new_role: Optional[UserRoleInput] = None, +) -> None: + args = [f'userId:"{gid}"'] + if new_password: + args.append(f'newPassword:"{new_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + if new_role: + args.append(f"newRole:{new_role.value}") + out = "user{id username role{name}}" + query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchUser"]["user"]) + assert user["id"] == gid + if new_username: + assert user["username"] == new_username + if new_role: + assert user["role"]["name"] == new_role.value + + +def _patch_viewer( + token: Optional[_Token], + current_password: Optional[_Password], + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, +) -> None: + args = [] + if new_password: + args.append(f'newPassword:"{new_password}"') + if current_password: + args.append(f'currentPassword:"{current_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + out = "user{username}" + query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchViewer"]["user"]) + if new_username: + assert user["username"] == new_username + + +def _create_system_api_key( + token: Optional[_Token], + /, + *, + name: _Name, + expires_at: Optional[datetime] = None, +) -> Tuple[_ApiKey, _GqlId]: + exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" + args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" + query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (result := resp_dict["data"]["createSystemApiKey"]) + assert (api_key := result["apiKey"]) + assert api_key["name"] == name + exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None + assert exp_t == expires_at + return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) + + +def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: + args, out = f'id:"{gid}"', "apiKeyId" + query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid - return _ +def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/login"), + json={"email": email, "password": password}, + ) + resp.raise_for_status() + assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) + assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + return _LoggedInTokens(access_token, refresh_token) -@pytest.fixture(scope="module") -def log_out( - httpx_client: Callable[[], httpx.Client], -) -> _LogOut: - def _(token: _Token, /) -> None: - resp = httpx_client().post( - urljoin(get_base_url(), "/auth/logout"), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, - ) - resp.raise_for_status() - return _ +def _log_out(token: _Token, /) -> None: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/logout"), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, + ) + resp.raise_for_status() def _json(resp: httpx.Response) -> Dict[str, Any]: diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 74be5550a9..0ef06935b3 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -2,176 +2,48 @@ from datetime import datetime, timedelta, timezone from functools import partial from itertools import product -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Iterator, - Optional, - Protocol, - Tuple, -) +from typing import ContextManager, Iterator, Optional from urllib.parse import urljoin -import httpx import jwt import pytest from faker import Faker from httpx import HTTPStatusError -from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult +from opentelemetry.sdk.trace.export import SpanExportResult from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from opentelemetry.trace import Span -from phoenix.auth import ( - PHOENIX_ACCESS_TOKEN_COOKIE_NAME, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME, -) +from phoenix.auth import PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME from phoenix.config import get_base_url from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from typing_extensions import TypeAlias - -NOW = datetime.now(timezone.utc) - -_ProjectName: TypeAlias = str -_SpanName: TypeAlias = str -_Headers: TypeAlias = Dict[str, Any] -_Name: TypeAlias = str - -_Username: TypeAlias = str -_Email: TypeAlias = str -_Password: TypeAlias = str -_Token: TypeAlias = str -_AccessToken: TypeAlias = str -_RefreshToken: TypeAlias = str -_ApiKey: TypeAlias = str -_GqlId: TypeAlias = str - - -class _LogIn(Protocol): - def __call__( - self, - password: _Password, - /, - *, - email: _Email, - ) -> ContextManager[Tuple[_AccessToken, _RefreshToken]]: ... - - -class _LogOut(Protocol): - def __call__(self, token: _Token, /) -> None: ... - - -class _CreateUser(Protocol): - def __call__( - self, - token: Optional[_Token], - /, - *, - email: _Email, - password: _Password, - role: UserRoleInput, - username: Optional[_Username] = None, - ) -> _GqlId: ... - - -class _PatchUser(Protocol): - def __call__( - self, - token: Optional[_Token], - gid: _GqlId, - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - new_role: Optional[UserRoleInput] = None, - ) -> None: ... - - -class _PatchViewer(Protocol): - def __call__( - self, - token: Optional[_Token], - current_password: Optional[_Password], - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - ) -> None: ... - - -class _CreateSystemApiKey(Protocol): - def __call__( - self, - token: Optional[_Token], - /, - *, - name: _Name, - expires_at: Optional[datetime] = None, - ) -> Tuple[_ApiKey, _GqlId]: ... - - -class _DeleteSystemApiKey(Protocol): - def __call__( - self, - token: Optional[_Token], - gid: _GqlId, - /, - ) -> None: ... - - -class _SpanExporterFactory(Protocol): - def __call__( - self, - *, - headers: Optional[_Headers] = None, - ) -> SpanExporter: ... - - -class _StartSpan(Protocol): - def __call__( - self, - *, - project_name: _ProjectName, - span_name: _SpanName, - exporter: SpanExporter, - ) -> Span: ... - - -class _Profile(Protocol): - @property - def email(self) -> _Email: ... - @property - def password(self) -> _Password: ... - @property - def username(self) -> Optional[_Username]: ... - - -class _User(Protocol): - @property - def gid(self) -> _GqlId: ... - @property - def profile(self) -> _Profile: ... - @property - def role(self) -> UserRoleInput: ... - @property - def token(self) -> Optional[_Token]: ... +from ..conftest import _Headers, _httpx_client, _SpanExporterFactory, _start_span +from .conftest import ( + _create_system_api_key, + _create_user, + _delete_system_api_key, + _GetNewUser, + _GqlId, + _log_in, + _Password, + _patch_user, + _patch_viewer, + _Profile, + _Token, + _Username, +) -class _GetNewUser(Protocol): - def __call__(self, role: UserRoleInput) -> _User: ... +NOW = datetime.now(timezone.utc) class TestTokens: def test_log_in_tokens_should_change( self, - admin_email: str, - secret: str, - log_in: _LogIn, + _admin_email: str, + _secret: str, ) -> None: n, access_tokens, refresh_tokens = 2, set(), set() for _ in range(n): - with log_in(secret, email=admin_email) as (access_token, refresh_token): + with _log_in(_secret, email=_admin_email) as (access_token, refresh_token): access_tokens.add(access_token) refresh_tokens.add(refresh_token) assert len(access_tokens) == n @@ -196,43 +68,39 @@ def test_admin( email: str, use_secret: bool, expectation: ContextManager[Optional[Unauthorized]], - secret: str, - log_in: _LogIn, - create_system_api_key: _CreateSystemApiKey, - fake: Faker, - passwords: Iterator[_Password], + _secret: str, + _fake: Faker, + _passwords: Iterator[_Password], ) -> None: - password = secret if use_secret else next(passwords) + password = _secret if use_secret else next(_passwords) with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_system_api_key(None, name=fake.unique.pystr()) + _create_system_api_key(None, name=_fake.unique.pystr()) with expectation: - with log_in(password, email=email) as (token, _): - create_system_api_key(token, name=fake.unique.pystr()) + with _log_in(password, email=email) as (token, _): + _create_system_api_key(token, name=_fake.unique.pystr()) with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_system_api_key(token, name=fake.unique.pystr()) + _create_system_api_key(token, name=_fake.unique.pystr()) def test_end_to_end_credentials_flow( self, - admin_email: str, - secret: str, - httpx_client: Callable[[], httpx.Client], - create_system_api_key: _CreateSystemApiKey, - fake: Faker, + _admin_email: str, + _secret: str, + _fake: Faker, ) -> None: # user logs into first browser - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "/auth/login"), - json={"email": admin_email, "password": secret}, + json={"email": _admin_email, "password": _secret}, ) resp.raise_for_status() assert (browser_0_access_token_0 := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) assert (browser_0_refresh_token_0 := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) # user creates api key in the first browser - create_system_api_key(browser_0_access_token_0, name="api-key-0") + _create_system_api_key(browser_0_access_token_0, name="api-key-0") # tokens are refreshed in the first browser - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "/auth/refresh"), cookies={ PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_0, @@ -244,10 +112,10 @@ def test_end_to_end_credentials_flow( assert (browser_0_refresh_token_1 := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) # user creates api key in the first browser - create_system_api_key(browser_0_access_token_1, name="api-key-1") + _create_system_api_key(browser_0_access_token_1, name="api-key-1") # refresh token is good for one use only - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "/auth/refresh"), cookies={ PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_0, @@ -259,22 +127,22 @@ def test_end_to_end_credentials_flow( # original access token is invalid after refresh with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_system_api_key(browser_0_access_token_0, name="api-key-2") + _create_system_api_key(browser_0_access_token_0, name="api-key-2") # user logs into second browser - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "/auth/login"), - json={"email": admin_email, "password": secret}, + json={"email": _admin_email, "password": _secret}, ) resp.raise_for_status() assert (browser_1_access_token_0 := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) assert resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) # user creates api key in the second browser - create_system_api_key(browser_1_access_token_0, name="api-key-3") + _create_system_api_key(browser_1_access_token_0, name="api-key-3") # user logs out in first browser - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "/auth/logout"), cookies={ PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_1, @@ -285,9 +153,9 @@ def test_end_to_end_credentials_flow( # user is logged out of both browsers with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_system_api_key(browser_0_access_token_1, name="api-key-4") + _create_system_api_key(browser_0_access_token_1, name="api-key-4") with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_system_api_key(browser_1_access_token_0, name="api-key-5") + _create_system_api_key(browser_1_access_token_0, name="api-key-5") @pytest.mark.parametrize( "role,expectation", @@ -300,29 +168,26 @@ def test_create_user( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - admin_email: str, - secret: str, - log_in: _LogIn, - create_user: _CreateUser, - create_system_api_key: _CreateSystemApiKey, - fake: Faker, - profiles: Iterator[_Profile], + _admin_email: str, + _secret: str, + _fake: Faker, + _profiles: Iterator[_Profile], ) -> None: - profile = next(profiles) + profile = next(_profiles) email = profile.email username = profile.username password = profile.password with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - create_user(None, email=email, password=password, username=username, role=role) - with log_in(secret, email=admin_email) as (token, _): - create_user(token, email=email, password=password, username=username, role=role) - with log_in(password, email=email) as (token, _): + _create_user(None, email=email, password=password, username=username, role=role) + with _log_in(_secret, email=_admin_email) as (token, _): + _create_user(token, email=email, password=password, username=username, role=role) + with _log_in(password, email=email) as (token, _): with expectation: - create_system_api_key(token, name=fake.unique.pystr()) + _create_system_api_key(token, name=_fake.unique.pystr()) for _role in UserRoleInput: - _profile = next(profiles) + _profile = next(_profiles) with expectation: - create_user( + _create_user( token, email=_profile.email, username=_profile.username, @@ -334,56 +199,52 @@ def test_create_user( def test_user_can_change_password_for_self( self, role: UserRoleInput, - patch_viewer: _PatchViewer, - log_in: _LogIn, - get_new_user: _GetNewUser, - passwords: Iterator[_Password], + _get_new_user: _GetNewUser, + _passwords: Iterator[_Password], ) -> None: - user = get_new_user(role) + user = _get_new_user(role) email = user.profile.email password = user.profile.password - token = user.token - new_password = f"new_password_{next(passwords)}" + (token, *_) = user.tokens + new_password = f"new_password_{next(_passwords)}" assert new_password != password - wrong_password = next(passwords) + wrong_password = next(_passwords) assert wrong_password != password for _token, _password in product((None, token), (None, wrong_password, password)): if _token == token and _password == password: continue with pytest.raises(BaseException): - patch_viewer(_token, _password, new_password=new_password) - log_in(password, email=email).__enter__() - patch_viewer((old_token := token), (old_password := password), new_password=new_password) - another_password = f"another_password_{next(passwords)}" + _patch_viewer(_token, _password, new_password=new_password) + _log_in(password, email=email) + _patch_viewer((old_token := token), (old_password := password), new_password=new_password) + another_password = f"another_password_{next(_passwords)}" with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - patch_viewer(old_token, new_password, new_password=another_password) + _patch_viewer(old_token, new_password, new_password=another_password) with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - log_in(old_password, email=email).__enter__() - new_token, _ = log_in(new_password, email=email).__enter__() + _log_in(old_password, email=email) + new_token, _ = _log_in(new_password, email=email) with pytest.raises(BaseException): - patch_viewer(new_token, old_password, new_password=another_password) + _patch_viewer(new_token, old_password, new_password=another_password) @pytest.mark.parametrize("role", list(UserRoleInput)) def test_user_can_change_username_for_self( self, role: UserRoleInput, - patch_viewer: _PatchViewer, - log_in: _LogIn, - get_new_user: _GetNewUser, - usernames: Iterator[_Username], - passwords: Iterator[_Password], + _get_new_user: _GetNewUser, + _usernames: Iterator[_Username], + _passwords: Iterator[_Password], ) -> None: - user = get_new_user(role) - token, password = user.token, user.profile.password - new_username = f"new_username_{next(usernames)}" + user = _get_new_user(role) + (token, *_), password = user.tokens, user.profile.password + new_username = f"new_username_{next(_usernames)}" for _password in (None, password): with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - patch_viewer(None, _password, new_username=new_username) - patch_viewer(token, None, new_username=new_username) - another_username = f"another_username_{next(usernames)}" - wrong_password = next(passwords) + _patch_viewer(None, _password, new_username=new_username) + _patch_viewer(token, None, new_username=new_username) + another_username = f"another_username_{next(_usernames)}" + wrong_password = next(_passwords) assert wrong_password != password - patch_viewer(token, wrong_password, new_username=another_username) + _patch_viewer(token, wrong_password, new_username=another_username) @pytest.mark.parametrize( "role,expectation", @@ -396,18 +257,16 @@ def test_only_admin_can_change_role_for_non_self( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - patch_user: _PatchUser, - log_in: _LogIn, - get_new_user: _GetNewUser, + _get_new_user: _GetNewUser, ) -> None: - user = get_new_user(role) - non_self = get_new_user(UserRoleInput.MEMBER) + user = _get_new_user(role) + non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid - token, gid = user.token, non_self.gid + (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - patch_user(None, gid, new_role=UserRoleInput.ADMIN) + _patch_user(None, gid, new_role=UserRoleInput.ADMIN) with expectation: - patch_user(token, gid, new_role=UserRoleInput.ADMIN) + _patch_user(token, gid, new_role=UserRoleInput.ADMIN) @pytest.mark.parametrize( "role,expectation", @@ -420,28 +279,26 @@ def test_only_admin_can_change_password_for_non_self( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - patch_user: _PatchUser, - log_in: _LogIn, - get_new_user: _GetNewUser, - passwords: Iterator[_Password], + _get_new_user: _GetNewUser, + _passwords: Iterator[_Password], ) -> None: - user = get_new_user(role) - non_self = get_new_user(UserRoleInput.MEMBER) + user = _get_new_user(role) + non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid old_password = non_self.profile.password - new_password = f"new_password_{next(passwords)}" + new_password = f"new_password_{next(_passwords)}" assert new_password != old_password - token, gid = user.token, non_self.gid + (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - patch_user(None, gid, new_password=new_password) + _patch_user(None, gid, new_password=new_password) with expectation as e: - patch_user(token, gid, new_password=new_password) + _patch_user(token, gid, new_password=new_password) if e: return email = non_self.profile.email with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - log_in(old_password, email=email).__enter__() - log_in(new_password, email=email).__enter__() + _log_in(old_password, email=email) + _log_in(new_password, email=email) @pytest.mark.parametrize( "role,expectation", @@ -454,26 +311,24 @@ def test_only_admin_can_change_username_for_non_self( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - patch_user: _PatchUser, - log_in: _LogIn, - get_new_user: _GetNewUser, - usernames: Iterator[_Username], + _get_new_user: _GetNewUser, + _usernames: Iterator[_Username], ) -> None: - user = get_new_user(role) - non_self = get_new_user(UserRoleInput.MEMBER) + user = _get_new_user(role) + non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid old_username = non_self.profile.username - new_username = f"new_username_{next(usernames)}" + new_username = f"new_username_{next(_usernames)}" assert new_username != old_username - token, gid = user.token, non_self.gid + (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - patch_user(None, gid, new_username=new_username) + _patch_user(None, gid, new_username=new_username) with expectation: - patch_user(token, gid, new_username=new_username) + _patch_user(token, gid, new_username=new_username) -def create_user_key(httpx_client: Callable[[], httpx.Client], token: str) -> str: - create_user_key_mutation = """ +def _create_user_key(token: str) -> str: + _create_user_key_mutation = """ mutation ($input: CreateUserApiKeyInput!) { createUserApiKey(input: $input) { apiKey { @@ -482,10 +337,10 @@ def create_user_key(httpx_client: Callable[[], httpx.Client], token: str) -> str } } """ - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "graphql"), json={ - "query": create_user_key_mutation, + "query": _create_user_key_mutation, "variables": { "input": { "name": "test", @@ -509,20 +364,17 @@ class TestApiKeys: def test_delete_user_api_key( self, - admin_email: str, - secret: str, - log_in: _LogIn, - create_user: _CreateUser, - httpx_client: Callable[[], httpx.Client], - passwords: Iterator[_Password], + _admin_email: str, + _secret: str, + _passwords: Iterator[_Password], ) -> None: member_email = "member@localhost.com" username = "member" - member_password = next(passwords) + member_password = next(_passwords) - with log_in(secret, email=admin_email) as (admin_token, _): - admin_api_key_id = create_user_key(httpx_client, admin_token) - create_user( + with _log_in(_secret, email=_admin_email) as (admin_token, _): + admin_api_key_id = _create_user_key(admin_token) + _create_user( admin_token, email=member_email, password=member_password, @@ -530,14 +382,14 @@ def test_delete_user_api_key( username=username, ) - with log_in( + with _log_in( member_password, email=member_email, ) as (member_token, _): - member_api_key_id = create_user_key(httpx_client, member_token) - member_api_key_id_2 = create_user_key(httpx_client, member_token) + member_api_key_id = _create_user_key(member_token) + member_api_key_id_2 = _create_user_key(member_token) # member can delete their own keys - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "graphql"), json={ "query": self.DELETE_USER_KEY_MUTATION, @@ -552,7 +404,7 @@ def test_delete_user_api_key( resp.raise_for_status() assert resp.json().get("errors") is None # member can't delete other user's keys - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "graphql"), json={ "query": self.DELETE_USER_KEY_MUTATION, @@ -567,7 +419,7 @@ def test_delete_user_api_key( assert len(errors := resp.json().get("errors")) == 1 assert errors[0]["message"] == "User not authorized to delete" # admin can delete their own key - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "graphql"), json={ "query": self.DELETE_USER_KEY_MUTATION, @@ -582,7 +434,7 @@ def test_delete_user_api_key( resp.raise_for_status() assert resp.json().get("errors") is None # admin can delete other user's keys - resp = httpx_client().post( + resp = _httpx_client().post( urljoin(get_base_url(), "graphql"), json={ "query": self.DELETE_USER_KEY_MUTATION, @@ -613,30 +465,27 @@ def test_headers( with_headers: bool, expires_at: Optional[datetime], expected: SpanExportResult, - span_exporter: _SpanExporterFactory, - start_span: _StartSpan, - create_system_api_key: _CreateSystemApiKey, - delete_system_api_key: _DeleteSystemApiKey, - admin_token: _Token, - fake: Faker, + _span_exporter: _SpanExporterFactory, + _admin_token: _Token, + _fake: Faker, ) -> None: - headers: Optional[Dict[str, Any]] = None + headers: Optional[_Headers] = None gid: Optional[_GqlId] = None if with_headers: - system_api_key, gid = create_system_api_key( - admin_token, - name=fake.unique.pystr(), + system_api_key, gid = _create_system_api_key( + _admin_token, + name=_fake.unique.pystr(), expires_at=expires_at, ) headers = {"authorization": f"Bearer {system_api_key}"} - export = span_exporter(headers=headers).export - project_name, span_name = fake.unique.pystr(), fake.unique.pystr() + export = _span_exporter(headers=headers).export + project_name, span_name = _fake.unique.pystr(), _fake.unique.pystr() memory = InMemorySpanExporter() - start_span(project_name=project_name, span_name=span_name, exporter=memory).end() + _start_span(project_name=project_name, span_name=span_name, exporter=memory).end() spans = memory.get_finished_spans() assert len(spans) == 1 for _ in range(2): assert export(spans) is expected if gid is not None and expected is SpanExportResult.SUCCESS: - delete_system_api_key(admin_token, gid) + _delete_system_api_key(_admin_token, gid) assert export(spans) is SpanExportResult.FAILURE diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index c2c9e7692b..7ff39782db 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import os import sys from contextlib import ExitStack, contextmanager -from functools import partial from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time -from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Protocol, cast +from typing import Any, Dict, Iterator, List, Optional, Protocol from unittest import mock from urllib.parse import urljoin from urllib.request import urlopen @@ -43,10 +44,6 @@ _Headers: TypeAlias = Dict[str, Any] -class _GetGqlSpans(Protocol): - def __call__(self, *keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: ... - - class _SpanExporterFactory(Protocol): def __call__( self, @@ -55,32 +52,13 @@ def __call__( ) -> SpanExporter: ... -class _GetTracer(Protocol): - def __call__( - self, - *, - project_name: _ProjectName, - exporter: SpanExporter, - ) -> Tracer: ... - - -class _StartSpan(Protocol): - def __call__( - self, - *, - project_name: _ProjectName, - span_name: _SpanName, - exporter: SpanExporter, - ) -> Span: ... - - @pytest.fixture(scope="class") -def fake() -> Faker: +def _fake() -> Faker: return Faker() @pytest.fixture(autouse=True, scope="class") -def env(tmp_path_factory: TempPathFactory) -> Iterator[None]: +def _env(tmp_path_factory: TempPathFactory) -> Iterator[None]: tmp = tmp_path_factory.getbasetemp() values = ( (ENV_PHOENIX_PORT, str(pick_unused_port())), @@ -101,168 +79,139 @@ def env(tmp_path_factory: TempPathFactory) -> Iterator[None]: ), ], ) -def sql_database_url(request: SubRequest) -> URL: +def _sql_database_url(request: SubRequest) -> URL: return make_url(request.param) @pytest.fixture(autouse=True, scope="class") -def env_phoenix_sql_database_url( - sql_database_url: URL, - fake: Faker, +def _env_phoenix_sql_database_url( + _sql_database_url: URL, + _fake: Faker, ) -> Iterator[None]: - values = [(ENV_PHOENIX_SQL_DATABASE_URL, sql_database_url.render_as_string())] + values = [(ENV_PHOENIX_SQL_DATABASE_URL, _sql_database_url.render_as_string())] with ExitStack() as stack: - if sql_database_url.get_backend_name().startswith("postgresql"): - schema = stack.enter_context(_random_schema(sql_database_url, fake)) + if _sql_database_url.get_backend_name().startswith("postgresql"): + schema = stack.enter_context(_random_schema(_sql_database_url, _fake)) values.append((ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema)) stack.enter_context(mock.patch.dict(os.environ, values)) yield -@pytest.fixture -def get_gql_spans( - httpx_client: Callable[[], httpx.Client], -) -> _GetGqlSpans: - def _(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: - out = "name spans{edges{node{" + " ".join(keys) + "}}}" - query = dict(query="query{projects{edges{node{" + out + "}}}}") - resp = httpx_client().post(urljoin(get_base_url(), "graphql"), json=query) - resp.raise_for_status() - resp_dict = resp.json() - assert not resp_dict.get("errors") - return { - project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] - for project in resp_dict["data"]["projects"]["edges"] - } - - return _ - - -@pytest.fixture(scope="session") -def http_span_exporter() -> _SpanExporterFactory: - def _( - *, - headers: Optional[_Headers] = None, - ) -> SpanExporter: - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - - endpoint = urljoin(get_base_url(), "v1/traces") - exporter = OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) - exporter._MAX_RETRY_TIMEOUT = 2 - return exporter - - return _ - - -@pytest.fixture(scope="session") -def grpc_span_exporter() -> _SpanExporterFactory: - def _( - *, - headers: Optional[_Headers] = None, - ) -> SpanExporter: - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter - - host = get_env_host() - if host == "0.0.0.0": - host = "127.0.0.1" - endpoint = f"http://{host}:{get_env_grpc_port()}" - return OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) - - return _ - - @pytest.fixture(scope="session", params=["http", "grpc"]) -def span_exporter(request: SubRequest) -> _SpanExporterFactory: +def _span_exporter(request: SubRequest) -> _SpanExporterFactory: if request.param == "http": - return cast(_SpanExporterFactory, request.getfixturevalue("http_span_exporter")) + return _http_span_exporter if request.param == "grpc": - return cast(_SpanExporterFactory, request.getfixturevalue("grpc_span_exporter")) + return _grpc_span_exporter raise ValueError(f"Unknown exporter: {request.param}") -@pytest.fixture(scope="session") -def get_tracer() -> _GetTracer: - def _( - *, - project_name: str, - exporter: SpanExporter, - ) -> Tracer: - resource = Resource({ResourceAttributes.PROJECT_NAME: project_name}) - tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) - return tracer_provider.get_tracer(__name__) - - return _ - - -@pytest.fixture(scope="session") -def start_span( - get_tracer: _GetTracer, -) -> _StartSpan: - def _( - *, - project_name: str, - span_name: str, - exporter: SpanExporter, - ) -> Span: - return get_tracer(project_name=project_name, exporter=exporter).start_span(span_name) - - return _ +def _get_gql_spans(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: + out = "name spans{edges{node{" + " ".join(keys) + "}}}" + query = dict(query="query{projects{edges{node{" + out + "}}}}") + resp = _httpx_client().post(urljoin(get_base_url(), "graphql"), json=query) + resp.raise_for_status() + resp_dict = resp.json() + assert not resp_dict.get("errors") + return { + project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] + for project in resp_dict["data"]["projects"]["edges"] + } + + +def _http_span_exporter( + *, + headers: Optional[_Headers] = None, +) -> SpanExporter: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + endpoint = urljoin(get_base_url(), "v1/traces") + exporter = OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) + exporter._MAX_RETRY_TIMEOUT = 2 + return exporter + + +def _grpc_span_exporter( + *, + headers: Optional[_Headers] = None, +) -> SpanExporter: + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + host = get_env_host() + if host == "0.0.0.0": + host = "127.0.0.1" + endpoint = f"http://{host}:{get_env_grpc_port()}" + return OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) + + +def _get_tracer( + *, + project_name: str, + exporter: SpanExporter, +) -> Tracer: + resource = Resource({ResourceAttributes.PROJECT_NAME: project_name}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + return tracer_provider.get_tracer(__name__) + + +def _start_span( + *, + project_name: str, + span_name: str, + exporter: SpanExporter, +) -> Span: + return _get_tracer(project_name=project_name, exporter=exporter).start_span(span_name) + + +def _httpx_client() -> httpx.Client: + # Having no timeout is useful when stepping through the debugger on the server side. + return httpx.Client(timeout=None) -@pytest.fixture(scope="session") -def httpx_client() -> Callable[[], httpx.Client]: - # Having no timeout is useful when stepping through the debugger on the server side. - return partial(httpx.Client, timeout=None) - - -@pytest.fixture(scope="session") -def server() -> Callable[[], ContextManager[None]]: - @contextmanager - def _() -> Iterator[None]: - if get_env_database_connection_str().startswith("postgresql"): - # double-check for safety - assert get_env_database_schema() - command = f"{sys.executable} -m phoenix.server.main serve" - process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=os.environ) - log: List[str] = [] - lock: Lock = Lock() - Thread(target=capture_stdout, args=(process, log, lock), daemon=True).start() - t = 60 - time_limit = time() + t - timed_out = False - url = urljoin(get_base_url(), "healthz") - while not timed_out and is_alive(process): - sleep(0.1) - try: - urlopen(url) - break - except BaseException: - timed_out = time() > time_limit +@contextmanager +def _server() -> Iterator[None]: + if get_env_database_connection_str().startswith("postgresql"): + # double-check for safety + assert get_env_database_schema() + command = f"{sys.executable} -m phoenix.server.main serve" + process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=os.environ) + log: List[str] = [] + lock: Lock = Lock() + Thread(target=_capture_stdout, args=(process, log, lock), daemon=True).start() + t = 60 + time_limit = time() + t + timed_out = False + url = urljoin(get_base_url(), "healthz") + while not timed_out and _is_alive(process): + sleep(0.1) try: - if timed_out: - raise TimeoutError(f"Server did not start within {t} seconds.") - assert is_alive(process) - with lock: - for line in log: - print(line, end="") - log.clear() - yield - process.terminate() - process.wait(10) - finally: + urlopen(url) + break + except BaseException: + timed_out = time() > time_limit + try: + if timed_out: + raise TimeoutError(f"Server did not start within {t} seconds.") + assert _is_alive(process) + with lock: for line in log: print(line, end="") - - return _ + log.clear() + yield + process.terminate() + process.wait(10) + finally: + for line in log: + print(line, end="") -def is_alive(process: Popen) -> bool: +def _is_alive(process: Popen) -> bool: return process.is_running() and process.status() != STATUS_ZOMBIE -def capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: - while is_alive(process): +def _capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: + while _is_alive(process): line = process.stdout.readline() if line or (log and log[-1] != line): with lock: @@ -270,13 +219,13 @@ def capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: @contextmanager -def _random_schema(url: URL, fake: Faker) -> Iterator[str]: +def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: engine = create_engine(url.set(drivername="postgresql+psycopg")) try: engine.connect() except OperationalError as ex: pytest.skip(f"PostgreSQL unavailable: {ex}") - schema = fake.unique.pystr().lower() + schema = _fake.unique.pystr().lower() yield schema with engine.connect() as conn: conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) diff --git a/integration_tests/mypy.ini b/integration_tests/mypy.ini index e66f0c832d..d31d7855e3 100644 --- a/integration_tests/mypy.ini +++ b/integration_tests/mypy.ini @@ -1,4 +1,3 @@ [mypy] strict = true -explicit_package_bases = true exclude = (^evals|^notebooks) diff --git a/integration_tests/py.typed b/integration_tests/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integration_tests/server/__init__.py b/integration_tests/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integration_tests/server/test_launch_app.py b/integration_tests/server/test_launch_app.py index 57c8a18bbd..d18d674bf2 100644 --- a/integration_tests/server/test_launch_app.py +++ b/integration_tests/server/test_launch_app.py @@ -1,73 +1,35 @@ import os from time import sleep -from typing import Any, Callable, ContextManager, Dict, List, Optional, Protocol, Set +from typing import Set from faker import Faker -from opentelemetry.sdk.trace.export import SpanExporter -from opentelemetry.trace import Span, Tracer -from typing_extensions import TypeAlias -_ProjectName: TypeAlias = str -_SpanName: TypeAlias = str -_Headers: TypeAlias = Dict[str, Any] - - -class _GetGqlSpans(Protocol): - def __call__(self, *keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: ... - - -class _SpanExporterFactory(Protocol): - def __call__( - self, - *, - headers: Optional[_Headers] = None, - ) -> SpanExporter: ... - - -class _GetTracer(Protocol): - def __call__( - self, - *, - project_name: _ProjectName, - exporter: SpanExporter, - ) -> Tracer: ... - - -class _StartSpan(Protocol): - def __call__( - self, - *, - project_name: _ProjectName, - span_name: _SpanName, - exporter: SpanExporter, - ) -> Span: ... +from ..conftest import ( + _get_gql_spans, + _grpc_span_exporter, + _http_span_exporter, + _server, + _start_span, +) class TestLaunchApp: - def test_send_spans( - self, - server: Callable[[], ContextManager[None]], - start_span: _StartSpan, - http_span_exporter: _SpanExporterFactory, - grpc_span_exporter: _SpanExporterFactory, - get_gql_spans: _GetGqlSpans, - fake: Faker, - ) -> None: + def test_send_spans(self, _fake: Faker) -> None: if (url := os.environ.get("PHOENIX_SQL_DATABASE_URL")) and ":memory:" in url: - # This test is not intended for a in-memory databases. + # This test is not intended for an in-memory databases. os.environ.pop("PHOENIX_SQL_DATABASE_URL", None) - project_name = fake.unique.pystr() + project_name = _fake.unique.pystr() span_names: Set[str] = set() for i in range(2): - with server(): - for j, span_exporter in enumerate([http_span_exporter, grpc_span_exporter]): - span_name = f"{i}_{j}_{fake.unique.pystr()}" + with _server(): + for j, exporter in enumerate([_http_span_exporter, _grpc_span_exporter]): + span_name = f"{i}_{j}_{_fake.unique.pystr()}" span_names.add(span_name) - start_span( + _start_span( project_name=project_name, span_name=span_name, - exporter=span_exporter(headers=None), + exporter=exporter(headers=None), ).end() sleep(2) - gql_span_names = set(span["name"] for span in get_gql_spans("name")[project_name]) + gql_span_names = set(span["name"] for span in _get_gql_spans("name")[project_name]) assert gql_span_names == span_names From d9cd19de078f90550ec5cabbe47f1b8eeb2646a9 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 10:42:03 -0700 Subject: [PATCH 02/21] clean up --- integration_tests/_helpers.py | 170 ++++++++++++++ integration_tests/auth/_helpers.py | 224 ++++++++++++++++++ integration_tests/auth/conftest.py | 241 ++------------------ integration_tests/auth/test_auth.py | 6 +- integration_tests/conftest.py | 172 +------------- integration_tests/server/test_launch_app.py | 2 +- 6 files changed, 425 insertions(+), 390 deletions(-) create mode 100644 integration_tests/_helpers.py create mode 100644 integration_tests/auth/_helpers.py diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py new file mode 100644 index 0000000000..269a262112 --- /dev/null +++ b/integration_tests/_helpers.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import os +import sys +from contextlib import contextmanager +from subprocess import PIPE, STDOUT +from threading import Lock, Thread +from time import sleep, time +from typing import Any, Dict, Iterator, List, Optional, Protocol +from urllib.parse import urljoin +from urllib.request import urlopen + +import httpx +import pytest +from faker import Faker +from openinference.semconv.resource import ResourceAttributes +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter +from opentelemetry.trace import Span, Tracer +from phoenix.config import ( + get_base_url, + get_env_database_connection_str, + get_env_database_schema, + get_env_grpc_port, + get_env_host, +) +from psutil import STATUS_ZOMBIE, Popen +from sqlalchemy import URL, create_engine, text +from sqlalchemy.exc import OperationalError +from typing_extensions import TypeAlias + +_ProjectName: TypeAlias = str +_SpanName: TypeAlias = str +_Headers: TypeAlias = Dict[str, Any] + + +class _SpanExporterConstructor(Protocol): + def __call__( + self, + *, + headers: Optional[_Headers] = None, + ) -> SpanExporter: ... + + +def _get_gql_spans(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: + out = "name spans{edges{node{" + " ".join(keys) + "}}}" + query = dict(query="query{projects{edges{node{" + out + "}}}}") + resp = _httpx_client().post(urljoin(get_base_url(), "graphql"), json=query) + resp.raise_for_status() + resp_dict = resp.json() + assert not resp_dict.get("errors") + return { + project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] + for project in resp_dict["data"]["projects"]["edges"] + } + + +def _http_span_exporter( + *, + headers: Optional[_Headers] = None, +) -> SpanExporter: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + endpoint = urljoin(get_base_url(), "v1/traces") + exporter = OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) + exporter._MAX_RETRY_TIMEOUT = 2 + return exporter + + +def _grpc_span_exporter( + *, + headers: Optional[_Headers] = None, +) -> SpanExporter: + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + host = get_env_host() + if host == "0.0.0.0": + host = "127.0.0.1" + endpoint = f"http://{host}:{get_env_grpc_port()}" + return OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) + + +def _get_tracer( + *, + project_name: str, + exporter: SpanExporter, +) -> Tracer: + resource = Resource({ResourceAttributes.PROJECT_NAME: project_name}) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) + return tracer_provider.get_tracer(__name__) + + +def _start_span( + *, + project_name: str, + span_name: str, + exporter: SpanExporter, +) -> Span: + return _get_tracer(project_name=project_name, exporter=exporter).start_span(span_name) + + +def _httpx_client() -> httpx.Client: + # Having no timeout is useful when stepping through the debugger on the server side. + return httpx.Client(timeout=None) + + +@contextmanager +def _server() -> Iterator[None]: + if get_env_database_connection_str().startswith("postgresql"): + # double-check for safety + assert get_env_database_schema() + command = f"{sys.executable} -m phoenix.server.main serve" + process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=os.environ) + log: List[str] = [] + lock: Lock = Lock() + Thread(target=_capture_stdout, args=(process, log, lock), daemon=True).start() + t = 60 + time_limit = time() + t + timed_out = False + url = urljoin(get_base_url(), "healthz") + while not timed_out and _is_alive(process): + sleep(0.1) + try: + urlopen(url) + break + except BaseException: + timed_out = time() > time_limit + try: + if timed_out: + raise TimeoutError(f"Server did not start within {t} seconds.") + assert _is_alive(process) + with lock: + for line in log: + print(line, end="") + log.clear() + yield + process.terminate() + process.wait(10) + finally: + for line in log: + print(line, end="") + + +def _is_alive(process: Popen) -> bool: + return process.is_running() and process.status() != STATUS_ZOMBIE + + +def _capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: + while _is_alive(process): + line = process.stdout.readline() + if line or (log and log[-1] != line): + with lock: + log.append(line) + + +@contextmanager +def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: + engine = create_engine(url.set(drivername="postgresql+psycopg")) + try: + engine.connect() + except OperationalError as ex: + pytest.skip(f"PostgreSQL unavailable: {ex}") + schema = _fake.unique.pystr().lower() + yield schema + with engine.connect() as conn: + conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) + conn.commit() + engine.dispose() diff --git a/integration_tests/auth/_helpers.py b/integration_tests/auth/_helpers.py new file mode 100644 index 0000000000..1cacbbfe40 --- /dev/null +++ b/integration_tests/auth/_helpers.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, Iterator, Optional, Protocol, Tuple, cast +from urllib.parse import urljoin + +import httpx +from phoenix.auth import PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME +from phoenix.config import get_base_url +from phoenix.server.api.auth import IsAdmin, IsAuthenticated +from phoenix.server.api.exceptions import Unauthorized +from phoenix.server.api.input_types.UserRoleInput import UserRoleInput +from typing_extensions import TypeAlias + +from .._helpers import _httpx_client + +_Email: TypeAlias = str +_GqlId: TypeAlias = str +_Name: TypeAlias = str +_Password: TypeAlias = str +_Token: TypeAlias = str +_Username: TypeAlias = str +_AccessToken: TypeAlias = _Token +_ApiKey: TypeAlias = _Token +_RefreshToken: TypeAlias = _Token + + +@dataclass(frozen=True) +class _Profile: + email: _Email + password: _Password + username: Optional[_Username] = None + + +@dataclass(frozen=True) +class _User: + gid: _GqlId + role: UserRoleInput + profile: _Profile + + +@dataclass(frozen=True) +class _LoggedInTokens: + access: _AccessToken + refresh: _RefreshToken + + def log_out(self) -> None: + _log_out(self.access) + + def __iter__(self) -> Iterator[_Token]: + yield self.access + yield self.refresh + + def __enter__(self) -> _LoggedInTokens: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + self.log_out() + + +@dataclass(frozen=True) +class _LoggedInUser(_User): + tokens: _LoggedInTokens + + +class _UserGenerator(Protocol): + def send(self, role: UserRoleInput) -> _LoggedInUser: ... + + +class _GetNewUser(Protocol): + def __call__(self, role: UserRoleInput) -> _LoggedInUser: ... + + +def _create_user( + token: Optional[_Token], + /, + *, + email: _Email, + password: _Password, + role: UserRoleInput, + username: Optional[_Username] = None, +) -> _GqlId: + args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] + if username: + args.append(f'username:"{username}"') + out = "user{id email role{name}}" + query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["createUser"]["user"]) + assert user["email"] == email + assert user["role"]["name"] == role.value + return cast(_GqlId, user["id"]) + + +def _patch_user( + token: Optional[_Token], + gid: _GqlId, + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, + new_role: Optional[UserRoleInput] = None, +) -> None: + args = [f'userId:"{gid}"'] + if new_password: + args.append(f'newPassword:"{new_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + if new_role: + args.append(f"newRole:{new_role.value}") + out = "user{id username role{name}}" + query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchUser"]["user"]) + assert user["id"] == gid + if new_username: + assert user["username"] == new_username + if new_role: + assert user["role"]["name"] == new_role.value + + +def _patch_viewer( + token: Optional[_Token], + current_password: Optional[_Password], + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, +) -> None: + args = [] + if new_password: + args.append(f'newPassword:"{new_password}"') + if current_password: + args.append(f'currentPassword:"{current_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + out = "user{username}" + query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchViewer"]["user"]) + if new_username: + assert user["username"] == new_username + + +def _create_system_api_key( + token: Optional[_Token], + /, + *, + name: _Name, + expires_at: Optional[datetime] = None, +) -> Tuple[_ApiKey, _GqlId]: + exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" + args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" + query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (result := resp_dict["data"]["createSystemApiKey"]) + assert (api_key := result["apiKey"]) + assert api_key["name"] == name + exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None + assert exp_t == expires_at + return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) + + +def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: + args, out = f'id:"{gid}"', "apiKeyId" + query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid + + +def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/login"), + json={"email": email, "password": password}, + ) + resp.raise_for_status() + assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) + assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + return _LoggedInTokens(access_token, refresh_token) + + +def _log_out(token: _Token, /) -> None: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/logout"), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, + ) + resp.raise_for_status() + + +def _json(resp: httpx.Response) -> Dict[str, Any]: + resp.raise_for_status() + assert (resp_dict := cast(Dict[str, Any], resp.json())) + if errers := resp_dict.get("errors"): + msg = errers[0]["message"] + if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: + raise Unauthorized(msg) + raise RuntimeError(msg) + return resp_dict diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index bbb44a9174..8a477522e1 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -3,85 +3,30 @@ import os import secrets from contextlib import ExitStack -from dataclasses import asdict, dataclass -from datetime import datetime +from dataclasses import asdict from itertools import count, starmap -from typing import Any, Dict, Generator, Iterator, Optional, Protocol, Tuple, cast +from typing import Any, Generator, Iterator, Optional, cast from unittest import mock -from urllib.parse import urljoin -import httpx import pytest from faker import Faker -from phoenix.auth import ( - PHOENIX_ACCESS_TOKEN_COOKIE_NAME, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME, - REQUIREMENTS_FOR_PHOENIX_SECRET, -) -from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET, get_base_url -from phoenix.server.api.auth import IsAdmin, IsAuthenticated -from phoenix.server.api.exceptions import Unauthorized +from phoenix.auth import REQUIREMENTS_FOR_PHOENIX_SECRET +from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from typing_extensions import TypeAlias - -from integration_tests.conftest import _httpx_client, _server - -_Email: TypeAlias = str -_GqlId: TypeAlias = str -_Name: TypeAlias = str -_Password: TypeAlias = str -_Token: TypeAlias = str -_Username: TypeAlias = str - -_AccessToken: TypeAlias = _Token -_ApiKey: TypeAlias = _Token -_RefreshToken: TypeAlias = _Token - - -@dataclass(frozen=True) -class _Profile: - email: _Email - password: _Password - username: Optional[_Username] = None - - -@dataclass(frozen=True) -class _User: - gid: _GqlId - role: UserRoleInput - profile: _Profile - - -@dataclass(frozen=True) -class _LoggedInTokens: - access: _AccessToken - refresh: _RefreshToken - - def log_out(self) -> None: - _log_out(self.access) - def __iter__(self) -> Iterator[_Token]: - yield self.access - yield self.refresh - - def __enter__(self) -> _LoggedInTokens: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.log_out() - - -@dataclass(frozen=True) -class _LoggedInUser(_User): - tokens: _LoggedInTokens - - -class _UserGenerator(Protocol): - def send(self, role: UserRoleInput) -> _LoggedInUser: ... - - -class _GetNewUser(Protocol): - def __call__(self, role: UserRoleInput) -> _LoggedInUser: ... +from .._helpers import _server +from ._helpers import ( + _create_user, + _Email, + _GetNewUser, + _log_in, + _LoggedInUser, + _Password, + _Profile, + _Token, + _UserGenerator, + _Username, +) @pytest.fixture(scope="class") @@ -169,155 +114,3 @@ def _admin_token( @pytest.fixture(scope="module") def _admin_email() -> _Email: return "admin@localhost" - - -def _create_user( - token: Optional[_Token], - /, - *, - email: _Email, - password: _Password, - role: UserRoleInput, - username: Optional[_Username] = None, -) -> _GqlId: - args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] - if username: - args.append(f'username:"{username}"') - out = "user{id email role{name}}" - query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["createUser"]["user"]) - assert user["email"] == email - assert user["role"]["name"] == role.value - return cast(_GqlId, user["id"]) - - -def _patch_user( - token: Optional[_Token], - gid: _GqlId, - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - new_role: Optional[UserRoleInput] = None, -) -> None: - args = [f'userId:"{gid}"'] - if new_password: - args.append(f'newPassword:"{new_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - if new_role: - args.append(f"newRole:{new_role.value}") - out = "user{id username role{name}}" - query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchUser"]["user"]) - assert user["id"] == gid - if new_username: - assert user["username"] == new_username - if new_role: - assert user["role"]["name"] == new_role.value - - -def _patch_viewer( - token: Optional[_Token], - current_password: Optional[_Password], - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, -) -> None: - args = [] - if new_password: - args.append(f'newPassword:"{new_password}"') - if current_password: - args.append(f'currentPassword:"{current_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - out = "user{username}" - query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchViewer"]["user"]) - if new_username: - assert user["username"] == new_username - - -def _create_system_api_key( - token: Optional[_Token], - /, - *, - name: _Name, - expires_at: Optional[datetime] = None, -) -> Tuple[_ApiKey, _GqlId]: - exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" - args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" - query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (result := resp_dict["data"]["createSystemApiKey"]) - assert (api_key := result["apiKey"]) - assert api_key["name"] == name - exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None - assert exp_t == expires_at - return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) - - -def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: - args, out = f'id:"{gid}"', "apiKeyId" - query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid - - -def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: - resp = _httpx_client().post( - urljoin(get_base_url(), "auth/login"), - json={"email": email, "password": password}, - ) - resp.raise_for_status() - assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) - return _LoggedInTokens(access_token, refresh_token) - - -def _log_out(token: _Token, /) -> None: - resp = _httpx_client().post( - urljoin(get_base_url(), "auth/logout"), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, - ) - resp.raise_for_status() - - -def _json(resp: httpx.Response) -> Dict[str, Any]: - resp.raise_for_status() - assert (resp_dict := cast(Dict[str, Any], resp.json())) - if errers := resp_dict.get("errors"): - msg = errers[0]["message"] - if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: - raise Unauthorized(msg) - raise RuntimeError(msg) - return resp_dict diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 0ef06935b3..b053a6bde2 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -16,8 +16,8 @@ from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from ..conftest import _Headers, _httpx_client, _SpanExporterFactory, _start_span -from .conftest import ( +from .._helpers import _Headers, _httpx_client, _SpanExporterConstructor, _start_span +from ._helpers import ( _create_system_api_key, _create_user, _delete_system_api_key, @@ -465,7 +465,7 @@ def test_headers( with_headers: bool, expires_at: Optional[datetime], expected: SpanExportResult, - _span_exporter: _SpanExporterFactory, + _span_exporter: _SpanExporterConstructor, _admin_token: _Token, _fake: Faker, ) -> None: diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index 7ff39782db..5635cb49bb 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -1,55 +1,30 @@ from __future__ import annotations import os -import sys -from contextlib import ExitStack, contextmanager -from subprocess import PIPE, STDOUT -from threading import Lock, Thread -from time import sleep, time -from typing import Any, Dict, Iterator, List, Optional, Protocol +from contextlib import ExitStack +from typing import Iterator from unittest import mock -from urllib.parse import urljoin -from urllib.request import urlopen -import httpx import pytest from _pytest.fixtures import SubRequest from _pytest.tmpdir import TempPathFactory from faker import Faker -from openinference.semconv.resource import ResourceAttributes -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter -from opentelemetry.trace import Span, Tracer from phoenix.config import ( ENV_PHOENIX_GRPC_PORT, ENV_PHOENIX_PORT, ENV_PHOENIX_SQL_DATABASE_SCHEMA, ENV_PHOENIX_SQL_DATABASE_URL, ENV_PHOENIX_WORKING_DIR, - get_base_url, - get_env_database_connection_str, - get_env_database_schema, - get_env_grpc_port, - get_env_host, ) from portpicker import pick_unused_port # type: ignore[import-untyped] -from psutil import STATUS_ZOMBIE, Popen -from sqlalchemy import URL, create_engine, make_url, text -from sqlalchemy.exc import OperationalError -from typing_extensions import TypeAlias +from sqlalchemy import URL, make_url -_ProjectName: TypeAlias = str -_SpanName: TypeAlias = str -_Headers: TypeAlias = Dict[str, Any] - - -class _SpanExporterFactory(Protocol): - def __call__( - self, - *, - headers: Optional[_Headers] = None, - ) -> SpanExporter: ... +from ._helpers import ( + _grpc_span_exporter, + _http_span_exporter, + _random_schema, + _SpanExporterConstructor, +) @pytest.fixture(scope="class") @@ -98,136 +73,9 @@ def _env_phoenix_sql_database_url( @pytest.fixture(scope="session", params=["http", "grpc"]) -def _span_exporter(request: SubRequest) -> _SpanExporterFactory: +def _span_exporter(request: SubRequest) -> _SpanExporterConstructor: if request.param == "http": return _http_span_exporter if request.param == "grpc": return _grpc_span_exporter raise ValueError(f"Unknown exporter: {request.param}") - - -def _get_gql_spans(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: - out = "name spans{edges{node{" + " ".join(keys) + "}}}" - query = dict(query="query{projects{edges{node{" + out + "}}}}") - resp = _httpx_client().post(urljoin(get_base_url(), "graphql"), json=query) - resp.raise_for_status() - resp_dict = resp.json() - assert not resp_dict.get("errors") - return { - project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] - for project in resp_dict["data"]["projects"]["edges"] - } - - -def _http_span_exporter( - *, - headers: Optional[_Headers] = None, -) -> SpanExporter: - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - - endpoint = urljoin(get_base_url(), "v1/traces") - exporter = OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) - exporter._MAX_RETRY_TIMEOUT = 2 - return exporter - - -def _grpc_span_exporter( - *, - headers: Optional[_Headers] = None, -) -> SpanExporter: - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter - - host = get_env_host() - if host == "0.0.0.0": - host = "127.0.0.1" - endpoint = f"http://{host}:{get_env_grpc_port()}" - return OTLPSpanExporter(endpoint=endpoint, headers=headers, timeout=1) - - -def _get_tracer( - *, - project_name: str, - exporter: SpanExporter, -) -> Tracer: - resource = Resource({ResourceAttributes.PROJECT_NAME: project_name}) - tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) - return tracer_provider.get_tracer(__name__) - - -def _start_span( - *, - project_name: str, - span_name: str, - exporter: SpanExporter, -) -> Span: - return _get_tracer(project_name=project_name, exporter=exporter).start_span(span_name) - - -def _httpx_client() -> httpx.Client: - # Having no timeout is useful when stepping through the debugger on the server side. - return httpx.Client(timeout=None) - - -@contextmanager -def _server() -> Iterator[None]: - if get_env_database_connection_str().startswith("postgresql"): - # double-check for safety - assert get_env_database_schema() - command = f"{sys.executable} -m phoenix.server.main serve" - process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=os.environ) - log: List[str] = [] - lock: Lock = Lock() - Thread(target=_capture_stdout, args=(process, log, lock), daemon=True).start() - t = 60 - time_limit = time() + t - timed_out = False - url = urljoin(get_base_url(), "healthz") - while not timed_out and _is_alive(process): - sleep(0.1) - try: - urlopen(url) - break - except BaseException: - timed_out = time() > time_limit - try: - if timed_out: - raise TimeoutError(f"Server did not start within {t} seconds.") - assert _is_alive(process) - with lock: - for line in log: - print(line, end="") - log.clear() - yield - process.terminate() - process.wait(10) - finally: - for line in log: - print(line, end="") - - -def _is_alive(process: Popen) -> bool: - return process.is_running() and process.status() != STATUS_ZOMBIE - - -def _capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: - while _is_alive(process): - line = process.stdout.readline() - if line or (log and log[-1] != line): - with lock: - log.append(line) - - -@contextmanager -def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: - engine = create_engine(url.set(drivername="postgresql+psycopg")) - try: - engine.connect() - except OperationalError as ex: - pytest.skip(f"PostgreSQL unavailable: {ex}") - schema = _fake.unique.pystr().lower() - yield schema - with engine.connect() as conn: - conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) - conn.commit() - engine.dispose() diff --git a/integration_tests/server/test_launch_app.py b/integration_tests/server/test_launch_app.py index d18d674bf2..14175f50ce 100644 --- a/integration_tests/server/test_launch_app.py +++ b/integration_tests/server/test_launch_app.py @@ -4,7 +4,7 @@ from faker import Faker -from ..conftest import ( +from .._helpers import ( _get_gql_spans, _grpc_span_exporter, _http_span_exporter, From f53a5c19633b19f051c3be9fdd9155cee2f64b04 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 10:46:56 -0700 Subject: [PATCH 03/21] clean up --- integration_tests/_helpers.py | 215 +++++++++++++++++++++++++- integration_tests/auth/_helpers.py | 224 ---------------------------- integration_tests/auth/conftest.py | 4 +- integration_tests/auth/test_auth.py | 7 +- 4 files changed, 221 insertions(+), 229 deletions(-) delete mode 100644 integration_tests/auth/_helpers.py diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 269a262112..25a979f287 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -3,10 +3,12 @@ import os import sys from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time -from typing import Any, Dict, Iterator, List, Optional, Protocol +from typing import Any, Dict, Iterator, List, Optional, Protocol, Tuple, cast from urllib.parse import urljoin from urllib.request import urlopen @@ -18,6 +20,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter from opentelemetry.trace import Span, Tracer +from phoenix.auth import PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME from phoenix.config import ( get_base_url, get_env_database_connection_str, @@ -25,16 +28,74 @@ get_env_grpc_port, get_env_host, ) +from phoenix.server.api.auth import IsAdmin, IsAuthenticated +from phoenix.server.api.exceptions import Unauthorized +from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from psutil import STATUS_ZOMBIE, Popen from sqlalchemy import URL, create_engine, text from sqlalchemy.exc import OperationalError from typing_extensions import TypeAlias +_Email: TypeAlias = str +_GqlId: TypeAlias = str +_Name: TypeAlias = str +_Password: TypeAlias = str +_Token: TypeAlias = str +_Username: TypeAlias = str +_AccessToken: TypeAlias = _Token +_ApiKey: TypeAlias = _Token +_RefreshToken: TypeAlias = _Token _ProjectName: TypeAlias = str _SpanName: TypeAlias = str _Headers: TypeAlias = Dict[str, Any] +@dataclass(frozen=True) +class _Profile: + email: _Email + password: _Password + username: Optional[_Username] = None + + +@dataclass(frozen=True) +class _User: + gid: _GqlId + role: UserRoleInput + profile: _Profile + + +@dataclass(frozen=True) +class _LoggedInTokens: + access: _AccessToken + refresh: _RefreshToken + + def log_out(self) -> None: + _log_out(self.access) + + def __iter__(self) -> Iterator[_Token]: + yield self.access + yield self.refresh + + def __enter__(self) -> _LoggedInTokens: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + self.log_out() + + +@dataclass(frozen=True) +class _LoggedInUser(_User): + tokens: _LoggedInTokens + + +class _UserGenerator(Protocol): + def send(self, role: UserRoleInput) -> _LoggedInUser: ... + + +class _GetNewUser(Protocol): + def __call__(self, role: UserRoleInput) -> _LoggedInUser: ... + + class _SpanExporterConstructor(Protocol): def __call__( self, @@ -168,3 +229,155 @@ def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) conn.commit() engine.dispose() + + +def _create_user( + token: Optional[_Token], + /, + *, + email: _Email, + password: _Password, + role: UserRoleInput, + username: Optional[_Username] = None, +) -> _GqlId: + args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] + if username: + args.append(f'username:"{username}"') + out = "user{id email role{name}}" + query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["createUser"]["user"]) + assert user["email"] == email + assert user["role"]["name"] == role.value + return cast(_GqlId, user["id"]) + + +def _patch_user( + token: Optional[_Token], + gid: _GqlId, + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, + new_role: Optional[UserRoleInput] = None, +) -> None: + args = [f'userId:"{gid}"'] + if new_password: + args.append(f'newPassword:"{new_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + if new_role: + args.append(f"newRole:{new_role.value}") + out = "user{id username role{name}}" + query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchUser"]["user"]) + assert user["id"] == gid + if new_username: + assert user["username"] == new_username + if new_role: + assert user["role"]["name"] == new_role.value + + +def _patch_viewer( + token: Optional[_Token], + current_password: Optional[_Password], + /, + *, + new_username: Optional[_Username] = None, + new_password: Optional[_Password] = None, +) -> None: + args = [] + if new_password: + args.append(f'newPassword:"{new_password}"') + if current_password: + args.append(f'currentPassword:"{current_password}"') + if new_username: + args.append(f'newUsername:"{new_username}"') + out = "user{username}" + query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (user := resp_dict["data"]["patchViewer"]["user"]) + if new_username: + assert user["username"] == new_username + + +def _create_system_api_key( + token: Optional[_Token], + /, + *, + name: _Name, + expires_at: Optional[datetime] = None, +) -> Tuple[_ApiKey, _GqlId]: + exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" + args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" + query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert (result := resp_dict["data"]["createSystemApiKey"]) + assert (api_key := result["apiKey"]) + assert api_key["name"] == name + exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None + assert exp_t == expires_at + return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) + + +def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: + args, out = f'id:"{gid}"', "apiKeyId" + query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" + resp = _httpx_client().post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, + ) + resp_dict = _json(resp) + assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid + + +def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/login"), + json={"email": email, "password": password}, + ) + resp.raise_for_status() + assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) + assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + return _LoggedInTokens(access_token, refresh_token) + + +def _log_out(token: _Token, /) -> None: + resp = _httpx_client().post( + urljoin(get_base_url(), "auth/logout"), + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, + ) + resp.raise_for_status() + + +def _json(resp: httpx.Response) -> Dict[str, Any]: + resp.raise_for_status() + assert (resp_dict := cast(Dict[str, Any], resp.json())) + if errers := resp_dict.get("errors"): + msg = errers[0]["message"] + if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: + raise Unauthorized(msg) + raise RuntimeError(msg) + return resp_dict diff --git a/integration_tests/auth/_helpers.py b/integration_tests/auth/_helpers.py deleted file mode 100644 index 1cacbbfe40..0000000000 --- a/integration_tests/auth/_helpers.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Dict, Iterator, Optional, Protocol, Tuple, cast -from urllib.parse import urljoin - -import httpx -from phoenix.auth import PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME -from phoenix.config import get_base_url -from phoenix.server.api.auth import IsAdmin, IsAuthenticated -from phoenix.server.api.exceptions import Unauthorized -from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from typing_extensions import TypeAlias - -from .._helpers import _httpx_client - -_Email: TypeAlias = str -_GqlId: TypeAlias = str -_Name: TypeAlias = str -_Password: TypeAlias = str -_Token: TypeAlias = str -_Username: TypeAlias = str -_AccessToken: TypeAlias = _Token -_ApiKey: TypeAlias = _Token -_RefreshToken: TypeAlias = _Token - - -@dataclass(frozen=True) -class _Profile: - email: _Email - password: _Password - username: Optional[_Username] = None - - -@dataclass(frozen=True) -class _User: - gid: _GqlId - role: UserRoleInput - profile: _Profile - - -@dataclass(frozen=True) -class _LoggedInTokens: - access: _AccessToken - refresh: _RefreshToken - - def log_out(self) -> None: - _log_out(self.access) - - def __iter__(self) -> Iterator[_Token]: - yield self.access - yield self.refresh - - def __enter__(self) -> _LoggedInTokens: - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.log_out() - - -@dataclass(frozen=True) -class _LoggedInUser(_User): - tokens: _LoggedInTokens - - -class _UserGenerator(Protocol): - def send(self, role: UserRoleInput) -> _LoggedInUser: ... - - -class _GetNewUser(Protocol): - def __call__(self, role: UserRoleInput) -> _LoggedInUser: ... - - -def _create_user( - token: Optional[_Token], - /, - *, - email: _Email, - password: _Password, - role: UserRoleInput, - username: Optional[_Username] = None, -) -> _GqlId: - args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] - if username: - args.append(f'username:"{username}"') - out = "user{id email role{name}}" - query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["createUser"]["user"]) - assert user["email"] == email - assert user["role"]["name"] == role.value - return cast(_GqlId, user["id"]) - - -def _patch_user( - token: Optional[_Token], - gid: _GqlId, - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, - new_role: Optional[UserRoleInput] = None, -) -> None: - args = [f'userId:"{gid}"'] - if new_password: - args.append(f'newPassword:"{new_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - if new_role: - args.append(f"newRole:{new_role.value}") - out = "user{id username role{name}}" - query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchUser"]["user"]) - assert user["id"] == gid - if new_username: - assert user["username"] == new_username - if new_role: - assert user["role"]["name"] == new_role.value - - -def _patch_viewer( - token: Optional[_Token], - current_password: Optional[_Password], - /, - *, - new_username: Optional[_Username] = None, - new_password: Optional[_Password] = None, -) -> None: - args = [] - if new_password: - args.append(f'newPassword:"{new_password}"') - if current_password: - args.append(f'currentPassword:"{current_password}"') - if new_username: - args.append(f'newUsername:"{new_username}"') - out = "user{username}" - query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (user := resp_dict["data"]["patchViewer"]["user"]) - if new_username: - assert user["username"] == new_username - - -def _create_system_api_key( - token: Optional[_Token], - /, - *, - name: _Name, - expires_at: Optional[datetime] = None, -) -> Tuple[_ApiKey, _GqlId]: - exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" - args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" - query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert (result := resp_dict["data"]["createSystemApiKey"]) - assert (api_key := result["apiKey"]) - assert api_key["name"] == name - exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None - assert exp_t == expires_at - return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) - - -def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: - args, out = f'id:"{gid}"', "apiKeyId" - query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) - assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid - - -def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: - resp = _httpx_client().post( - urljoin(get_base_url(), "auth/login"), - json={"email": email, "password": password}, - ) - resp.raise_for_status() - assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) - return _LoggedInTokens(access_token, refresh_token) - - -def _log_out(token: _Token, /) -> None: - resp = _httpx_client().post( - urljoin(get_base_url(), "auth/logout"), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, - ) - resp.raise_for_status() - - -def _json(resp: httpx.Response) -> Dict[str, Any]: - resp.raise_for_status() - assert (resp_dict := cast(Dict[str, Any], resp.json())) - if errers := resp_dict.get("errors"): - msg = errers[0]["message"] - if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: - raise Unauthorized(msg) - raise RuntimeError(msg) - return resp_dict diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 8a477522e1..32ef3f55b3 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -14,8 +14,7 @@ from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from .._helpers import _server -from ._helpers import ( +from .._helpers import ( _create_user, _Email, _GetNewUser, @@ -23,6 +22,7 @@ _LoggedInUser, _Password, _Profile, + _server, _Token, _UserGenerator, _Username, diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index b053a6bde2..f6e4120cbe 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -16,18 +16,21 @@ from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput -from .._helpers import _Headers, _httpx_client, _SpanExporterConstructor, _start_span -from ._helpers import ( +from .._helpers import ( _create_system_api_key, _create_user, _delete_system_api_key, _GetNewUser, _GqlId, + _Headers, + _httpx_client, _log_in, _Password, _patch_user, _patch_viewer, _Profile, + _SpanExporterConstructor, + _start_span, _Token, _Username, ) From 06ad909d34ecae88e9f6212a28dca1ff7ca953ce Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 10:48:30 -0700 Subject: [PATCH 04/21] clean up --- integration_tests/__init__.py | 1 - integration_tests/auth/conftest.py | 2 -- integration_tests/conftest.py | 2 -- integration_tests/py.typed | 0 4 files changed, 5 deletions(-) delete mode 100644 integration_tests/py.typed diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py index 9d48db4f9f..e69de29bb2 100644 --- a/integration_tests/__init__.py +++ b/integration_tests/__init__.py @@ -1 +0,0 @@ -from __future__ import annotations diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 32ef3f55b3..cd86e2ea13 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import secrets from contextlib import ExitStack diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index 5635cb49bb..40c0c2b35d 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os from contextlib import ExitStack from typing import Iterator diff --git a/integration_tests/py.typed b/integration_tests/py.typed deleted file mode 100644 index e69de29bb2..0000000000 From 53b5ace8b9edf3025baa41e1adcfd987e53e6f59 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 14:25:03 -0700 Subject: [PATCH 05/21] clean up --- integration_tests/_helpers.py | 104 +++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 25a979f287..1549202b75 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -8,7 +8,7 @@ from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time -from typing import Any, Dict, Iterator, List, Optional, Protocol, Tuple, cast +from typing import Any, Dict, Iterator, List, Mapping, Optional, Protocol, Tuple, cast from urllib.parse import urljoin from urllib.request import urlopen @@ -162,9 +162,17 @@ def _start_span( return _get_tracer(project_name=project_name, exporter=exporter).start_span(span_name) -def _httpx_client() -> httpx.Client: +def _httpx_client( + access_token: Optional[_Token] = None, + refresh_token: Optional[_Token] = None, + cookies: Optional[Dict[str, Any]] = None, +) -> httpx.Client: + if access_token: + cookies = {**(cookies or {}), PHOENIX_ACCESS_TOKEN_COOKIE_NAME: access_token} + if refresh_token: + cookies = {**(cookies or {}), PHOENIX_REFRESH_TOKEN_COOKIE_NAME: refresh_token} # Having no timeout is useful when stepping through the debugger on the server side. - return httpx.Client(timeout=None) + return httpx.Client(timeout=None, cookies=cookies) @contextmanager @@ -208,7 +216,11 @@ def _is_alive(process: Popen) -> bool: return process.is_running() and process.status() != STATUS_ZOMBIE -def _capture_stdout(process: Popen, log: List[str], lock: Lock) -> None: +def _capture_stdout( + process: Popen, + log: List[str], + lock: Lock, +) -> None: while _is_alive(process): line = process.stdout.readline() if line or (log and log[-1] != line): @@ -231,8 +243,22 @@ def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: engine.dispose() +def _gql( + access_token: Optional[_Token], + /, + *, + query: str, + variables: Optional[Mapping[str, Any]] = None, +) -> Dict[str, Any]: + resp = _httpx_client(access_token).post( + urljoin(get_base_url(), "graphql"), + json=dict(query=query, variables=dict(variables or {})), + ) + return _json(resp) + + def _create_user( - token: Optional[_Token], + access_token: Optional[_Token], /, *, email: _Email, @@ -245,12 +271,7 @@ def _create_user( args.append(f'username:"{username}"') out = "user{id email role{name}}" query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) + resp_dict = _gql(access_token, query=query) assert (user := resp_dict["data"]["createUser"]["user"]) assert user["email"] == email assert user["role"]["name"] == role.value @@ -258,7 +279,7 @@ def _create_user( def _patch_user( - token: Optional[_Token], + access_token: Optional[_Token], gid: _GqlId, /, *, @@ -275,12 +296,7 @@ def _patch_user( args.append(f"newRole:{new_role.value}") out = "user{id username role{name}}" query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) + resp_dict = _gql(access_token, query=query) assert (user := resp_dict["data"]["patchUser"]["user"]) assert user["id"] == gid if new_username: @@ -290,7 +306,7 @@ def _patch_user( def _patch_viewer( - token: Optional[_Token], + access_token: Optional[_Token], current_password: Optional[_Password], /, *, @@ -306,19 +322,14 @@ def _patch_viewer( args.append(f'newUsername:"{new_username}"') out = "user{username}" query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) + resp_dict = _gql(access_token, query=query) assert (user := resp_dict["data"]["patchViewer"]["user"]) if new_username: assert user["username"] == new_username def _create_system_api_key( - token: Optional[_Token], + access_token: Optional[_Token], /, *, name: _Name, @@ -327,12 +338,7 @@ def _create_system_api_key( exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else "" args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" query = "mutation{createSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) + resp_dict = _gql(access_token, query=query) assert (result := resp_dict["data"]["createSystemApiKey"]) assert (api_key := result["apiKey"]) assert api_key["name"] == name @@ -341,19 +347,23 @@ def _create_system_api_key( return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) -def _delete_system_api_key(token: Optional[_Token], gid: _GqlId, /) -> None: +def _delete_system_api_key( + access_token: Optional[_Token], + gid: _GqlId, + /, +) -> None: args, out = f'id:"{gid}"', "apiKeyId" query = "mutation{deleteSystemApiKey(input:{" + args + "}){" + out + "}}" - resp = _httpx_client().post( - urljoin(get_base_url(), "graphql"), - json=dict(query=query), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token} if token else {}, - ) - resp_dict = _json(resp) + resp_dict = _gql(access_token, query=query) assert resp_dict["data"]["deleteSystemApiKey"]["apiKeyId"] == gid -def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: +def _log_in( + password: _Password, + /, + *, + email: _Email, +) -> _LoggedInTokens: resp = _httpx_client().post( urljoin(get_base_url(), "auth/login"), json={"email": email, "password": password}, @@ -364,15 +374,17 @@ def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: return _LoggedInTokens(access_token, refresh_token) -def _log_out(token: _Token, /) -> None: - resp = _httpx_client().post( - urljoin(get_base_url(), "auth/logout"), - cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, - ) +def _log_out( + access_token: _Token, + /, +) -> None: + resp = _httpx_client(access_token).post(urljoin(get_base_url(), "auth/logout")) resp.raise_for_status() -def _json(resp: httpx.Response) -> Dict[str, Any]: +def _json( + resp: httpx.Response, +) -> Dict[str, Any]: resp.raise_for_status() assert (resp_dict := cast(Dict[str, Any], resp.json())) if errers := resp_dict.get("errors"): From 1f40e843f620232ae904fe764ffbda1618579929 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 15:49:55 -0700 Subject: [PATCH 06/21] clean up --- integration_tests/_helpers.py | 115 +++++++++++++++++----------- integration_tests/auth/conftest.py | 6 +- integration_tests/auth/test_auth.py | 76 +++++++----------- 3 files changed, 103 insertions(+), 94 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 1549202b75..803bf9bf94 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -2,13 +2,25 @@ import os import sys +from abc import ABC from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time -from typing import Any, Dict, Iterator, List, Mapping, Optional, Protocol, Tuple, cast +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Tuple, + cast, +) from urllib.parse import urljoin from urllib.request import urlopen @@ -36,18 +48,21 @@ from sqlalchemy.exc import OperationalError from typing_extensions import TypeAlias -_Email: TypeAlias = str -_GqlId: TypeAlias = str -_Name: TypeAlias = str -_Password: TypeAlias = str -_Token: TypeAlias = str -_Username: TypeAlias = str -_AccessToken: TypeAlias = _Token -_ApiKey: TypeAlias = _Token -_RefreshToken: TypeAlias = _Token _ProjectName: TypeAlias = str _SpanName: TypeAlias = str _Headers: TypeAlias = Dict[str, Any] +_Name: TypeAlias = str + + +class _String(str, ABC): + def __new__(cls, string: Optional[str] = None) -> _String: + assert string + return super().__new__(cls, string) + + +_Email: TypeAlias = str +_Password: TypeAlias = str +_Username: TypeAlias = str @dataclass(frozen=True) @@ -57,6 +72,9 @@ class _Profile: username: Optional[_Username] = None +class _GqlId(_String): ... + + @dataclass(frozen=True) class _User: gid: _GqlId @@ -64,18 +82,25 @@ class _User: profile: _Profile -@dataclass(frozen=True) -class _LoggedInTokens: +class _Token(_String, ABC): ... + + +class _AccessToken(_Token): ... + + +class _ApiKey(_Token): ... + + +class _RefreshToken(_Token): ... + + +class _LoggedInTokens(NamedTuple): access: _AccessToken refresh: _RefreshToken def log_out(self) -> None: _log_out(self.access) - def __iter__(self) -> Iterator[_Token]: - yield self.access - yield self.refresh - def __enter__(self) -> _LoggedInTokens: return self @@ -104,19 +129,6 @@ def __call__( ) -> SpanExporter: ... -def _get_gql_spans(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: - out = "name spans{edges{node{" + " ".join(keys) + "}}}" - query = dict(query="query{projects{edges{node{" + out + "}}}}") - resp = _httpx_client().post(urljoin(get_base_url(), "graphql"), json=query) - resp.raise_for_status() - resp_dict = resp.json() - assert not resp_dict.get("errors") - return { - project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] - for project in resp_dict["data"]["projects"]["edges"] - } - - def _http_span_exporter( *, headers: Optional[_Headers] = None, @@ -163,8 +175,8 @@ def _start_span( def _httpx_client( - access_token: Optional[_Token] = None, - refresh_token: Optional[_Token] = None, + access_token: Optional[_AccessToken] = None, + refresh_token: Optional[_RefreshToken] = None, cookies: Optional[Dict[str, Any]] = None, ) -> httpx.Client: if access_token: @@ -244,7 +256,7 @@ def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: def _gql( - access_token: Optional[_Token], + access_token: Optional[_AccessToken] = None, /, *, query: str, @@ -257,8 +269,19 @@ def _gql( return _json(resp) +def _get_gql_spans(*keys: str) -> Dict[_ProjectName, List[Dict[str, Any]]]: + out = "name spans{edges{node{" + " ".join(keys) + "}}}" + query = "query{projects{edges{node{" + out + "}}}}" + resp_dict = _gql(query=query) + assert not resp_dict.get("errors") + return { + project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] + for project in resp_dict["data"]["projects"]["edges"] + } + + def _create_user( - access_token: Optional[_Token], + access_token: Optional[_AccessToken] = None, /, *, email: _Email, @@ -275,12 +298,12 @@ def _create_user( assert (user := resp_dict["data"]["createUser"]["user"]) assert user["email"] == email assert user["role"]["name"] == role.value - return cast(_GqlId, user["id"]) + return _GqlId(user["id"]) def _patch_user( - access_token: Optional[_Token], gid: _GqlId, + access_token: Optional[_AccessToken] = None, /, *, new_username: Optional[_Username] = None, @@ -306,8 +329,8 @@ def _patch_user( def _patch_viewer( - access_token: Optional[_Token], - current_password: Optional[_Password], + access_token: Optional[_AccessToken] = None, + current_password: Optional[_Password] = None, /, *, new_username: Optional[_Username] = None, @@ -329,7 +352,7 @@ def _patch_viewer( def _create_system_api_key( - access_token: Optional[_Token], + access_token: Optional[_AccessToken] = None, /, *, name: _Name, @@ -344,12 +367,14 @@ def _create_system_api_key( assert api_key["name"] == name exp_t = datetime.fromisoformat(api_key["expiresAt"]) if api_key["expiresAt"] else None assert exp_t == expires_at - return cast(_ApiKey, result["jwt"]), cast(_GqlId, api_key["id"]) + assert (jwt := result["jwt"]) + assert (id_ := api_key["id"]) + return _ApiKey(jwt), _GqlId(id_) def _delete_system_api_key( - access_token: Optional[_Token], gid: _GqlId, + access_token: Optional[_AccessToken] = None, /, ) -> None: args, out = f'id:"{gid}"', "apiKeyId" @@ -371,14 +396,18 @@ def _log_in( resp.raise_for_status() assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) - return _LoggedInTokens(access_token, refresh_token) + return _LoggedInTokens(_AccessToken(access_token), _RefreshToken(refresh_token)) def _log_out( - access_token: _Token, + access_token: Optional[_AccessToken] = None, + refresh_token: Optional[_RefreshToken] = None, /, ) -> None: - resp = _httpx_client(access_token).post(urljoin(get_base_url(), "auth/logout")) + resp = _httpx_client( + access_token, + refresh_token, + ).post(urljoin(get_base_url(), "auth/logout")) resp.raise_for_status() diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index cd86e2ea13..3dd39c6b4c 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -13,6 +13,7 @@ from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from .._helpers import ( + _AccessToken, _create_user, _Email, _GetNewUser, @@ -21,7 +22,6 @@ _Password, _Profile, _server, - _Token, _UserGenerator, _Username, ) @@ -74,7 +74,7 @@ def _profiles( @pytest.fixture def _users( _profiles: Iterator[_Profile], - _admin_token: _Token, + _admin_token: _AccessToken, _fake: Faker, ) -> _UserGenerator: def _() -> Generator[Optional[_LoggedInUser], UserRoleInput, None]: @@ -104,7 +104,7 @@ def _(role: UserRoleInput) -> _LoggedInUser: def _admin_token( _admin_email: str, _secret: str, -) -> Iterator[_Token]: +) -> Iterator[_AccessToken]: with _log_in(_secret, email=_admin_email) as (token, _): yield token diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index f6e4120cbe..f0afdb39fb 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -17,21 +17,24 @@ from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from .._helpers import ( + _AccessToken, _create_system_api_key, _create_user, _delete_system_api_key, + _Email, _GetNewUser, _GqlId, _Headers, _httpx_client, _log_in, + _log_out, _Password, _patch_user, _patch_viewer, _Profile, + _RefreshToken, _SpanExporterConstructor, _start_span, - _Token, _Username, ) @@ -68,7 +71,7 @@ class TestUsers: ) def test_admin( self, - email: str, + email: _Email, use_secret: bool, expectation: ContextManager[Optional[Unauthorized]], _secret: str, @@ -86,44 +89,34 @@ def test_admin( def test_end_to_end_credentials_flow( self, - _admin_email: str, - _secret: str, + _admin_email: _Email, + _secret: _Password, _fake: Faker, ) -> None: # user logs into first browser - resp = _httpx_client().post( - urljoin(get_base_url(), "/auth/login"), - json={"email": _admin_email, "password": _secret}, - ) - resp.raise_for_status() - assert (browser_0_access_token_0 := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert (browser_0_refresh_token_0 := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + browser_0_access_token_0, browser_0_refresh_token_0 = _log_in(_secret, email=_admin_email) # user creates api key in the first browser _create_system_api_key(browser_0_access_token_0, name="api-key-0") # tokens are refreshed in the first browser - resp = _httpx_client().post( - urljoin(get_base_url(), "/auth/refresh"), - cookies={ - PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_0, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME: browser_0_refresh_token_0, - }, + resp = _httpx_client(browser_0_access_token_0, browser_0_refresh_token_0).post( + urljoin(get_base_url(), "auth/refresh"), ) resp.raise_for_status() - assert (browser_0_access_token_1 := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert (browser_0_refresh_token_1 := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + browser_0_access_token_1 = _AccessToken(resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) + browser_0_refresh_token_1 = _RefreshToken( + resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) + ) + assert browser_0_access_token_1 + assert browser_0_refresh_token_1 # user creates api key in the first browser _create_system_api_key(browser_0_access_token_1, name="api-key-1") # refresh token is good for one use only - resp = _httpx_client().post( - urljoin(get_base_url(), "/auth/refresh"), - cookies={ - PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_0, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME: browser_0_refresh_token_0, - }, + resp = _httpx_client(browser_0_access_token_0, browser_0_refresh_token_0).post( + urljoin(get_base_url(), "auth/refresh"), ) with pytest.raises(HTTPStatusError): resp.raise_for_status() @@ -133,26 +126,13 @@ def test_end_to_end_credentials_flow( _create_system_api_key(browser_0_access_token_0, name="api-key-2") # user logs into second browser - resp = _httpx_client().post( - urljoin(get_base_url(), "/auth/login"), - json={"email": _admin_email, "password": _secret}, - ) - resp.raise_for_status() - assert (browser_1_access_token_0 := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) - assert resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) + browser_1_access_token_0, browser_1_refresh_token_0 = _log_in(_secret, email=_admin_email) # user creates api key in the second browser _create_system_api_key(browser_1_access_token_0, name="api-key-3") # user logs out in first browser - resp = _httpx_client().post( - urljoin(get_base_url(), "/auth/logout"), - cookies={ - PHOENIX_ACCESS_TOKEN_COOKIE_NAME: browser_0_access_token_1, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME: browser_0_refresh_token_1, - }, - ) - resp.raise_for_status() + _log_out(browser_0_access_token_1, browser_0_refresh_token_1) # user is logged out of both browsers with pytest.raises(HTTPStatusError, match="401 Unauthorized"): @@ -267,9 +247,9 @@ def test_only_admin_can_change_role_for_non_self( assert user.gid != non_self.gid (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - _patch_user(None, gid, new_role=UserRoleInput.ADMIN) + _patch_user(gid, new_role=UserRoleInput.ADMIN) with expectation: - _patch_user(token, gid, new_role=UserRoleInput.ADMIN) + _patch_user(gid, token, new_role=UserRoleInput.ADMIN) @pytest.mark.parametrize( "role,expectation", @@ -293,9 +273,9 @@ def test_only_admin_can_change_password_for_non_self( assert new_password != old_password (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - _patch_user(None, gid, new_password=new_password) + _patch_user(gid, new_password=new_password) with expectation as e: - _patch_user(token, gid, new_password=new_password) + _patch_user(gid, token, new_password=new_password) if e: return email = non_self.profile.email @@ -325,9 +305,9 @@ def test_only_admin_can_change_username_for_non_self( assert new_username != old_username (token, *_), gid = user.tokens, non_self.gid with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - _patch_user(None, gid, new_username=new_username) + _patch_user(gid, new_username=new_username) with expectation: - _patch_user(token, gid, new_username=new_username) + _patch_user(gid, token, new_username=new_username) def _create_user_key(token: str) -> str: @@ -469,7 +449,7 @@ def test_headers( expires_at: Optional[datetime], expected: SpanExportResult, _span_exporter: _SpanExporterConstructor, - _admin_token: _Token, + _admin_token: _AccessToken, _fake: Faker, ) -> None: headers: Optional[_Headers] = None @@ -490,5 +470,5 @@ def test_headers( for _ in range(2): assert export(spans) is expected if gid is not None and expected is SpanExportResult.SUCCESS: - _delete_system_api_key(_admin_token, gid) + _delete_system_api_key(gid, _admin_token) assert export(spans) is SpanExportResult.FAILURE From 8d5fc4acc67a721c6479d2c97bbbb4963d297ea2 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 16:00:05 -0700 Subject: [PATCH 07/21] clean up --- integration_tests/_helpers.py | 1 + integration_tests/auth/conftest.py | 9 +++++---- integration_tests/auth/test_auth.py | 11 ++++++----- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 803bf9bf94..8598ea523d 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -60,6 +60,7 @@ def __new__(cls, string: Optional[str] = None) -> _String: return super().__new__(cls, string) +_Secret: TypeAlias = str _Email: TypeAlias = str _Password: TypeAlias = str _Username: TypeAlias = str diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 3dd39c6b4c..454070b0de 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -21,6 +21,7 @@ _LoggedInUser, _Password, _Profile, + _Secret, _server, _UserGenerator, _Username, @@ -28,13 +29,13 @@ @pytest.fixture(scope="class") -def _secret() -> str: +def _secret() -> _Secret: return secrets.token_hex(32) @pytest.fixture(autouse=True, scope="class") def _app( - _secret: str, + _secret: _Secret, _env_phoenix_sql_database_url: Any, ) -> Iterator[None]: values = ( @@ -102,8 +103,8 @@ def _(role: UserRoleInput) -> _LoggedInUser: @pytest.fixture def _admin_token( - _admin_email: str, - _secret: str, + _admin_email: _Email, + _secret: _Secret, ) -> Iterator[_AccessToken]: with _log_in(_secret, email=_admin_email) as (token, _): yield token diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index f0afdb39fb..e725a48189 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -33,6 +33,7 @@ _patch_viewer, _Profile, _RefreshToken, + _Secret, _SpanExporterConstructor, _start_span, _Username, @@ -44,8 +45,8 @@ class TestTokens: def test_log_in_tokens_should_change( self, - _admin_email: str, - _secret: str, + _admin_email: _Email, + _secret: _Secret, ) -> None: n, access_tokens, refresh_tokens = 2, set(), set() for _ in range(n): @@ -74,7 +75,7 @@ def test_admin( email: _Email, use_secret: bool, expectation: ContextManager[Optional[Unauthorized]], - _secret: str, + _secret: _Secret, _fake: Faker, _passwords: Iterator[_Password], ) -> None: @@ -151,8 +152,8 @@ def test_create_user( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - _admin_email: str, - _secret: str, + _admin_email: _Email, + _secret: _Secret, _fake: Faker, _profiles: Iterator[_Profile], ) -> None: From d885f3ff3d2b2bc56b24ec17a547694040ad3bc2 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 18:44:28 -0700 Subject: [PATCH 08/21] clean up --- integration_tests/_helpers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 8598ea523d..7a6abe2c51 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -86,27 +86,26 @@ class _User: class _Token(_String, ABC): ... -class _AccessToken(_Token): ... - - class _ApiKey(_Token): ... class _RefreshToken(_Token): ... +class _AccessToken(_Token): + def log_out(self) -> None: + _log_out(self) + + class _LoggedInTokens(NamedTuple): access: _AccessToken refresh: _RefreshToken - def log_out(self) -> None: - _log_out(self.access) - def __enter__(self) -> _LoggedInTokens: return self def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.log_out() + self.access.log_out() @dataclass(frozen=True) From 421c7fe77efbd5aaecd8ba441bff28f316069fa2 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Fri, 6 Sep 2024 19:39:20 -0700 Subject: [PATCH 09/21] clean up --- integration_tests/_helpers.py | 64 +++++++++++---- integration_tests/auth/conftest.py | 30 +++---- integration_tests/auth/test_auth.py | 116 +++++++++++----------------- src/phoenix/auth.py | 2 + src/phoenix/db/facilitator.py | 12 ++- 5 files changed, 115 insertions(+), 109 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 7a6abe2c51..119493d851 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime +from functools import cached_property from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time @@ -27,6 +28,7 @@ import httpx import pytest from faker import Faker +from httpx import HTTPStatusError from openinference.semconv.resource import ResourceAttributes from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -53,26 +55,24 @@ _Headers: TypeAlias = Dict[str, Any] _Name: TypeAlias = str - -class _String(str, ABC): - def __new__(cls, string: Optional[str] = None) -> _String: - assert string - return super().__new__(cls, string) - - _Secret: TypeAlias = str _Email: TypeAlias = str _Password: TypeAlias = str _Username: TypeAlias = str -@dataclass(frozen=True) -class _Profile: +class _Profile(NamedTuple): email: _Email password: _Password username: Optional[_Username] = None +class _String(str, ABC): + def __new__(cls, string: Optional[str] = None) -> _String: + assert string + return super().__new__(cls, string) + + class _GqlId(_String): ... @@ -82,6 +82,21 @@ class _User: role: UserRoleInput profile: _Profile + def log_in(self) -> _LoggedInTokens: + return _log_in(self.password, email=self.email) + + @cached_property + def password(self) -> _Password: + return self.profile.password + + @cached_property + def email(self) -> _Email: + return self.profile.email + + @cached_property + def username(self) -> Optional[_Username]: + return self.profile.username + class _Token(_String, ABC): ... @@ -89,7 +104,13 @@ class _Token(_String, ABC): ... class _ApiKey(_Token): ... -class _RefreshToken(_Token): ... +class _RefreshToken(_Token): + def __call__(self) -> _LoggedInTokens: + resp = _httpx_client(refresh_token=self).post(urljoin(get_base_url(), "auth/refresh")) + resp.raise_for_status() + access_token = _AccessToken(resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) + refresh_token = _RefreshToken(resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) + return _LoggedInTokens(access_token, refresh_token) class _AccessToken(_Token): @@ -112,6 +133,12 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: class _LoggedInUser(_User): tokens: _LoggedInTokens + def log_out(self) -> None: + self.tokens.access.log_out() + + def refresh(self) -> _LoggedInTokens: + return self.tokens.refresh() + class _UserGenerator(Protocol): def send(self, role: UserRoleInput) -> _LoggedInUser: ... @@ -241,13 +268,13 @@ def _capture_stdout( @contextmanager -def _random_schema(url: URL, _fake: Faker) -> Iterator[str]: +def _random_schema(url: URL, fake: Faker) -> Iterator[str]: engine = create_engine(url.set(drivername="postgresql+psycopg")) try: engine.connect() except OperationalError as ex: pytest.skip(f"PostgreSQL unavailable: {ex}") - schema = _fake.unique.pystr().lower() + schema = fake.unique.pystr().lower() yield schema with engine.connect() as conn: conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) @@ -284,11 +311,12 @@ def _create_user( access_token: Optional[_AccessToken] = None, /, *, - email: _Email, - password: _Password, role: UserRoleInput, - username: Optional[_Username] = None, + profile: _Profile, ) -> _GqlId: + email = profile.email + password = profile.password + username = profile.username args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"] if username: args.append(f'username:"{username}"') @@ -391,7 +419,7 @@ def _log_in( ) -> _LoggedInTokens: resp = _httpx_client().post( urljoin(get_base_url(), "auth/login"), - json={"email": email, "password": password}, + json=dict(email=email, password=password), ) resp.raise_for_status() assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)) @@ -422,3 +450,7 @@ def _json( raise Unauthorized(msg) raise RuntimeError(msg) return resp_dict + + +_EXPECTATION_401 = pytest.raises(HTTPStatusError, match="401 Unauthorized") +_EXPECTATION_403 = pytest.raises(HTTPStatusError, match="403 Forbidden") diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 5690bbad74..c29c4cb95b 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -8,7 +8,11 @@ import pytest from faker import Faker -from phoenix.auth import REQUIREMENTS_FOR_PHOENIX_SECRET +from phoenix.auth import ( + DEFAULT_ADMIN_EMAIL, + DEFAULT_ADMIN_PASSWORD, + REQUIREMENTS_FOR_PHOENIX_SECRET, +) from phoenix.config import ENV_PHOENIX_ENABLE_AUTH, ENV_PHOENIX_SECRET from phoenix.server.api.input_types.UserRoleInput import UserRoleInput @@ -66,8 +70,8 @@ def _usernames(_fake: Faker) -> Iterator[_Username]: @pytest.fixture(scope="class") def _profiles( _emails: Iterator[_Email], - _usernames: Iterator[_Password], _passwords: Iterator[_Password], + _usernames: Iterator[_Username], ) -> Iterator[_Profile]: return starmap(_Profile, zip(_emails, _passwords, _usernames)) @@ -76,13 +80,12 @@ def _profiles( def _users( _profiles: Iterator[_Profile], _admin_token: _AccessToken, - _fake: Faker, ) -> _UserGenerator: def _() -> Generator[Optional[_LoggedInUser], UserRoleInput, None]: role = yield None for profile in _profiles: - gid = _create_user(_admin_token, **asdict(profile), role=role) - email, password = profile.email, profile.password + gid = _create_user(_admin_token, profile=profile, role=role) + password, email = profile.password, profile.email tokens = _log_in(password, email=email) role = yield _LoggedInUser(gid=gid, role=role, tokens=tokens, profile=profile) @@ -102,19 +105,6 @@ def _(role: UserRoleInput) -> _LoggedInUser: @pytest.fixture -def _admin_token( - _admin_email: _Email, - _admin_password: _Password, -) -> Iterator[_AccessToken]: - with _log_in(_admin_password, email=_admin_email) as (token, _): +def _admin_token() -> Iterator[_AccessToken]: + with _log_in(DEFAULT_ADMIN_PASSWORD, email=DEFAULT_ADMIN_EMAIL) as (token, _): yield token - - -@pytest.fixture(scope="module") -def _admin_email() -> _Email: - return "admin@localhost" - - -@pytest.fixture(scope="module") -def _admin_password() -> _Password: - return "admin" diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index aa0527dc49..3d53ac111a 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -11,12 +11,18 @@ from httpx import HTTPStatusError from opentelemetry.sdk.trace.export import SpanExportResult from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from phoenix.auth import PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME +from phoenix.auth import ( + DEFAULT_ADMIN_EMAIL, + DEFAULT_ADMIN_PASSWORD, + PHOENIX_ACCESS_TOKEN_COOKIE_NAME, + PHOENIX_REFRESH_TOKEN_COOKIE_NAME, +) from phoenix.config import get_base_url from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from .._helpers import ( + _EXPECTATION_401, _AccessToken, _create_system_api_key, _create_user, @@ -42,14 +48,11 @@ class TestTokens: - def test_log_in_tokens_should_change( - self, - _admin_email: _Email, - _admin_password: _Password, - ) -> None: + def test_log_in_tokens_should_change(self) -> None: + password, email = DEFAULT_ADMIN_PASSWORD, DEFAULT_ADMIN_EMAIL n, access_tokens, refresh_tokens = 2, set(), set() for _ in range(n): - with _log_in(_admin_password, email=_admin_email) as (access_token, refresh_token): + with _log_in(password, email=email) as (access_token, refresh_token): access_tokens.add(access_token) refresh_tokens.add(refresh_token) assert len(access_tokens) == n @@ -64,9 +67,9 @@ class TestUsers: "email,use_admin_password,expectation", [ ("admin@localhost", True, nullcontext()), - ("admin@localhost", False, pytest.raises(HTTPStatusError, match="401 Unauthorized")), - ("system@localhost", True, pytest.raises(HTTPStatusError, match="401 Unauthorized")), - ("admin", True, pytest.raises(HTTPStatusError, match="401 Unauthorized")), + ("admin@localhost", False, _EXPECTATION_401), + ("system@localhost", True, _EXPECTATION_401), + ("admin", True, _EXPECTATION_401), ], ) def test_admin( @@ -74,29 +77,22 @@ def test_admin( email: _Email, use_admin_password: bool, expectation: ContextManager[Optional[Unauthorized]], - _admin_password: _Password, _fake: Faker, _passwords: Iterator[_Password], ) -> None: - password = _admin_password if use_admin_password else next(_passwords) - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + password = DEFAULT_ADMIN_PASSWORD if use_admin_password else next(_passwords) + with _EXPECTATION_401: _create_system_api_key(None, name=_fake.unique.pystr()) with expectation: with _log_in(password, email=email) as (token, _): _create_system_api_key(token, name=_fake.unique.pystr()) - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _create_system_api_key(token, name=_fake.unique.pystr()) - def test_end_to_end_credentials_flow( - self, - _admin_email: _Email, - _admin_password: _Password, - _fake: Faker, - ) -> None: + def test_end_to_end_credentials_flow(self, _fake: Faker) -> None: + password, email = DEFAULT_ADMIN_PASSWORD, DEFAULT_ADMIN_EMAIL # user logs into first browser - browser_0_access_token_0, browser_0_refresh_token_0 = _log_in( - _admin_password, email=_admin_email - ) + browser_0_access_token_0, browser_0_refresh_token_0 = _log_in(password, email=email) # user creates api key in the first browser _create_system_api_key(browser_0_access_token_0, name="api-key-0") @@ -124,13 +120,11 @@ def test_end_to_end_credentials_flow( resp.raise_for_status() # original access token is invalid after refresh - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _create_system_api_key(browser_0_access_token_0, name="api-key-2") # user logs into second browser - browser_1_access_token_0, browser_1_refresh_token_0 = _log_in( - _admin_password, email=_admin_email - ) + browser_1_access_token_0, browser_1_refresh_token_0 = _log_in(password, email=email) # user creates api key in the second browser _create_system_api_key(browser_1_access_token_0, name="api-key-3") @@ -139,9 +133,9 @@ def test_end_to_end_credentials_flow( _log_out(browser_0_access_token_1, browser_0_refresh_token_1) # user is logged out of both browsers - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _create_system_api_key(browser_0_access_token_1, name="api-key-4") - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _create_system_api_key(browser_1_access_token_0, name="api-key-5") @pytest.mark.parametrize( @@ -155,32 +149,21 @@ def test_create_user( self, role: UserRoleInput, expectation: ContextManager[Optional[Unauthorized]], - _admin_email: _Email, - _admin_password: _Password, _fake: Faker, _profiles: Iterator[_Profile], ) -> None: profile = next(_profiles) - email = profile.email - username = profile.username - password = profile.password - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): - _create_user(email=email, password=password, username=username, role=role) - with _log_in(_admin_password, email=_admin_email) as (token, _): - _create_user(token, email=email, password=password, username=username, role=role) + with _EXPECTATION_401: + _create_user(profile=profile, role=role) + with _log_in(DEFAULT_ADMIN_PASSWORD, email=DEFAULT_ADMIN_EMAIL) as (token, _): + _create_user(token, profile=profile, role=role) + password, email = profile.password, profile.email with _log_in(password, email=email) as (token, _): with expectation: _create_system_api_key(token, name=_fake.unique.pystr()) for _role in UserRoleInput: - _profile = next(_profiles) with expectation: - _create_user( - token, - email=_profile.email, - username=_profile.username, - password=_profile.password, - role=_role, - ) + _create_user(token, profile=next(_profiles), role=_role) @pytest.mark.parametrize("role", list(UserRoleInput)) def test_user_can_change_password_for_self( @@ -190,8 +173,8 @@ def test_user_can_change_password_for_self( _passwords: Iterator[_Password], ) -> None: user = _get_new_user(role) - email = user.profile.email - password = user.profile.password + email = user.email + password = user.password (token, *_) = user.tokens new_password = f"new_password_{next(_passwords)}" assert new_password != password @@ -205,11 +188,11 @@ def test_user_can_change_password_for_self( _log_in(password, email=email) _patch_viewer((old_token := token), (old_password := password), new_password=new_password) another_password = f"another_password_{next(_passwords)}" - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _patch_viewer(old_token, new_password, new_password=another_password) - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _log_in(old_password, email=email) - new_token, _ = _log_in(new_password, email=email) + new_token, *_ = _log_in(new_password, email=email) with pytest.raises(BaseException): _patch_viewer(new_token, old_password, new_password=another_password) @@ -222,10 +205,10 @@ def test_user_can_change_username_for_self( _passwords: Iterator[_Password], ) -> None: user = _get_new_user(role) - (token, *_), password = user.tokens, user.profile.password + (token, *_), password = user.tokens, user.password new_username = f"new_username_{next(_usernames)}" for _password in (None, password): - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _patch_viewer(None, _password, new_username=new_username) _patch_viewer(token, None, new_username=new_username) another_username = f"another_username_{next(_usernames)}" @@ -250,7 +233,7 @@ def test_only_admin_can_change_role_for_non_self( non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid (token, *_), gid = user.tokens, non_self.gid - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _patch_user(gid, new_role=UserRoleInput.ADMIN) with expectation: _patch_user(gid, token, new_role=UserRoleInput.ADMIN) @@ -272,18 +255,18 @@ def test_only_admin_can_change_password_for_non_self( user = _get_new_user(role) non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid - old_password = non_self.profile.password + old_password = non_self.password new_password = f"new_password_{next(_passwords)}" assert new_password != old_password (token, *_), gid = user.tokens, non_self.gid - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _patch_user(gid, new_password=new_password) with expectation as e: _patch_user(gid, token, new_password=new_password) if e: return - email = non_self.profile.email - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + email = non_self.email + with _EXPECTATION_401: _log_in(old_password, email=email) _log_in(new_password, email=email) @@ -304,11 +287,11 @@ def test_only_admin_can_change_username_for_non_self( user = _get_new_user(role) non_self = _get_new_user(UserRoleInput.MEMBER) assert user.gid != non_self.gid - old_username = non_self.profile.username + old_username = non_self.username new_username = f"new_username_{next(_usernames)}" assert new_username != old_username (token, *_), gid = user.tokens, non_self.gid - with pytest.raises(HTTPStatusError, match="401 Unauthorized"): + with _EXPECTATION_401: _patch_user(gid, new_username=new_username) with expectation: _patch_user(gid, token, new_username=new_username) @@ -349,24 +332,17 @@ class TestApiKeys: } """ - def test_delete_user_api_key( - self, - _admin_email: _Email, - _admin_password: _Password, - _passwords: Iterator[_Password], - ) -> None: + def test_delete_user_api_key(self, _passwords: Iterator[_Password]) -> None: member_email = "member@localhost.com" username = "member" member_password = next(_passwords) - with _log_in(_admin_password, email=_admin_email) as (admin_token, _): + with _log_in(DEFAULT_ADMIN_PASSWORD, email=DEFAULT_ADMIN_EMAIL) as (admin_token, _): admin_api_key_id = _create_user_key(admin_token) _create_user( admin_token, - email=member_email, - password=member_password, role=UserRoleInput.MEMBER, - username=username, + profile=_Profile(email=member_email, password=member_password, username=username), ) with _log_in( diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 369656c485..b9ab3b15ab 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -176,6 +176,8 @@ def validate( raise ValueError(err_text) +DEFAULT_ADMIN_USERNAME = "admin" +DEFAULT_ADMIN_EMAIL = "admin@localhost" DEFAULT_ADMIN_PASSWORD = "admin" DEFAULT_SECRET_LENGTH = 32 """The default length of a secret key in bytes.""" diff --git a/src/phoenix/db/facilitator.py b/src/phoenix/db/facilitator.py index a21fbcccad..0e61861442 100644 --- a/src/phoenix/db/facilitator.py +++ b/src/phoenix/db/facilitator.py @@ -14,7 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.functions import coalesce -from phoenix.auth import DEFAULT_ADMIN_PASSWORD, DEFAULT_SECRET_LENGTH, compute_password_hash +from phoenix.auth import ( + DEFAULT_ADMIN_EMAIL, + DEFAULT_ADMIN_PASSWORD, + DEFAULT_ADMIN_USERNAME, + DEFAULT_SECRET_LENGTH, + compute_password_hash, +) from phoenix.config import ENABLE_AUTH from phoenix.db import models from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, UserRole @@ -93,8 +99,8 @@ async def _ensure_user_roles(session: AsyncSession) -> None: ) is not None: admin_user = models.User( user_role_id=admin_role_id, - username="admin", - email="admin@localhost", + username=DEFAULT_ADMIN_USERNAME, + email=DEFAULT_ADMIN_EMAIL, auth_method=AuthMethod.LOCAL.value, reset_password=True, ) From 65a46516b8f3683c981f106f9b21d06c285e92d6 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 15:10:38 -0700 Subject: [PATCH 10/21] add pytest-randomly --- dev-requirements.txt | 1 + integration_tests/_helpers.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 9c61eb2d78..b6fc3abd42 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,5 +2,6 @@ mypy==1.11.2 ruff==0.6.3 pytest==8.3.2 pytest-xdist==3.6.1 +pytest-randomly==3.15.0 pytest-asyncio==0.23.8 uvloop; platform_system != 'Windows' diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index a6ebdf5564..a8837e545b 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import secrets import sys from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext @@ -526,9 +527,9 @@ def _random_schema(url: URL, fake: Faker) -> Iterator[str]: engine = create_engine(url.set(drivername="postgresql+psycopg")) try: engine.connect() - except OperationalError as ex: - pytest.skip(f"PostgreSQL unavailable: {ex}") - schema = fake.unique.pystr().lower() + except OperationalError as exc: + pytest.skip(f"PostgreSQL unavailable: {exc}") + schema = "_" + secrets.token_hex(15) yield schema with engine.connect() as conn: conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) From 537076314be344346fca121bb5c0eeaaddea23d0 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 15:18:09 -0700 Subject: [PATCH 11/21] clean up --- integration_tests/auth/test_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 763bfeebc1..d00561ce6d 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -184,9 +184,9 @@ def test_end_to_end_credentials_flow( with _EXPECTATION_401: doers[0][0].create_api_key() - # user logs into second doers + # user logs into second browser doers[1][0] = u.log_in() - # user creates api key in the second doers + # user creates api key in the second browser doers[1][0].create_api_key() # user logs out in first browser From b26960b86ed1e11b79ec975dc93923732de5c7c5 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 15:30:34 -0700 Subject: [PATCH 12/21] clean up --- dev-requirements.txt | 1 - integration_tests/requirements.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index b6fc3abd42..9c61eb2d78 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,6 +2,5 @@ mypy==1.11.2 ruff==0.6.3 pytest==8.3.2 pytest-xdist==3.6.1 -pytest-randomly==3.15.0 pytest-asyncio==0.23.8 uvloop; platform_system != 'Windows' diff --git a/integration_tests/requirements.txt b/integration_tests/requirements.txt index b4c0803482..681b6a7077 100644 --- a/integration_tests/requirements.txt +++ b/integration_tests/requirements.txt @@ -4,4 +4,5 @@ openinference-semantic-conventions opentelemetry-sdk portpicker psutil +pytest-randomly types-psutil From 6d3ffdc5dc9ea85fc076016d507a2569757e717d Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 15:40:16 -0700 Subject: [PATCH 13/21] clean up --- integration_tests/auth/test_auth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index d00561ce6d..c9a591b9f1 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -58,7 +58,7 @@ class TestLogIn: @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) - def test_get_user_can_log_in( + def test_can_log_in( self, role_or_user: _RoleOrUser, _get_user: _GetUser, @@ -67,7 +67,7 @@ def test_get_user_can_log_in( u.log_in() @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) - def test_get_user_can_log_in_more_than_once_simultaneously( + def test_can_log_in_more_than_once_simultaneously( self, role_or_user: _RoleOrUser, _get_user: _GetUser, @@ -77,7 +77,7 @@ def test_get_user_can_log_in_more_than_once_simultaneously( u.log_in() @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) - def test_get_user_cannot_log_in_with_empty_password( + def test_cannot_log_in_with_empty_password( self, role_or_user: _RoleOrUser, _get_user: _GetUser, @@ -87,7 +87,7 @@ def test_get_user_cannot_log_in_with_empty_password( _log_in("", email=u.email) @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) - def test_get_user_cannot_log_in_with_wrong_password( + def test_cannot_log_in_with_wrong_password( self, role_or_user: _RoleOrUser, _get_user: _GetUser, @@ -101,7 +101,7 @@ def test_get_user_cannot_log_in_with_wrong_password( class TestLogOut: @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) - def test_get_user_can_log_out( + def test_can_log_out( self, role_or_user: _RoleOrUser, _get_user: _GetUser, From 6781fed200cbbc5ed427de47c266e4922b14f116 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 15:58:33 -0700 Subject: [PATCH 14/21] clean up --- integration_tests/_helpers.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index a8837e545b..85ab0812a1 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -34,7 +34,6 @@ import httpx import pytest -from faker import Faker from httpx import HTTPStatusError from openinference.semconv.resource import ResourceAttributes from opentelemetry.sdk.resources import Resource @@ -88,10 +87,6 @@ class _Profile: class _String(str, ABC): def __new__(cls, obj: Any) -> _String: - """ - - :rtype: object - """ assert obj is not None return super().__new__(cls, str(obj)) @@ -506,7 +501,9 @@ def _server() -> Iterator[None]: print(line, end="") -def _is_alive(process: Popen) -> bool: +def _is_alive( + process: Popen, +) -> bool: return process.is_running() and process.status() != STATUS_ZOMBIE @@ -523,13 +520,15 @@ def _capture_stdout( @contextmanager -def _random_schema(url: URL, fake: Faker) -> Iterator[str]: +def _random_schema( + url: URL, +) -> Iterator[str]: engine = create_engine(url.set(drivername="postgresql+psycopg")) try: engine.connect() except OperationalError as exc: pytest.skip(f"PostgreSQL unavailable: {exc}") - schema = "_" + secrets.token_hex(15) + schema = f"_{secrets.token_hex(15)}" yield schema with engine.connect() as conn: conn.execute(text(f"DROP SCHEMA IF EXISTS {schema} CASCADE;")) @@ -711,7 +710,12 @@ def _delete_api_key( assert resp_dict["data"][field]["apiKeyId"] == gid -def _log_in(password: _Password, /, *, email: _Email) -> _LoggedInTokens: +def _log_in( + password: _Password, + /, + *, + email: _Email, +) -> _LoggedInTokens: json_ = dict(email=email, password=password) resp = _httpx_client().post("auth/login", json=json_) resp.raise_for_status() From d12b3c3e8f542e6d288325779fafabc26b873ebd Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Sun, 8 Sep 2024 16:37:13 -0700 Subject: [PATCH 15/21] clean up --- integration_tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index e6f8cda5c7..4b98c5af2b 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -89,7 +89,7 @@ def _env_phoenix_sql_database_url( values = [(ENV_PHOENIX_SQL_DATABASE_URL, _sql_database_url.render_as_string())] with ExitStack() as stack: if _sql_database_url.get_backend_name().startswith("postgresql"): - schema = stack.enter_context(_random_schema(_sql_database_url, _fake)) + schema = stack.enter_context(_random_schema(_sql_database_url)) values.append((ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema)) stack.enter_context(mock.patch.dict(os.environ, values)) yield From f71444bd1a2b3d458f23c171b9e33ba340be4518 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 09:14:26 -0700 Subject: [PATCH 16/21] :feat: role based access control for gql queries --- integration_tests/_helpers.py | 11 +++- integration_tests/auth/test_auth.py | 64 +++++++++++++++++++ src/phoenix/server/api/README.md | 20 ++++++ src/phoenix/server/api/auth.py | 38 ++--------- src/phoenix/server/api/context.py | 10 ++- .../server/api/mutations/api_key_mutations.py | 47 ++------------ .../server/api/mutations/dataset_mutations.py | 16 ++--- .../api/mutations/experiment_mutations.py | 4 +- .../api/mutations/export_events_mutations.py | 6 +- .../server/api/mutations/project_mutations.py | 6 +- .../mutations/span_annotations_mutations.py | 8 +-- .../mutations/trace_annotations_mutations.py | 8 +-- .../server/api/mutations/user_mutations.py | 36 ++--------- src/phoenix/server/api/queries.py | 13 ++-- src/phoenix/server/app.py | 1 + src/phoenix/server/bearer_auth.py | 15 ++++- 16 files changed, 162 insertions(+), 141 deletions(-) create mode 100644 src/phoenix/server/api/README.md diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 85ab0812a1..aead1a72c3 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -55,7 +55,7 @@ get_env_grpc_port, get_env_host, ) -from phoenix.server.api.auth import IsAdmin, IsAuthenticated +from phoenix.server.api.auth import IsAdmin from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from psutil import STATUS_ZOMBIE, Popen @@ -130,6 +130,13 @@ def email(self) -> _Email: def username(self) -> Optional[_Username]: return self.profile.username + def gql( + self, + query: str, + variables: Optional[Mapping[str, Any]] = None, + ) -> Dict[str, Any]: + return _gql(self, query=query, variables=variables) + def create_user( self, role: UserRoleInput = _MEMBER, @@ -739,7 +746,7 @@ def _json( assert (resp_dict := cast(Dict[str, Any], resp.json())) if errers := resp_dict.get("errors"): msg = errers[0]["message"] - if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: + if "not auth" in msg or IsAdmin.message in msg: raise Unauthorized(msg) raise RuntimeError(msg) return resp_dict diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index c9a591b9f1..615cc00236 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -580,6 +580,70 @@ def test_only_admin_can_delete_system_api_key( doer.delete_api_key(api_key) +class TestGraphQLQuery: + @pytest.mark.parametrize( + "role_or_user,expectation", + [ + (_MEMBER, _DENIED), + (_ADMIN, _OK), + (_CZAR, _OK), + ], + ) + @pytest.mark.parametrize( + "query", + [ + "query{users{edges{node{id}}}}", + "query{userApiKeys{id}}", + "query{systemApiKeys{id}}", + ], + ) + def test_only_admin_can_list_users_and_api_keys( + self, + role_or_user: _RoleOrUser, + query: str, + expectation: _OK_OR_DENIED, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + doer = u.log_in() + with expectation: + doer.gql(query) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _CZAR]) + def test_can_query_user_node_for_self( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + doer = u.log_in() + query = 'query{node(id:"' + u.gid + '"){__typename}}' + doer.gql(query) + + @pytest.mark.parametrize( + "role_or_user,expectation", + [ + (_MEMBER, _DENIED), + (_ADMIN, _OK), + (_CZAR, _OK), + ], + ) + @pytest.mark.parametrize("role", list(UserRoleInput)) + def test_only_admin_can_query_user_node_for_non_self( + self, + role_or_user: _RoleOrUser, + role: UserRoleInput, + expectation: _OK_OR_DENIED, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + doer = u.log_in() + non_self = _get_user(role) + query = 'query{node(id:"' + non_self.gid + '"){__typename}}' + with expectation: + doer.gql(query) + + class TestSpanExporters: @pytest.mark.parametrize( "with_headers,expires_at,expected", diff --git a/src/phoenix/server/api/README.md b/src/phoenix/server/api/README.md new file mode 100644 index 0000000000..d96a6829cb --- /dev/null +++ b/src/phoenix/server/api/README.md @@ -0,0 +1,20 @@ +# Permission Matrix for GraphQL API + +## Mutations + +| Action | Admin | Member | +|:---------------------------|:-----:|:------:| +| Create System API Keys | Yes | No | +| Create User API Keys | Yes | Yes | +| Delete System API Keys | Yes | No | +| Delete Any User's API Keys | Yes | No | +| Delete Own User API Keys | Yes | Yes | + +## Queries + +| Action | Admin | Member | +|:-------------------------------------|:-----:|:------:| +| List All System API Keys | Yes | No | +| List All User API Keys | Yes | Yes | +| List All Users | Yes | No | +| Fetch Other User's Info, e.g. emails | Yes | No | diff --git a/src/phoenix/server/api/auth.py b/src/phoenix/server/api/auth.py index f24424e1bb..68a788dda9 100644 --- a/src/phoenix/server/api/auth.py +++ b/src/phoenix/server/api/auth.py @@ -4,7 +4,6 @@ from strawberry import Info from strawberry.permission import BasePermission -from phoenix.db import enums from phoenix.server.api.exceptions import Unauthorized from phoenix.server.bearer_auth import PhoenixUser @@ -21,40 +20,13 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return not info.context.read_only -class IsAuthenticated(Authorization): - message = "User is not authenticated" - - def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - if info.context.token_store is None: - return True - try: - user = info.context.request.user - except AttributeError: - return False - return isinstance(user, PhoenixUser) and user.is_authenticated +MSG_ADMIN_ONLY = "Only admin can perform this action" class IsAdmin(Authorization): - message = "Only admin can perform this action" + message = MSG_ADMIN_ONLY def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - if info.context.token_store is None: - return False - try: - user = info.context.request.user - except AttributeError: - return False - return ( - isinstance(user, PhoenixUser) - and user.is_authenticated - and user.claims is not None - and user.claims.attributes is not None - and user.claims.attributes.user_role == enums.UserRole.ADMIN - ) - - -class HasSecret(BasePermission): - message = "Application secret is not set" - - def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - return info.context.secret is not None + if not info.context.auth_enabled: + return True + return isinstance((user := info.context.user), PhoenixUser) and user.is_admin diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index a98e637a55..8ed113e57d 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,8 +1,8 @@ from asyncio import get_running_loop from dataclasses import dataclass -from functools import partial +from functools import cached_property, partial from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast from starlette.requests import Request as StarletteRequest from starlette.responses import Response as StarletteResponse @@ -42,6 +42,7 @@ UserRolesDataLoader, UsersDataLoader, ) +from phoenix.server.bearer_auth import PhoenixUser from phoenix.server.dml_event import DmlEvent from phoenix.server.types import ( CanGetLastUpdatedAt, @@ -96,6 +97,7 @@ class Context(BaseContext): event_queue: CanPutItem[DmlEvent] = _NoOp() corpus: Optional[Model] = None read_only: bool = False + auth_enabled: bool = False secret: Optional[str] = None token_store: Optional[TokenStore] = None @@ -146,3 +148,7 @@ async def log_out(self, user_id: int) -> None: response = self.get_response() response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) + + @cached_property + def user(self) -> PhoenixUser: + return cast(PhoenixUser, self.get_request().user) diff --git a/src/phoenix/server/api/mutations/api_key_mutations.py b/src/phoenix/server/api/mutations/api_key_mutations.py index 97738a5b90..d7877c9418 100644 --- a/src/phoenix/server/api/mutations/api_key_mutations.py +++ b/src/phoenix/server/api/mutations/api_key_mutations.py @@ -8,8 +8,7 @@ from strawberry.types import Info from phoenix.db import enums, models -from phoenix.db.models import ApiKey as OrmApiKey -from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsAdmin, IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.queries import Query @@ -59,32 +58,9 @@ class DeleteApiKeyMutationPayload: query: Query -def can_delete_user_key(info: Info[Context, None], key: OrmApiKey) -> bool: - try: - user = info.context.request.user # type: ignore - except AttributeError: - return False - return ( - isinstance(user, PhoenixUser) - and user.claims is not None - and user.claims.attributes is not None - and ( - user.claims.attributes.user_role == enums.UserRole.ADMIN - or int(user.identity) == key.user_id - ) - ) - - @strawberry.type class ApiKeyMutationMixin: - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def create_system_api_key( self, info: Info[Context, None], input: CreateApiKeyInput ) -> CreateSystemApiKeyMutationPayload: @@ -125,13 +101,7 @@ async def create_system_api_key( query=Query(), ) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_user_api_key( self, info: Info[Context, None], input: CreateUserApiKeyInput ) -> CreateUserApiKeyMutationPayload: @@ -166,7 +136,7 @@ async def create_user_api_key( query=Query(), ) - @strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated, IsAdmin, IsNotReadOnly]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def delete_system_api_key( self, info: Info[Context, None], input: DeleteApiKeyInput ) -> DeleteApiKeyMutationPayload: @@ -177,12 +147,7 @@ async def delete_system_api_key( await token_store.revoke(ApiKeyId(api_key_id)) return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query()) - @strawberry.mutation( - permission_classes=[ - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_user_api_key( self, info: Info[Context, None], input: DeleteApiKeyInput ) -> DeleteApiKeyMutationPayload: @@ -196,7 +161,7 @@ async def delete_user_api_key( ) if api_key is None: raise ValueError(f"API key with id {input.id} not found") - if not can_delete_user_key(info, api_key): + if int((user := info.context.user).identity) != api_key.user_id and not user.is_admin: raise Unauthorized("User not authorized to delete") await token_store.revoke(ApiKeyId(api_key_id)) return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query()) diff --git a/src/phoenix/server/api/mutations/dataset_mutations.py b/src/phoenix/server/api/mutations/dataset_mutations.py index baea741b4c..ddd4ab4fde 100644 --- a/src/phoenix/server/api/mutations/dataset_mutations.py +++ b/src/phoenix/server/api/mutations/dataset_mutations.py @@ -12,7 +12,7 @@ from phoenix.db import models from phoenix.db.helpers import get_eval_trace_ids_for_datasets, get_project_names_for_datasets -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import BadRequest, NotFound from phoenix.server.api.helpers.dataset_helpers import ( @@ -44,7 +44,7 @@ class DatasetMutationPayload: @strawberry.type class DatasetMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_dataset( self, info: Info[Context, None], @@ -67,7 +67,7 @@ async def create_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_dataset( self, info: Info[Context, None], @@ -96,7 +96,7 @@ async def patch_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def add_spans_to_dataset( self, info: Info[Context, None], @@ -225,7 +225,7 @@ async def add_spans_to_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def add_examples_to_dataset( self, info: Info[Context, None], input: AddExamplesToDatasetInput ) -> DatasetMutationPayload: @@ -351,7 +351,7 @@ async def add_examples_to_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_dataset( self, info: Info[Context, None], @@ -382,7 +382,7 @@ async def delete_dataset( info.context.event_queue.put(DatasetDeleteEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_dataset_examples( self, info: Info[Context, None], @@ -474,7 +474,7 @@ async def patch_dataset_examples( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_dataset_examples( self, info: Info[Context, None], input: DeleteDatasetExamplesInput ) -> DatasetMutationPayload: diff --git a/src/phoenix/server/api/mutations/experiment_mutations.py b/src/phoenix/server/api/mutations/experiment_mutations.py index 1372cdad3a..8ebfce4c6c 100644 --- a/src/phoenix/server/api/mutations/experiment_mutations.py +++ b/src/phoenix/server/api/mutations/experiment_mutations.py @@ -8,7 +8,7 @@ from phoenix.db import models from phoenix.db.helpers import get_eval_trace_ids_for_experiments, get_project_names_for_experiments -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import CustomGraphQLError from phoenix.server.api.input_types.DeleteExperimentsInput import DeleteExperimentsInput @@ -25,7 +25,7 @@ class ExperimentMutationPayload: @strawberry.type class ExperimentMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_experiments( self, info: Info[Context, None], diff --git a/src/phoenix/server/api/mutations/export_events_mutations.py b/src/phoenix/server/api/mutations/export_events_mutations.py index c051af65a8..57ad4cb04d 100644 --- a/src/phoenix/server/api/mutations/export_events_mutations.py +++ b/src/phoenix/server/api/mutations/export_events_mutations.py @@ -8,7 +8,7 @@ from strawberry.types import Info import phoenix.core.model_schema as ms -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.ClusterInput import ClusterInput from phoenix.server.api.types.Event import parse_event_ids_by_inferences_role, unpack_event_id @@ -19,7 +19,7 @@ @strawberry.type class ExportEventsMutationMixin: @strawberry.mutation( - permission_classes=[IsNotReadOnly, IsAuthenticated], + permission_classes=[IsNotReadOnly], description=( "Given a list of event ids, export the corresponding data subset in Parquet format." " File name is optional, but if specified, should be without file extension. By default" @@ -51,7 +51,7 @@ async def export_events( return ExportedFile(file_name=file_name) @strawberry.mutation( - permission_classes=[IsNotReadOnly, IsAuthenticated], + permission_classes=[IsNotReadOnly], description=( "Given a list of clusters, export the corresponding data subset in Parquet format." " File name is optional, but if specified, should be without file extension. By default" diff --git a/src/phoenix/server/api/mutations/project_mutations.py b/src/phoenix/server/api/mutations/project_mutations.py index aa51b49b86..30d38620f1 100644 --- a/src/phoenix/server/api/mutations/project_mutations.py +++ b/src/phoenix/server/api/mutations/project_mutations.py @@ -6,7 +6,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput from phoenix.server.api.queries import Query @@ -16,7 +16,7 @@ @strawberry.type class ProjectMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query: project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project") async with info.context.db() as session: @@ -33,7 +33,7 @@ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query info.context.event_queue.put(ProjectDeleteEvent((project_id,))) return Query() - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query: project_id = from_global_id_with_expected_type( global_id=input.id, expected_type_name="Project" diff --git a/src/phoenix/server/api/mutations/span_annotations_mutations.py b/src/phoenix/server/api/mutations/span_annotations_mutations.py index 95c38c3ba1..f007f73d41 100644 --- a/src/phoenix/server/api/mutations/span_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/span_annotations_mutations.py @@ -6,7 +6,7 @@ from strawberry.types import Info from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanAnnotationInput from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput @@ -25,7 +25,7 @@ class SpanAnnotationMutationPayload: @strawberry.type class SpanAnnotationMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_span_annotations( self, info: Info[Context, None], input: List[CreateSpanAnnotationInput] ) -> SpanAnnotationMutationPayload: @@ -59,7 +59,7 @@ async def create_span_annotations( query=Query(), ) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_span_annotations( self, info: Info[Context, None], input: List[PatchAnnotationInput] ) -> SpanAnnotationMutationPayload: @@ -99,7 +99,7 @@ async def patch_span_annotations( info.context.event_queue.put(SpanAnnotationInsertEvent((span_annotation.id,))) return SpanAnnotationMutationPayload(span_annotations=patched_annotations, query=Query()) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_span_annotations( self, info: Info[Context, None], input: DeleteAnnotationsInput ) -> SpanAnnotationMutationPayload: diff --git a/src/phoenix/server/api/mutations/trace_annotations_mutations.py b/src/phoenix/server/api/mutations/trace_annotations_mutations.py index 3fccc94b29..2aeaca77e1 100644 --- a/src/phoenix/server/api/mutations/trace_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/trace_annotations_mutations.py @@ -6,7 +6,7 @@ from strawberry.types import Info from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput @@ -25,7 +25,7 @@ class TraceAnnotationMutationPayload: @strawberry.type class TraceAnnotationMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_trace_annotations( self, info: Info[Context, None], input: List[CreateTraceAnnotationInput] ) -> TraceAnnotationMutationPayload: @@ -59,7 +59,7 @@ async def create_trace_annotations( query=Query(), ) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_trace_annotations( self, info: Info[Context, None], input: List[PatchAnnotationInput] ) -> TraceAnnotationMutationPayload: @@ -98,7 +98,7 @@ async def patch_trace_annotations( info.context.event_queue.put(TraceAnnotationInsertEvent((trace_annotation.id,))) return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query()) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_trace_annotations( self, info: Info[Context, None], input: DeleteAnnotationsInput ) -> TraceAnnotationMutationPayload: diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 8cd6f668a0..7b0c822c7f 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -20,7 +20,7 @@ validate_password_format, ) from phoenix.db import enums, models -from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsAdmin, IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import Conflict, NotFound from phoenix.server.api.input_types.UserRoleInput import UserRoleInput @@ -78,14 +78,7 @@ class UserMutationPayload: @strawberry.type class UserMutationMixin: - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def create_user( self, info: Info[Context, None], @@ -117,14 +110,7 @@ async def create_user( raise ValueError(_user_operation_error_message(error)) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def patch_user( self, info: Info[Context, None], @@ -165,13 +151,7 @@ async def patch_user( await info.context.log_out(user.id) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_viewer( self, info: Info[Context, None], @@ -209,13 +189,7 @@ async def patch_viewer( await info.context.log_out(user.id) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def delete_users( self, info: Info[Context, None], diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index ef5b0f21ae..0c52856db8 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -33,8 +33,9 @@ Trace as OrmTrace, ) from phoenix.pointcloud.clustering import Hdbscan +from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin from phoenix.server.api.context import Context -from phoenix.server.api.exceptions import NotFound +from phoenix.server.api.exceptions import NotFound, Unauthorized from phoenix.server.api.helpers import ensure_list from phoenix.server.api.input_types.ClusterInput import ClusterInput from phoenix.server.api.input_types.Coordinates import ( @@ -77,7 +78,7 @@ @strawberry.type class Query: - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def users( self, info: Info[Context, None], @@ -121,9 +122,8 @@ async def user_roles( for role in roles ] - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]: - # TODO(auth): add access control stmt = ( select(models.ApiKey) .join(models.User) @@ -134,9 +134,8 @@ async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]: api_keys = await session.scalars(stmt) return [to_gql_api_key(api_key) for api_key in api_keys] - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]: - # TODO(auth): add access control stmt = ( select(models.ApiKey) .join(models.User) @@ -468,6 +467,8 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: raise NotFound(f"Unknown experiment run: {id}") return to_gql_experiment_run(run) elif type_name == User.__name__: + if int((user := info.context.user).identity) != node_id and not user.is_admin: + raise Unauthorized(MSG_ADMIN_ONLY) async with info.context.db() as session: if not ( user := await session.scalar( diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 9a71aa1c32..b6f476aad4 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -539,6 +539,7 @@ def get_context() -> Context: ), cache_for_dataloaders=cache_for_dataloaders, read_only=read_only, + auth_enabled=authentication_enabled, secret=secret, token_store=token_store, ) diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 613a44fb9c..4a986f12da 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -1,4 +1,5 @@ from abc import ABC +from functools import cached_property from typing import Any, Awaitable, Callable, Optional, Tuple import grpc @@ -16,6 +17,7 @@ ClaimSetStatus, Token, ) +from phoenix.db import enums from phoenix.server.types import AccessTokenClaims, ApiKeyClaims, UserClaimSet, UserId @@ -50,12 +52,21 @@ class PhoenixUser(BaseUser): def __init__(self, user_id: UserId, claims: UserClaimSet) -> None: self._user_id = user_id self.claims = claims + assert claims.attributes + self._is_admin = ( + claims.status is ClaimSetStatus.VALID + and claims.attributes.user_role == enums.UserRole.ADMIN + ) - @property + @cached_property + def is_admin(self) -> bool: + return self._is_admin + + @cached_property def identity(self) -> UserId: return self._user_id - @property + @cached_property def is_authenticated(self) -> bool: return True From 364236977428bac0d119cd265b8ed57ed3efe71a Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 09:18:02 -0700 Subject: [PATCH 17/21] clean up --- src/phoenix/server/api/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/server/api/README.md b/src/phoenix/server/api/README.md index d96a6829cb..5b3b3798e8 100644 --- a/src/phoenix/server/api/README.md +++ b/src/phoenix/server/api/README.md @@ -15,6 +15,6 @@ | Action | Admin | Member | |:-------------------------------------|:-----:|:------:| | List All System API Keys | Yes | No | -| List All User API Keys | Yes | Yes | +| List All User API Keys | Yes | No | | List All Users | Yes | No | | Fetch Other User's Info, e.g. emails | Yes | No | From 81a6840fc3a97ce7d38aa9d2371e050d0917fb01 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 11:52:20 -0700 Subject: [PATCH 18/21] clean up --- integration_tests/auth/test_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 38dbc9f0b9..ce21492dfb 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -676,5 +676,5 @@ def test_headers( for _ in range(2): assert export(spans) is expected if api_key and expected is SpanExportResult.SUCCESS: - _DEFAULT_ADMIN.log_in().delete_api_key(api_key) + _DEFAULT_ADMIN.delete_api_key(api_key) assert export(spans) is SpanExportResult.FAILURE From f76bdd549a8b4ccd7e2e1e1808535996adaa1065 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 12:41:36 -0700 Subject: [PATCH 19/21] clean up --- src/phoenix/server/api/README.md | 8 ++++++++ src/phoenix/server/api/auth.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/README.md b/src/phoenix/server/api/README.md index 5b3b3798e8..443eb0566e 100644 --- a/src/phoenix/server/api/README.md +++ b/src/phoenix/server/api/README.md @@ -4,6 +4,14 @@ | Action | Admin | Member | |:---------------------------|:-----:|:------:| +| Create User | Yes | No | +| Delete User | Yes | No | +| Change Own Password | Yes | Yes | +| Change Other's Password | Yes | No | +| Change Own Username | Yes | Yes | +| Change Other's Username | Yes | No | +| Change Own Email | No | No | +| Change Other's Email | No | No | | Create System API Keys | Yes | No | | Create User API Keys | Yes | Yes | | Delete System API Keys | Yes | No | diff --git a/src/phoenix/server/api/auth.py b/src/phoenix/server/api/auth.py index 68a788dda9..2e937dd0c2 100644 --- a/src/phoenix/server/api/auth.py +++ b/src/phoenix/server/api/auth.py @@ -28,5 +28,5 @@ class IsAdmin(Authorization): def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: if not info.context.auth_enabled: - return True + return False return isinstance((user := info.context.user), PhoenixUser) and user.is_admin From 6c1624dd83cb6c75b0c09d2a26ad31696f12abe3 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 12:43:02 -0700 Subject: [PATCH 20/21] clean up --- integration_tests/auth/test_auth.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index ce21492dfb..359e78c516 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -605,9 +605,9 @@ def test_only_admin_can_list_users_and_api_keys( _get_user: _GetUser, ) -> None: u = _get_user(role_or_user) - doer = u.log_in() + logged_in_user = u.log_in() with expectation: - doer.gql(query) + logged_in_user.gql(query) @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN]) def test_can_query_user_node_for_self( @@ -616,9 +616,9 @@ def test_can_query_user_node_for_self( _get_user: _GetUser, ) -> None: u = _get_user(role_or_user) - doer = u.log_in() + logged_in_user = u.log_in() query = 'query{node(id:"' + u.gid + '"){__typename}}' - doer.gql(query) + logged_in_user.gql(query) @pytest.mark.parametrize( "role_or_user,expectation", @@ -637,11 +637,11 @@ def test_only_admin_can_query_user_node_for_non_self( _get_user: _GetUser, ) -> None: u = _get_user(role_or_user) - doer = u.log_in() + logged_in_user = u.log_in() non_self = _get_user(role) query = 'query{node(id:"' + non_self.gid + '"){__typename}}' with expectation: - doer.gql(query) + logged_in_user.gql(query) class TestSpanExporters: From 5d6acdf60dacac9d931c0ac4607fc61b1726b21a Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 9 Sep 2024 12:45:38 -0700 Subject: [PATCH 21/21] clean up --- src/phoenix/server/api/README.md | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/phoenix/server/api/README.md b/src/phoenix/server/api/README.md index 443eb0566e..a646a42c71 100644 --- a/src/phoenix/server/api/README.md +++ b/src/phoenix/server/api/README.md @@ -2,21 +2,21 @@ ## Mutations -| Action | Admin | Member | -|:---------------------------|:-----:|:------:| -| Create User | Yes | No | -| Delete User | Yes | No | -| Change Own Password | Yes | Yes | -| Change Other's Password | Yes | No | -| Change Own Username | Yes | Yes | -| Change Other's Username | Yes | No | -| Change Own Email | No | No | -| Change Other's Email | No | No | -| Create System API Keys | Yes | No | -| Create User API Keys | Yes | Yes | -| Delete System API Keys | Yes | No | -| Delete Any User's API Keys | Yes | No | -| Delete Own User API Keys | Yes | Yes | +| Action | Admin | Member | +|:-----------------------------|:-----:|:------:| +| Create User | Yes | No | +| Delete User | Yes | No | +| Change Own Password | Yes | Yes | +| Change Other's Password | Yes | No | +| Change Own Username | Yes | Yes | +| Change Other's Username | Yes | No | +| Change Own Email | No | No | +| Change Other's Email | No | No | +| Create System API Keys | Yes | No | +| Delete System API Keys | Yes | No | +| Create Own User API Keys | Yes | Yes | +| Delete Own User API Keys | Yes | Yes | +| Delete Other's User API Keys | Yes | No | ## Queries