Skip to content

Commit

Permalink
fix: inadvertent cookie deletion when changing user password via Patc…
Browse files Browse the repository at this point in the history
…hUser (#4637)
  • Loading branch information
RogerHYang committed Sep 21, 2024
1 parent a253f64 commit 7077cc2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
32 changes: 21 additions & 11 deletions integration_tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import httpx
import pytest
from httpx import HTTPStatusError
from httpx import Headers, HTTPStatusError
from openinference.semconv.resource import ResourceAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
Expand Down Expand Up @@ -135,7 +135,7 @@ def gql(
self,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
) -> Tuple[Dict[str, Any], Headers]:
return _gql(self, query=query, variables=variables)

def create_user(
Expand Down Expand Up @@ -641,10 +641,10 @@ def _gql(
*,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
) -> Tuple[Dict[str, Any], Headers]:
json_ = dict(query=query, variables=dict(variables or {}))
resp = _httpx_client(auth).post("graphql", json=json_)
return _json(resp)
return _json(resp), resp.headers


def _get_gql_spans(
Expand All @@ -654,8 +654,9 @@ def _get_gql_spans(
) -> Dict[_ProjectName, List[Dict[str, Any]]]:
out = "name spans{edges{node{" + " ".join(fields) + "}}}"
query = "query{projects{edges{node{" + out + "}}}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert not resp_dict.get("errors")
assert not headers.get("set-cookie")
return {
project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]]
for project in resp_dict["data"]["projects"]["edges"]
Expand All @@ -677,10 +678,11 @@ def _create_user(
args.append(f'username:"{username}"')
out = "user{id email role{name}}"
query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert (user := resp_dict["data"]["createUser"]["user"])
assert user["email"] == email
assert user["role"]["name"] == role.value
assert not headers.get("set-cookie")
return _User(_GqlId(user["id"]), role, profile)


Expand All @@ -692,7 +694,8 @@ def _delete_users(
) -> None:
user_ids = [u.gid if isinstance(u, _User) else u for u in users]
query = "mutation($userIds:[GlobalID!]!){deleteUsers(input:{userIds:$userIds})}"
_gql(auth, query=query, variables=dict(userIds=user_ids))
_, headers = _gql(auth, query=query, variables=dict(userIds=user_ids))
assert not headers.get("set-cookie")


def _patch_user_gid(
Expand All @@ -713,14 +716,15 @@ def _patch_user_gid(
args.append(f"newRole:{new_role.value}")
out = "user{id username role{name}}"
query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert (data := resp_dict["data"]["patchUser"])
assert (result := data["user"])
assert result["id"] == gid
if new_username:
assert result["username"] == new_username
if new_role:
assert result["role"]["name"] == new_role.value
assert not headers.get("set-cookie")


def _patch_user(
Expand Down Expand Up @@ -765,11 +769,15 @@ def _patch_viewer(
args.append(f'newUsername:"{new_username}"')
out = "user{username}"
query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert (data := resp_dict["data"]["patchViewer"])
assert (user := data["user"])
if new_username:
assert user["username"] == new_username
if new_password:
assert headers.get("set-cookie")
else:
assert not headers.get("set-cookie")


def _create_api_key(
Expand All @@ -786,12 +794,13 @@ def _create_api_key(
args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}"
field = f"create{kind}ApiKey"
query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert (data := resp_dict["data"][field])
assert (key := data["apiKey"])
assert key["name"] == name
exp_t = datetime.fromisoformat(key["expiresAt"]) if key["expiresAt"] else None
assert exp_t == expires_at
assert not headers.get("set-cookie")
return _ApiKey(data["jwt"], _GqlId(key["id"]), kind)


Expand All @@ -805,8 +814,9 @@ def _delete_api_key(
gid = api_key.gid
args, out = f'id:"{gid}"', "apiKeyId"
query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}"
resp_dict = _gql(auth, query=query)
resp_dict, headers = _gql(auth, query=query)
assert resp_dict["data"][field]["apiKeyId"] == gid
assert not headers.get("set-cookie")


def _log_in(
Expand Down
5 changes: 0 additions & 5 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from strawberry.fastapi import BaseContext

from phoenix.auth import (
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
compute_password_hash,
)
from phoenix.core.model_schema import Model
Expand Down Expand Up @@ -145,9 +143,6 @@ async def hash_password(password: str, salt: bytes) -> bytes:
async def log_out(self, user_id: int) -> None:
assert self.token_store is not None
await self.token_store.log_out(UserId(user_id))
response = self.get_response()
response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)

@cached_property
def user(self) -> PhoenixUser:
Expand Down
5 changes: 5 additions & 0 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DEFAULT_ADMIN_USERNAME,
DEFAULT_SECRET_LENGTH,
PASSWORD_REQUIREMENTS,
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
validate_email_format,
validate_password_format,
)
Expand Down Expand Up @@ -188,6 +190,9 @@ async def patch_viewer(
assert user
if input.new_password:
await info.context.log_out(user.id)
response = info.context.get_response()
response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
return UserMutationPayload(user=to_gql_user(user))

@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
Expand Down

0 comments on commit 7077cc2

Please sign in to comment.