Skip to content

Commit

Permalink
feat: role based access control for gql queries (#4554)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Sep 9, 2024
1 parent ad5f9cf commit e079369
Show file tree
Hide file tree
Showing 16 changed files with 186 additions and 157 deletions.
11 changes: 9 additions & 2 deletions integration_tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
get_env_grpc_port,
get_env_host,
)
from phoenix.server.api.auth import IsAdmin, IsAuthenticated
from phoenix.server.api.auth import IsAdmin
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from psutil import STATUS_ZOMBIE, Popen
Expand Down Expand Up @@ -130,6 +130,13 @@ def email(self) -> _Email:
def username(self) -> Optional[_Username]:
return self.profile.username

def gql(
self,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
return _gql(self, query=query, variables=variables)

def create_user(
self,
role: UserRoleInput = _MEMBER,
Expand Down Expand Up @@ -739,7 +746,7 @@ def _json(
assert (resp_dict := cast(Dict[str, Any], resp.json()))
if errers := resp_dict.get("errors"):
msg = errers[0]["message"]
if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg:
if "not auth" in msg or IsAdmin.message in msg:
raise Unauthorized(msg)
raise RuntimeError(msg)
return resp_dict
Expand Down
98 changes: 81 additions & 17 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def test_can_log_out(
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
doers = [u.log_in() for _ in range(2)]
for logged_in_user in doers:
logged_in_users = [u.log_in() for _ in range(2)]
for logged_in_user in logged_in_users:
logged_in_user.create_api_key()
doers[0].log_out()
for logged_in_user in doers:
logged_in_users[0].log_out()
for logged_in_user in logged_in_users:
with _EXPECTATION_401:
logged_in_user.create_api_key()

Expand Down Expand Up @@ -167,35 +167,35 @@ def test_end_to_end_credentials_flow(
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
doers: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict)
logged_in_users: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict)

# user logs into first browser
doers[0][0] = u.log_in()
logged_in_users[0][0] = u.log_in()
# user creates api key in the first browser
doers[0][0].create_api_key()
logged_in_users[0][0].create_api_key()
# tokens are refreshed in the first browser
doers[0][1] = doers[0][0].refresh()
logged_in_users[0][1] = logged_in_users[0][0].refresh()
# user creates api key in the first browser
doers[0][1].create_api_key()
logged_in_users[0][1].create_api_key()
# refresh token is good for one use only
with pytest.raises(HTTPStatusError):
doers[0][0].refresh()
logged_in_users[0][0].refresh()
# original access token is invalid after refresh
with _EXPECTATION_401:
doers[0][0].create_api_key()
logged_in_users[0][0].create_api_key()

# user logs into second browser
doers[1][0] = u.log_in()
logged_in_users[1][0] = u.log_in()
# user creates api key in the second browser
doers[1][0].create_api_key()
logged_in_users[1][0].create_api_key()

# user logs out in first browser
doers[0][1].log_out()
logged_in_users[0][1].log_out()
# user is logged out of both browsers
with _EXPECTATION_401:
doers[0][1].create_api_key()
logged_in_users[0][1].create_api_key()
with _EXPECTATION_401:
doers[1][0].create_api_key()
logged_in_users[1][0].create_api_key()


class TestCreateUser:
Expand Down Expand Up @@ -580,6 +580,70 @@ def test_only_admin_can_delete_system_api_key(
logged_in_user.delete_api_key(api_key)


class TestGraphQLQuery:
@pytest.mark.parametrize(
"role_or_user,expectation",
[
(_MEMBER, _DENIED),
(_ADMIN, _OK),
(_DEFAULT_ADMIN, _OK),
],
)
@pytest.mark.parametrize(
"query",
[
"query{users{edges{node{id}}}}",
"query{userApiKeys{id}}",
"query{systemApiKeys{id}}",
],
)
def test_only_admin_can_list_users_and_api_keys(
self,
role_or_user: _RoleOrUser,
query: str,
expectation: _OK_OR_DENIED,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
with expectation:
logged_in_user.gql(query)

@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN])
def test_can_query_user_node_for_self(
self,
role_or_user: _RoleOrUser,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
query = 'query{node(id:"' + u.gid + '"){__typename}}'
logged_in_user.gql(query)

@pytest.mark.parametrize(
"role_or_user,expectation",
[
(_MEMBER, _DENIED),
(_ADMIN, _OK),
(_DEFAULT_ADMIN, _OK),
],
)
@pytest.mark.parametrize("role", list(UserRoleInput))
def test_only_admin_can_query_user_node_for_non_self(
self,
role_or_user: _RoleOrUser,
role: UserRoleInput,
expectation: _OK_OR_DENIED,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
non_self = _get_user(role)
query = 'query{node(id:"' + non_self.gid + '"){__typename}}'
with expectation:
logged_in_user.gql(query)


class TestSpanExporters:
@pytest.mark.parametrize(
"with_headers,expires_at,expected",
Expand Down Expand Up @@ -612,5 +676,5 @@ def test_headers(
for _ in range(2):
assert export(spans) is expected
if api_key and expected is SpanExportResult.SUCCESS:
_DEFAULT_ADMIN.log_in().delete_api_key(api_key)
_DEFAULT_ADMIN.delete_api_key(api_key)
assert export(spans) is SpanExportResult.FAILURE
28 changes: 28 additions & 0 deletions src/phoenix/server/api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Permission Matrix for GraphQL API

## Mutations

| Action | Admin | Member |
|:-----------------------------|:-----:|:------:|
| Create User | Yes | No |
| Delete User | Yes | No |
| Change Own Password | Yes | Yes |
| Change Other's Password | Yes | No |
| Change Own Username | Yes | Yes |
| Change Other's Username | Yes | No |
| Change Own Email | No | No |
| Change Other's Email | No | No |
| Create System API Keys | Yes | No |
| Delete System API Keys | Yes | No |
| Create Own User API Keys | Yes | Yes |
| Delete Own User API Keys | Yes | Yes |
| Delete Other's User API Keys | Yes | No |

## Queries

| Action | Admin | Member |
|:-------------------------------------|:-----:|:------:|
| List All System API Keys | Yes | No |
| List All User API Keys | Yes | No |
| List All Users | Yes | No |
| Fetch Other User's Info, e.g. emails | Yes | No |
36 changes: 4 additions & 32 deletions src/phoenix/server/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from strawberry import Info
from strawberry.permission import BasePermission

from phoenix.db import enums
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.bearer_auth import PhoenixUser

Expand All @@ -21,40 +20,13 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
return not info.context.read_only


class IsAuthenticated(Authorization):
message = "User is not authenticated"

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
if info.context.token_store is None:
return True
try:
user = info.context.request.user
except AttributeError:
return False
return isinstance(user, PhoenixUser) and user.is_authenticated
MSG_ADMIN_ONLY = "Only admin can perform this action"


class IsAdmin(Authorization):
message = "Only admin can perform this action"
message = MSG_ADMIN_ONLY

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
if info.context.token_store is None:
return False
try:
user = info.context.request.user
except AttributeError:
if not info.context.auth_enabled:
return False
return (
isinstance(user, PhoenixUser)
and user.is_authenticated
and user.claims is not None
and user.claims.attributes is not None
and user.claims.attributes.user_role == enums.UserRole.ADMIN
)


class HasSecret(BasePermission):
message = "Application secret is not set"

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
return info.context.secret is not None
return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
10 changes: 8 additions & 2 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from asyncio import get_running_loop
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, cast

from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
Expand Down Expand Up @@ -42,6 +42,7 @@
UserRolesDataLoader,
UsersDataLoader,
)
from phoenix.server.bearer_auth import PhoenixUser
from phoenix.server.dml_event import DmlEvent
from phoenix.server.types import (
CanGetLastUpdatedAt,
Expand Down Expand Up @@ -96,6 +97,7 @@ class Context(BaseContext):
event_queue: CanPutItem[DmlEvent] = _NoOp()
corpus: Optional[Model] = None
read_only: bool = False
auth_enabled: bool = False
secret: Optional[str] = None
token_store: Optional[TokenStore] = None

Expand Down Expand Up @@ -146,3 +148,7 @@ async def log_out(self, user_id: int) -> None:
response = self.get_response()
response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)

@cached_property
def user(self) -> PhoenixUser:
return cast(PhoenixUser, self.get_request().user)
47 changes: 6 additions & 41 deletions src/phoenix/server/api/mutations/api_key_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from strawberry.types import Info

from phoenix.db import enums, models
from phoenix.db.models import ApiKey as OrmApiKey
from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly
from phoenix.server.api.auth import IsAdmin, IsNotReadOnly
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.queries import Query
Expand Down Expand Up @@ -59,32 +58,9 @@ class DeleteApiKeyMutationPayload:
query: Query


def can_delete_user_key(info: Info[Context, None], key: OrmApiKey) -> bool:
try:
user = info.context.request.user # type: ignore
except AttributeError:
return False
return (
isinstance(user, PhoenixUser)
and user.claims is not None
and user.claims.attributes is not None
and (
user.claims.attributes.user_role == enums.UserRole.ADMIN
or int(user.identity) == key.user_id
)
)


@strawberry.type
class ApiKeyMutationMixin:
@strawberry.mutation(
permission_classes=[
IsNotReadOnly,
HasSecret,
IsAuthenticated,
IsAdmin,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
async def create_system_api_key(
self, info: Info[Context, None], input: CreateApiKeyInput
) -> CreateSystemApiKeyMutationPayload:
Expand Down Expand Up @@ -125,13 +101,7 @@ async def create_system_api_key(
query=Query(),
)

@strawberry.mutation(
permission_classes=[
IsNotReadOnly,
HasSecret,
IsAuthenticated,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
async def create_user_api_key(
self, info: Info[Context, None], input: CreateUserApiKeyInput
) -> CreateUserApiKeyMutationPayload:
Expand Down Expand Up @@ -166,7 +136,7 @@ async def create_user_api_key(
query=Query(),
)

@strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated, IsAdmin, IsNotReadOnly]) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
async def delete_system_api_key(
self, info: Info[Context, None], input: DeleteApiKeyInput
) -> DeleteApiKeyMutationPayload:
Expand All @@ -177,12 +147,7 @@ async def delete_system_api_key(
await token_store.revoke(ApiKeyId(api_key_id))
return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query())

@strawberry.mutation(
permission_classes=[
HasSecret,
IsAuthenticated,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
async def delete_user_api_key(
self, info: Info[Context, None], input: DeleteApiKeyInput
) -> DeleteApiKeyMutationPayload:
Expand All @@ -196,7 +161,7 @@ async def delete_user_api_key(
)
if api_key is None:
raise ValueError(f"API key with id {input.id} not found")
if not can_delete_user_key(info, api_key):
if int((user := info.context.user).identity) != api_key.user_id and not user.is_admin:
raise Unauthorized("User not authorized to delete")
await token_store.revoke(ApiKeyId(api_key_id))
return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query())
Loading

0 comments on commit e079369

Please sign in to comment.