diff --git a/requirements/unit-tests.txt b/requirements/unit-tests.txt index 813f21f816..6f18910537 100644 --- a/requirements/unit-tests.txt +++ b/requirements/unit-tests.txt @@ -5,6 +5,7 @@ asgi-lifespan asyncpg grpc-interceptor[testing] httpx # For OpenAI testing +httpx-ws litellm>=1.0.3 nest-asyncio # for executor testing numpy @@ -17,3 +18,4 @@ respx # For OpenAI testing tenacity tiktoken typing-extensions==4.7.0 +vcrpy diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 49768f95cd..99bdc076df 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -1,8 +1,7 @@ import json from abc import ABC, abstractmethod from collections import defaultdict -from dataclasses import fields -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from itertools import chain from typing import ( @@ -355,7 +354,6 @@ async def chat_completion( ) tracer = tracer_provider.get_tracer(__name__) span_name = "ChatCompletion" - with tracer.start_span( span_name, attributes=dict( @@ -480,9 +478,11 @@ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]: def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]: - assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput)) + assert (api_key := "api_key") in (input_data := jsonify(input)) + input_data = {k: v for k, v in input_data.items() if k != api_key} + assert api_key not in input_data yield INPUT_MIME_TYPE, JSON - yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key}) + yield INPUT_VALUE, safe_json_dumps(input_data) def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]: @@ -530,7 +530,7 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime: Converts a Unix epoch timestamp in nanoseconds to a datetime. """ epoch_seconds = epoch_nanoseconds / 1e9 - return datetime.fromtimestamp(epoch_seconds) + return datetime.fromtimestamp(epoch_seconds).replace(tzinfo=timezone.utc) def _formatted_messages( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a0eeda8eec..5b264ff8bf 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -11,12 +11,16 @@ AsyncIterator, Awaitable, Callable, + Dict, Iterator, List, Literal, + Optional, Set, Tuple, ) +from urllib.parse import urljoin +from uuid import uuid4 import httpx import pytest @@ -26,11 +30,14 @@ from asgi_lifespan import LifespanManager from faker import Faker from httpx import URL, Request, Response +from httpx_ws import AsyncWebSocketSession, aconnect_ws +from httpx_ws.transport import ASGIWebSocketTransport from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import make_url from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from starlette.types import ASGIApp +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL import phoenix.trace.v1 as pb from phoenix.config import EXPORT_DIR @@ -221,8 +228,8 @@ async def read_stream(): request=request, ) - asgi_transport = httpx.ASGITransport(app=app) - transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport) + asgi_transport = ASGIWebSocketTransport(app=app) + transport = Transport(ASGIWebSocketTransport(app), asgi_transport=asgi_transport) base_url = "http://test" return ( httpx.Client(transport=transport, base_url=base_url), @@ -237,6 +244,11 @@ def httpx_client( return httpx_clients[1] +@pytest.fixture +def gql_client(httpx_client: httpx.AsyncClient) -> "AsyncGraphQLClient": + yield AsyncGraphQLClient(httpx_client) + + @pytest.fixture def px_client( httpx_clients: Tuple[httpx.Client, httpx.AsyncClient], @@ -341,3 +353,123 @@ def _(seen: Set[str]) -> Iterator[str]: yield span_id return _(set()) + + +class AsyncGraphQLClient: + """ + Async GraphQL client that can execute queries, mutations, and subscriptions. + """ + + def __init__( + self, httpx_client: httpx.AsyncClient, timeout_seconds: Optional[float] = 10 + ) -> None: + self._httpx_client = httpx_client + self._timeout_seconds = timeout_seconds + self._gql_url = urljoin(str(httpx_client.base_url), "/graphql") + + async def execute( + self, + query: str, + variables: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Executes queries and mutations. + """ + response = await self._httpx_client.post( + self._gql_url, + json={ + "query": query, + **({"variables": variables} if variables is not None else {}), + **({"operationName": operation_name} if operation_name is not None else {}), + }, + ) + response.raise_for_status() + response_json = response.json() + if (errors := response_json.get("errors")) is not None: + raise RuntimeError(errors) + assert isinstance(data := response_json.get("data"), dict) + return data + + @contextlib.asynccontextmanager + async def subscription( + self, + query: str, + variables: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> "GraphQLSubscription": + """ + Starts a GraphQL subscription session. + """ + async with aconnect_ws( + self._gql_url, + self._httpx_client, + subprotocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL], + ) as session: + await session.send_json({"type": "connection_init"}) + message = await session.receive_json(timeout=self._timeout_seconds) + if message.get("type") != "connection_ack": + raise RuntimeError("Websocket connection failed") + yield GraphQLSubscription( + session=session, + query=query, + variables=variables, + operation_name=operation_name, + timeout_seconds=self._timeout_seconds, + ) + + +class GraphQLSubscription: + """ + A session for a GraphQL subscription. + """ + + def __init__( + self, + *, + session: AsyncWebSocketSession, + query: str, + variables: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + timeout_seconds: Optional[str] = None, + ) -> None: + self._session = session + self._query = query + self._variables = variables + self._operation_name = operation_name + self._timeout_seconds = timeout_seconds + + async def stream( + self, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Streams subscription payloads. + """ + connection_id = str(uuid4()) + await self._session.send_json( + { + "id": connection_id, + "type": "subscribe", + "payload": { + "query": self._query, + **({"variables": self._variables} if self._variables is not None else {}), + **( + {"operationName": self._operation_name} + if self._operation_name is not None + else {} + ), + }, + } + ) + while True: + message = await self._session.receive_json(timeout=self._timeout_seconds) + message_type = message.get("type") + assert message.get("id") == connection_id + if message_type == "complete": + break + elif message_type == "next": + yield message["payload"]["data"] + elif message_type == "error": + raise RuntimeError(message["payload"]) + else: + assert False, f"Unexpected message type: {message_type}" diff --git a/tests/unit/server/api/cassettes/test_subscriptions/TestChatCompletionSubscription.test_openai_text_response_emits_expected_payloads_and_records_expected_span[sqlite].yaml b/tests/unit/server/api/cassettes/test_subscriptions/TestChatCompletionSubscription.test_openai_text_response_emits_expected_payloads_and_records_expected_span[sqlite].yaml new file mode 100644 index 0000000000..694e18b59c --- /dev/null +++ b/tests/unit/server/api/cassettes/test_subscriptions/TestChatCompletionSubscription.test_openai_text_response_emits_expected_payloads_and_records_expected_span[sqlite].yaml @@ -0,0 +1,31 @@ +interactions: +- request: + body: '{"messages": [{"content": "Who won the World Cup in 2018? Answer in one + word", "role": "user"}], "model": "gpt-4", "stream": true, "stream_options": + {"include_usage": true}, "temperature": 0.1}' + headers: {} + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: 'data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"France"},"logprobs":null,"finish_reason":null}],"usage":null} + + + data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + + + data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[],"usage":{"prompt_tokens":21,"completion_tokens":1,"total_tokens":22,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + + data: [DONE] + + + ' + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py new file mode 100644 index 0000000000..1fc71afce2 --- /dev/null +++ b/tests/unit/server/api/test_subscriptions.py @@ -0,0 +1,268 @@ +import json +import sys +from pathlib import Path +from typing import Any, Dict + +import pytest +from openinference.semconv.trace import ( + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, +) +from vcr import use_cassette + +from phoenix.trace.attributes import flatten + + +def remove_all_vcr_request_headers(request: Any) -> Any: + """ + Removes all request headers. + + Example: + ``` + @pytest.mark.vcr( + before_record_response=remove_all_vcr_request_headers + ) + def test_openai() -> None: + # make request to OpenAI + """ + request.headers.clear() + return request + + +def remove_all_vcr_response_headers(response: Dict[str, Any]) -> Dict[str, Any]: + """ + Removes all response headers. + + Example: + ``` + @pytest.mark.vcr( + before_record_response=remove_all_vcr_response_headers + ) + def test_openai() -> None: + # make request to OpenAI + """ + response["headers"] = {} + return response + + +@pytest.mark.skipif( + sys.platform + in ( + "win32", + "linux", + ), # todo: support windows and linux https://github.com/Arize-ai/phoenix/issues/5126 + reason="subscriptions are not currently supported on windows or linux", +) +class TestChatCompletionSubscription: + QUERY = """ + subscription ChatCompletionSubscription($input: ChatCompletionInput!) { + chatCompletion(input: $input) { + __typename + ... on TextChunk { + content + } + ... on ToolCallChunk { + id + function { + name + arguments + } + } + ... on FinishedChatCompletion { + span { + ...SpanFragment + } + } + } + } + + query SpanQuery($spanId: GlobalID!) { + span: node(id: $spanId) { + ... on Span { + ...SpanFragment + } + } + } + + fragment SpanFragment on Span { + id + name + statusCode + statusMessage + startTime + endTime + latencyMs + parentId + spanKind + context { + spanId + traceId + } + attributes + metadata + numDocuments + tokenCountTotal + tokenCountPrompt + tokenCountCompletion + input { + mimeType + value + } + output { + mimeType + value + } + events { + name + message + timestamp + } + cumulativeTokenCountTotal + cumulativeTokenCountPrompt + cumulativeTokenCountCompletion + propagatedStatusCode + } + """ + + async def test_openai_text_response_emits_expected_payloads_and_records_expected_span( + self, + gql_client: Any, + openai_api_key: str, + ) -> None: + variables = { + "input": { + "messages": [ + { + "role": "USER", + "content": "Who won the World Cup in 2018? Answer in one word", + } + ], + "model": {"name": "gpt-4", "providerKey": "OPENAI"}, + "invocationParameters": { + "temperature": 0.1, + }, + }, + } + async with gql_client.subscription( + query=self.QUERY, + variables=variables, + operation_name="ChatCompletionSubscription", + ) as subscription: + with use_cassette( + Path(__file__).parent / "cassettes/test_subscriptions/" + "TestChatCompletionSubscription.test_openai_text_response_emits_expected_payloads_and_records_expected_span[sqlite].yaml", + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, + ): + payloads = [payload["chatCompletion"] async for payload in subscription.stream()] + + # check subscription payloads + assert payloads + assert (last_payload := payloads.pop())["__typename"] == "FinishedChatCompletion" + assert all(payload["__typename"] == "TextChunk" for payload in payloads) + response_text = "".join(payload["content"] for payload in payloads) + assert "france" in response_text.lower() + subscription_span = last_payload["span"] + span_id = subscription_span["id"] + + # query for the span via the node interface to ensure that the span + # recorded in the db contains identical information as the span emitted + # by the subscription + data = await gql_client.execute( + query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" + ) + span = data["span"] + assert span == subscription_span + + # check attributes + assert span.pop("id") == span_id + assert span.pop("name") == "ChatCompletion" + assert span.pop("statusCode") == "OK" + assert not span.pop("statusMessage") + assert span.pop("startTime") + assert span.pop("endTime") + assert isinstance(span.pop("latencyMs"), float) + assert span.pop("parentId") is None + assert span.pop("spanKind") == "llm" + assert (context := span.pop("context")).pop("spanId") + assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) + assert context.pop("traceId") + assert not context + assert span.pop("metadata") is None + assert span.pop("numDocuments") is None + assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) + assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) + assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) + assert token_count_prompt > 0 + assert token_count_completion > 0 + assert token_count_total == token_count_prompt + token_count_completion + assert (input := span.pop("input")).pop("mimeType") == "json" + assert (input_value := input.pop("value")) + assert not input + assert "api_key" not in input_value + assert "apiKey" not in input_value + assert (output := span.pop("output")).pop("mimeType") == "json" + assert output.pop("value") + assert not output + assert not span.pop("events") + assert isinstance( + cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int + ) + assert isinstance( + cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int + ) + assert isinstance( + cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int + ) + assert cumulative_token_count_total == token_count_total + assert cumulative_token_count_prompt == token_count_prompt + assert cumulative_token_count_completion == token_count_completion + assert span.pop("propagatedStatusCode") == "OK" + assert not span + + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM + assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" + assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"temperature": 0.1}) + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt + assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion + assert attributes.pop(INPUT_VALUE) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + assert attributes.pop(OUTPUT_VALUE) + assert attributes.pop(OUTPUT_MIME_TYPE) == JSON + assert attributes.pop(LLM_INPUT_MESSAGES) == [ + { + "message": { + "role": "user", + "content": "Who won the World Cup in 2018? Answer in one word", + } + } + ] + assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ + { + "message": { + "role": "assistant", + "content": response_text, + } + } + ] + assert not attributes + + +LLM = OpenInferenceSpanKindValues.LLM.value +JSON = OpenInferenceMimeTypeValues.JSON.value + +OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND +LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME +LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES +LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES +INPUT_VALUE = SpanAttributes.INPUT_VALUE +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE