diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 8affa9613a..8d43cf6d78 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -11,6 +11,7 @@ from typing_extensions import TypeVar from phoenix.config import get_env_phoenix_use_secure_cookies +from phoenix.db.models import User as OrmUser ResponseType = TypeVar("ResponseType", bound=Response) @@ -68,6 +69,14 @@ def validate_password_format(password: str) -> None: PASSWORD_REQUIREMENTS.validate(password) +def is_locally_authenticated(user: OrmUser) -> bool: + """ + Returns true if the user is authenticated locally, i.e., not through an + OAuth2 identity provider, and false otherwise. + """ + return user.oauth2_client_id is None and user.oauth2_user_id is None + + def set_access_token_cookie( *, response: ResponseType, access_token: str, max_age: timedelta ) -> ResponseType: diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 76d48aab12..af0c3ee702 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -18,6 +18,7 @@ PASSWORD_REQUIREMENTS, PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME, + is_locally_authenticated, validate_email_format, validate_password_format, ) @@ -137,7 +138,7 @@ async def patch_user( raise NotFound(f"Role {input.new_role.value} not found") user.user_role_id = user_role_id if password := input.new_password: - if not _is_locally_authenticated_user(user): + if not is_locally_authenticated(user): raise Conflict("Cannot modify password for non-local user") validate_password_format(password) user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) @@ -170,7 +171,7 @@ async def patch_viewer( raise NotFound("User not found") stack.enter_context(session.no_autoflush) if password := input.new_password: - if not _is_locally_authenticated_user(user): + if not is_locally_authenticated(user): raise Conflict("Cannot modify password for non-local user") if not ( current_password := input.current_password @@ -325,14 +326,6 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: ) -def _is_locally_authenticated_user(user: models.User) -> bool: - """ - Returns true if the user is authenticated locally, i.e., not through an - OAuth2 identity provider, and false otherwise. - """ - return user.oauth2_client_id is None and user.oauth2_user_id is None - - def _user_operation_error_message( error: IntegrityError, operation: Literal["create", "modify"] = "create", diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index 2d08cd8373..35181565d2 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -25,13 +25,14 @@ delete_oauth2_nonce_cookie, delete_oauth2_state_cookie, delete_refresh_token_cookie, + is_locally_authenticated, is_valid_password, set_access_token_cookie, set_refresh_token_cookie, validate_password_format, ) from phoenix.config import get_base_url -from phoenix.db import enums, models +from phoenix.db import models from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens from phoenix.server.email.templates.types import PasswordResetTemplateBody from phoenix.server.email.types import EmailSender @@ -197,7 +198,7 @@ async def initiate_password_reset(request: Request) -> Response: joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id) ) ) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or not is_locally_authenticated(user): # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) if user.password_reset_token: @@ -229,7 +230,7 @@ async def reset_password(request: Request) -> Response: assert (user_id := claims.subject) async with request.app.state.db() as session: user = await session.scalar(_select_active_user().filter_by(id=int(user_id))) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or not is_locally_authenticated(user): # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) validate_password_format(password) diff --git a/src/phoenix/server/api/types/AuthMethod.py b/src/phoenix/server/api/types/AuthMethod.py index 011140e035..f3c77e9b51 100644 --- a/src/phoenix/server/api/types/AuthMethod.py +++ b/src/phoenix/server/api/types/AuthMethod.py @@ -6,3 +6,4 @@ @strawberry.enum class AuthMethod(Enum): LOCAL = "LOCAL" + OAUTH2 = "OAUTH2" diff --git a/src/phoenix/server/api/types/User.py b/src/phoenix/server/api/types/User.py index eddf677e62..05adae4066 100644 --- a/src/phoenix/server/api/types/User.py +++ b/src/phoenix/server/api/types/User.py @@ -7,6 +7,7 @@ from strawberry.relay import Node, NodeID from strawberry.types import Info +from phoenix.auth import is_locally_authenticated from phoenix.db import models from phoenix.server.api.context import Context from phoenix.server.api.exceptions import NotFound @@ -53,5 +54,5 @@ def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = Non email=user.email, created_at=user.created_at, user_role_id=user.user_role_id, - auth_method=AuthMethod("MEMBER"), + auth_method=AuthMethod.LOCAL if is_locally_authenticated(user) else AuthMethod.OAUTH2, )