Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: role based access control for gql queries #4554

Merged
merged 24 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions integration_tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
get_env_grpc_port,
get_env_host,
)
from phoenix.server.api.auth import IsAdmin, IsAuthenticated
from phoenix.server.api.auth import IsAdmin
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from psutil import STATUS_ZOMBIE, Popen
Expand Down Expand Up @@ -130,6 +130,13 @@ def email(self) -> _Email:
def username(self) -> Optional[_Username]:
return self.profile.username

def gql(
self,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
return _gql(self, query=query, variables=variables)
axiomofjoy marked this conversation as resolved.
Show resolved Hide resolved

def create_user(
self,
role: UserRoleInput = _MEMBER,
Expand Down Expand Up @@ -739,7 +746,7 @@ def _json(
assert (resp_dict := cast(Dict[str, Any], resp.json()))
if errers := resp_dict.get("errors"):
msg = errers[0]["message"]
if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg:
if "not auth" in msg or IsAdmin.message in msg:
raise Unauthorized(msg)
raise RuntimeError(msg)
return resp_dict
Expand Down
98 changes: 81 additions & 17 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def test_can_log_out(
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
doers = [u.log_in() for _ in range(2)]
for logged_in_user in doers:
logged_in_users = [u.log_in() for _ in range(2)]
for logged_in_user in logged_in_users:
logged_in_user.create_api_key()
doers[0].log_out()
for logged_in_user in doers:
logged_in_users[0].log_out()
for logged_in_user in logged_in_users:
with _EXPECTATION_401:
logged_in_user.create_api_key()

Expand Down Expand Up @@ -167,35 +167,35 @@ def test_end_to_end_credentials_flow(
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
doers: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict)
logged_in_users: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict)

# user logs into first browser
doers[0][0] = u.log_in()
logged_in_users[0][0] = u.log_in()
# user creates api key in the first browser
doers[0][0].create_api_key()
logged_in_users[0][0].create_api_key()
# tokens are refreshed in the first browser
doers[0][1] = doers[0][0].refresh()
logged_in_users[0][1] = logged_in_users[0][0].refresh()
# user creates api key in the first browser
doers[0][1].create_api_key()
logged_in_users[0][1].create_api_key()
# refresh token is good for one use only
with pytest.raises(HTTPStatusError):
doers[0][0].refresh()
logged_in_users[0][0].refresh()
# original access token is invalid after refresh
with _EXPECTATION_401:
doers[0][0].create_api_key()
logged_in_users[0][0].create_api_key()

# user logs into second browser
doers[1][0] = u.log_in()
logged_in_users[1][0] = u.log_in()
# user creates api key in the second browser
doers[1][0].create_api_key()
logged_in_users[1][0].create_api_key()

# user logs out in first browser
doers[0][1].log_out()
logged_in_users[0][1].log_out()
# user is logged out of both browsers
with _EXPECTATION_401:
doers[0][1].create_api_key()
logged_in_users[0][1].create_api_key()
with _EXPECTATION_401:
doers[1][0].create_api_key()
logged_in_users[1][0].create_api_key()


class TestCreateUser:
Expand Down Expand Up @@ -580,6 +580,70 @@ def test_only_admin_can_delete_system_api_key(
logged_in_user.delete_api_key(api_key)


class TestGraphQLQuery:
@pytest.mark.parametrize(
"role_or_user,expectation",
[
(_MEMBER, _DENIED),
(_ADMIN, _OK),
(_DEFAULT_ADMIN, _OK),
],
)
@pytest.mark.parametrize(
"query",
[
"query{users{edges{node{id}}}}",
"query{userApiKeys{id}}",
"query{systemApiKeys{id}}",
],
)
def test_only_admin_can_list_users_and_api_keys(
self,
role_or_user: _RoleOrUser,
query: str,
expectation: _OK_OR_DENIED,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
with expectation:
logged_in_user.gql(query)

@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN])
def test_can_query_user_node_for_self(
self,
role_or_user: _RoleOrUser,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
query = 'query{node(id:"' + u.gid + '"){__typename}}'
logged_in_user.gql(query)

@pytest.mark.parametrize(
"role_or_user,expectation",
[
(_MEMBER, _DENIED),
(_ADMIN, _OK),
(_DEFAULT_ADMIN, _OK),
],
)
@pytest.mark.parametrize("role", list(UserRoleInput))
def test_only_admin_can_query_user_node_for_non_self(
self,
role_or_user: _RoleOrUser,
role: UserRoleInput,
expectation: _OK_OR_DENIED,
_get_user: _GetUser,
) -> None:
u = _get_user(role_or_user)
logged_in_user = u.log_in()
non_self = _get_user(role)
query = 'query{node(id:"' + non_self.gid + '"){__typename}}'
with expectation:
logged_in_user.gql(query)


class TestSpanExporters:
@pytest.mark.parametrize(
"with_headers,expires_at,expected",
Expand Down Expand Up @@ -612,5 +676,5 @@ def test_headers(
for _ in range(2):
assert export(spans) is expected
if api_key and expected is SpanExportResult.SUCCESS:
_DEFAULT_ADMIN.log_in().delete_api_key(api_key)
_DEFAULT_ADMIN.delete_api_key(api_key)
assert export(spans) is SpanExportResult.FAILURE
28 changes: 28 additions & 0 deletions src/phoenix/server/api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Permission Matrix for GraphQL API

## Mutations

| Action | Admin | Member |
|:-----------------------------|:-----:|:------:|
| Create User | Yes | No |
| Delete User | Yes | No |
| Change Own Password | Yes | Yes |
| Change Other's Password | Yes | No |
| Change Own Username | Yes | Yes |
| Change Other's Username | Yes | No |
| Change Own Email | No | No |
| Change Other's Email | No | No |
| Create System API Keys | Yes | No |
| Delete System API Keys | Yes | No |
| Create Own User API Keys | Yes | Yes |
| Delete Own User API Keys | Yes | Yes |
| Delete Other's User API Keys | Yes | No |

## Queries

| Action | Admin | Member |
|:-------------------------------------|:-----:|:------:|
| List All System API Keys | Yes | No |
| List All User API Keys | Yes | No |
| List All Users | Yes | No |
| Fetch Other User's Info, e.g. emails | Yes | No |
36 changes: 4 additions & 32 deletions src/phoenix/server/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from strawberry import Info
from strawberry.permission import BasePermission

from phoenix.db import enums
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.bearer_auth import PhoenixUser

Expand All @@ -21,40 +20,13 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
return not info.context.read_only


class IsAuthenticated(Authorization):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now redundant after /graphql itself has been secured

message = "User is not authenticated"

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
if info.context.token_store is None:
return True
try:
user = info.context.request.user
except AttributeError:
return False
return isinstance(user, PhoenixUser) and user.is_authenticated
MSG_ADMIN_ONLY = "Only admin can perform this action"


class IsAdmin(Authorization):
message = "Only admin can perform this action"
message = MSG_ADMIN_ONLY

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
if info.context.token_store is None:
return False
try:
user = info.context.request.user
except AttributeError:
if not info.context.auth_enabled:
return False
return (
isinstance(user, PhoenixUser)
and user.is_authenticated
and user.claims is not None
and user.claims.attributes is not None
and user.claims.attributes.user_role == enums.UserRole.ADMIN
)


class HasSecret(BasePermission):
message = "Application secret is not set"

def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
return info.context.secret is not None
return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
10 changes: 8 additions & 2 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from asyncio import get_running_loop
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, cast

from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
Expand Down Expand Up @@ -42,6 +42,7 @@
UserRolesDataLoader,
UsersDataLoader,
)
from phoenix.server.bearer_auth import PhoenixUser
from phoenix.server.dml_event import DmlEvent
from phoenix.server.types import (
CanGetLastUpdatedAt,
Expand Down Expand Up @@ -96,6 +97,7 @@ class Context(BaseContext):
event_queue: CanPutItem[DmlEvent] = _NoOp()
corpus: Optional[Model] = None
read_only: bool = False
auth_enabled: bool = False
secret: Optional[str] = None
token_store: Optional[TokenStore] = None

Expand Down Expand Up @@ -146,3 +148,7 @@ async def log_out(self, user_id: int) -> None:
response = self.get_response()
response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)

@cached_property
axiomofjoy marked this conversation as resolved.
Show resolved Hide resolved
def user(self) -> PhoenixUser:
return cast(PhoenixUser, self.get_request().user)
47 changes: 6 additions & 41 deletions src/phoenix/server/api/mutations/api_key_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from strawberry.types import Info

from phoenix.db import enums, models
from phoenix.db.models import ApiKey as OrmApiKey
from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly
from phoenix.server.api.auth import IsAdmin, IsNotReadOnly
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.queries import Query
Expand Down Expand Up @@ -59,32 +58,9 @@ class DeleteApiKeyMutationPayload:
query: Query


def can_delete_user_key(info: Info[Context, None], key: OrmApiKey) -> bool:
try:
user = info.context.request.user # type: ignore
except AttributeError:
return False
return (
isinstance(user, PhoenixUser)
and user.claims is not None
and user.claims.attributes is not None
and (
user.claims.attributes.user_role == enums.UserRole.ADMIN
or int(user.identity) == key.user_id
)
)


@strawberry.type
class ApiKeyMutationMixin:
@strawberry.mutation(
permission_classes=[
IsNotReadOnly,
HasSecret,
IsAuthenticated,
IsAdmin,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
async def create_system_api_key(
self, info: Info[Context, None], input: CreateApiKeyInput
) -> CreateSystemApiKeyMutationPayload:
Expand Down Expand Up @@ -125,13 +101,7 @@ async def create_system_api_key(
query=Query(),
)

@strawberry.mutation(
permission_classes=[
IsNotReadOnly,
HasSecret,
IsAuthenticated,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
async def create_user_api_key(
self, info: Info[Context, None], input: CreateUserApiKeyInput
) -> CreateUserApiKeyMutationPayload:
Expand Down Expand Up @@ -166,7 +136,7 @@ async def create_user_api_key(
query=Query(),
)

@strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated, IsAdmin, IsNotReadOnly]) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
async def delete_system_api_key(
self, info: Info[Context, None], input: DeleteApiKeyInput
) -> DeleteApiKeyMutationPayload:
Expand All @@ -177,12 +147,7 @@ async def delete_system_api_key(
await token_store.revoke(ApiKeyId(api_key_id))
return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query())

@strawberry.mutation(
permission_classes=[
HasSecret,
IsAuthenticated,
]
) # type: ignore
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
async def delete_user_api_key(
self, info: Info[Context, None], input: DeleteApiKeyInput
) -> DeleteApiKeyMutationPayload:
Expand All @@ -196,7 +161,7 @@ async def delete_user_api_key(
)
if api_key is None:
raise ValueError(f"API key with id {input.id} not found")
if not can_delete_user_key(info, api_key):
if int((user := info.context.user).identity) != api_key.user_id and not user.is_admin:
raise Unauthorized("User not authorized to delete")
await token_store.revoke(ApiKeyId(api_key_id))
return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query())
Loading
Loading