Skip to content

Commit

Permalink
fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Sep 19, 2024
1 parent c34170e commit 3783a4a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
9 changes: 9 additions & 0 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import TypeVar

from phoenix.config import get_env_phoenix_use_secure_cookies
from phoenix.db.models import User as OrmUser

ResponseType = TypeVar("ResponseType", bound=Response)

Expand Down Expand Up @@ -68,6 +69,14 @@ def validate_password_format(password: str) -> None:
PASSWORD_REQUIREMENTS.validate(password)


def is_locally_authenticated(user: OrmUser) -> bool:
"""
Returns true if the user is authenticated locally, i.e., not through an
OAuth2 identity provider, and false otherwise.
"""
return user.oauth2_client_id is None and user.oauth2_user_id is None


def set_access_token_cookie(
*, response: ResponseType, access_token: str, max_age: timedelta
) -> ResponseType:
Expand Down
13 changes: 3 additions & 10 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PASSWORD_REQUIREMENTS,
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
is_locally_authenticated,
validate_email_format,
validate_password_format,
)
Expand Down Expand Up @@ -137,7 +138,7 @@ async def patch_user(
raise NotFound(f"Role {input.new_role.value} not found")
user.user_role_id = user_role_id
if password := input.new_password:
if not _is_locally_authenticated_user(user):
if not is_locally_authenticated(user):
raise Conflict("Cannot modify password for non-local user")
validate_password_format(password)
user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
Expand Down Expand Up @@ -170,7 +171,7 @@ async def patch_viewer(
raise NotFound("User not found")
stack.enter_context(session.no_autoflush)
if password := input.new_password:
if not _is_locally_authenticated_user(user):
if not is_locally_authenticated(user):
raise Conflict("Cannot modify password for non-local user")
if not (
current_password := input.current_password
Expand Down Expand Up @@ -325,14 +326,6 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
)


def _is_locally_authenticated_user(user: models.User) -> bool:
"""
Returns true if the user is authenticated locally, i.e., not through an
OAuth2 identity provider, and false otherwise.
"""
return user.oauth2_client_id is None and user.oauth2_user_id is None


def _user_operation_error_message(
error: IntegrityError,
operation: Literal["create", "modify"] = "create",
Expand Down
7 changes: 4 additions & 3 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
delete_oauth2_nonce_cookie,
delete_oauth2_state_cookie,
delete_refresh_token_cookie,
is_locally_authenticated,
is_valid_password,
set_access_token_cookie,
set_refresh_token_cookie,
validate_password_format,
)
from phoenix.config import get_base_url
from phoenix.db import enums, models
from phoenix.db import models
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
from phoenix.server.email.templates.types import PasswordResetTemplateBody
from phoenix.server.email.types import EmailSender
Expand Down Expand Up @@ -197,7 +198,7 @@ async def initiate_password_reset(request: Request) -> Response:
joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id)
)
)
if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
if user is None or not is_locally_authenticated(user):
# Withold privileged information
return Response(status_code=HTTP_204_NO_CONTENT)
if user.password_reset_token:
Expand Down Expand Up @@ -229,7 +230,7 @@ async def reset_password(request: Request) -> Response:
assert (user_id := claims.subject)
async with request.app.state.db() as session:
user = await session.scalar(_select_active_user().filter_by(id=int(user_id)))
if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
if user is None or not is_locally_authenticated(user):
# Withold privileged information
return Response(status_code=HTTP_204_NO_CONTENT)
validate_password_format(password)
Expand Down
1 change: 1 addition & 0 deletions src/phoenix/server/api/types/AuthMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
@strawberry.enum
class AuthMethod(Enum):
LOCAL = "LOCAL"
OAUTH2 = "OAUTH2"
3 changes: 2 additions & 1 deletion src/phoenix/server/api/types/User.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from strawberry.relay import Node, NodeID
from strawberry.types import Info

from phoenix.auth import is_locally_authenticated
from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import NotFound
Expand Down Expand Up @@ -53,5 +54,5 @@ def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = Non
email=user.email,
created_at=user.created_at,
user_role_id=user.user_role_id,
auth_method=AuthMethod("MEMBER"),
auth_method=AuthMethod.LOCAL if is_locally_authenticated(user) else AuthMethod.OAUTH2,
)

0 comments on commit 3783a4a

Please sign in to comment.