Skip to content

Commit

Permalink
test(playground): add api test for chat completion subscription (#5125)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 21, 2024
1 parent a2de578 commit 3783a27
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 8 deletions.
2 changes: 2 additions & 0 deletions requirements/unit-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ asgi-lifespan
asyncpg
grpc-interceptor[testing]
httpx # For OpenAI testing
httpx-ws
litellm>=1.0.3
nest-asyncio # for executor testing
numpy
Expand All @@ -17,3 +18,4 @@ respx # For OpenAI testing
tenacity
tiktoken
typing-extensions==4.7.0
vcrpy
12 changes: 6 additions & 6 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import fields
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from itertools import chain
from typing import (
Expand Down Expand Up @@ -355,7 +354,6 @@ async def chat_completion(
)
tracer = tracer_provider.get_tracer(__name__)
span_name = "ChatCompletion"

with tracer.start_span(
span_name,
attributes=dict(
Expand Down Expand Up @@ -480,9 +478,11 @@ def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
assert (api_key := "api_key") in (input_data := jsonify(input))
input_data = {k: v for k, v in input_data.items() if k != api_key}
assert api_key not in input_data
yield INPUT_MIME_TYPE, JSON
yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})
yield INPUT_VALUE, safe_json_dumps(input_data)


def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
Expand Down Expand Up @@ -530,7 +530,7 @@ def _datetime(*, epoch_nanoseconds: float) -> datetime:
Converts a Unix epoch timestamp in nanoseconds to a datetime.
"""
epoch_seconds = epoch_nanoseconds / 1e9
return datetime.fromtimestamp(epoch_seconds)
return datetime.fromtimestamp(epoch_seconds).replace(tzinfo=timezone.utc)


def _formatted_messages(
Expand Down
136 changes: 134 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
)
from urllib.parse import urljoin
from uuid import uuid4

import httpx
import pytest
Expand All @@ -26,11 +30,14 @@
from asgi_lifespan import LifespanManager
from faker import Faker
from httpx import URL, Request, Response
from httpx_ws import AsyncWebSocketSession, aconnect_ws
from httpx_ws.transport import ASGIWebSocketTransport
from psycopg import Connection
from pytest_postgresql import factories
from sqlalchemy import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.types import ASGIApp
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL

import phoenix.trace.v1 as pb
from phoenix.config import EXPORT_DIR
Expand Down Expand Up @@ -221,8 +228,8 @@ async def read_stream():
request=request,
)

asgi_transport = httpx.ASGITransport(app=app)
transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport)
asgi_transport = ASGIWebSocketTransport(app=app)
transport = Transport(ASGIWebSocketTransport(app), asgi_transport=asgi_transport)
base_url = "http://test"
return (
httpx.Client(transport=transport, base_url=base_url),
Expand All @@ -237,6 +244,11 @@ def httpx_client(
return httpx_clients[1]


@pytest.fixture
def gql_client(httpx_client: httpx.AsyncClient) -> "AsyncGraphQLClient":
yield AsyncGraphQLClient(httpx_client)


@pytest.fixture
def px_client(
httpx_clients: Tuple[httpx.Client, httpx.AsyncClient],
Expand Down Expand Up @@ -341,3 +353,123 @@ def _(seen: Set[str]) -> Iterator[str]:
yield span_id

return _(set())


class AsyncGraphQLClient:
"""
Async GraphQL client that can execute queries, mutations, and subscriptions.
"""

def __init__(
self, httpx_client: httpx.AsyncClient, timeout_seconds: Optional[float] = 10
) -> None:
self._httpx_client = httpx_client
self._timeout_seconds = timeout_seconds
self._gql_url = urljoin(str(httpx_client.base_url), "/graphql")

async def execute(
self,
query: str,
variables: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Executes queries and mutations.
"""
response = await self._httpx_client.post(
self._gql_url,
json={
"query": query,
**({"variables": variables} if variables is not None else {}),
**({"operationName": operation_name} if operation_name is not None else {}),
},
)
response.raise_for_status()
response_json = response.json()
if (errors := response_json.get("errors")) is not None:
raise RuntimeError(errors)
assert isinstance(data := response_json.get("data"), dict)
return data

@contextlib.asynccontextmanager
async def subscription(
self,
query: str,
variables: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
) -> "GraphQLSubscription":
"""
Starts a GraphQL subscription session.
"""
async with aconnect_ws(
self._gql_url,
self._httpx_client,
subprotocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
) as session:
await session.send_json({"type": "connection_init"})
message = await session.receive_json(timeout=self._timeout_seconds)
if message.get("type") != "connection_ack":
raise RuntimeError("Websocket connection failed")
yield GraphQLSubscription(
session=session,
query=query,
variables=variables,
operation_name=operation_name,
timeout_seconds=self._timeout_seconds,
)


class GraphQLSubscription:
"""
A session for a GraphQL subscription.
"""

def __init__(
self,
*,
session: AsyncWebSocketSession,
query: str,
variables: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
timeout_seconds: Optional[str] = None,
) -> None:
self._session = session
self._query = query
self._variables = variables
self._operation_name = operation_name
self._timeout_seconds = timeout_seconds

async def stream(
self,
) -> AsyncIterator[Dict[str, Any]]:
"""
Streams subscription payloads.
"""
connection_id = str(uuid4())
await self._session.send_json(
{
"id": connection_id,
"type": "subscribe",
"payload": {
"query": self._query,
**({"variables": self._variables} if self._variables is not None else {}),
**(
{"operationName": self._operation_name}
if self._operation_name is not None
else {}
),
},
}
)
while True:
message = await self._session.receive_json(timeout=self._timeout_seconds)
message_type = message.get("type")
assert message.get("id") == connection_id
if message_type == "complete":
break
elif message_type == "next":
yield message["payload"]["data"]
elif message_type == "error":
raise RuntimeError(message["payload"])
else:
assert False, f"Unexpected message type: {message_type}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
interactions:
- request:
body: '{"messages": [{"content": "Who won the World Cup in 2018? Answer in one
word", "role": "user"}], "model": "gpt-4", "stream": true, "stream_options":
{"include_usage": true}, "temperature": 0.1}'
headers: {}
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: 'data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"France"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
data: {"id":"chatcmpl-AKIfgHRmoJMmeKa7KgYrn0xMd2n63","object":"chat.completion.chunk","created":1729401696,"model":"gpt-4-0613","system_fingerprint":null,"choices":[],"usage":{"prompt_tokens":21,"completion_tokens":1,"total_tokens":22,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}}
data: [DONE]
'
headers: {}
status:
code: 200
message: OK
version: 1
Loading

0 comments on commit 3783a27

Please sign in to comment.