Skip to content

Commit

Permalink
fix(playground): ensure chat completion subscription unit tests run o…
Browse files Browse the repository at this point in the history
…n all os and on postgres (#5205)

Co-authored-by: Dustin Ngo <dustin.ngo@gmail.com>
  • Loading branch information
axiomofjoy and anticorrelator authored Oct 28, 2024
1 parent cb55790 commit 78942cd
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from faker import Faker
from httpx import AsyncByteStream, Request, Response
from httpx_ws import AsyncWebSocketSession, aconnect_ws
from httpx_ws.transport import ASGIWebSocketTransport
from psycopg import Connection
from pytest_postgresql import factories
from sqlalchemy import URL, make_url
Expand All @@ -42,6 +41,7 @@
from phoenix.server.types import BatchedCaller, DbSessionFactory
from phoenix.session.client import Client
from phoenix.trace.schemas import Span
from tests.unit.transport import ASGIWebSocketTransport


def pytest_terminal_summary(
Expand Down
35 changes: 20 additions & 15 deletions tests/unit/server/api/test_subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any

import pytest
from openinference.semconv.trace import (
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
Expand Down Expand Up @@ -53,14 +51,6 @@ def test_openai() -> None:
return response


@pytest.mark.skipif(
sys.platform
in (
"win32",
"linux",
), # todo: support windows and linux https://github.com/Arize-ai/phoenix/issues/5126
reason="subscriptions are not currently supported on windows or linux",
)
class TestChatCompletionSubscription:
QUERY = """
subscription ChatCompletionSubscription($input: ChatCompletionInput!) {
Expand Down Expand Up @@ -188,6 +178,10 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -201,7 +195,6 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down Expand Up @@ -317,6 +310,10 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -330,7 +327,6 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down Expand Up @@ -457,6 +453,10 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -470,7 +470,6 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down Expand Up @@ -607,6 +606,10 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -620,7 +623,6 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down Expand Up @@ -750,6 +752,10 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -763,7 +769,6 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down
256 changes: 256 additions & 0 deletions tests/unit/transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
"""
This file contains a copy of [httpx-ws](https://github.com/frankie567/httpx-ws),
which is published under an [MIT
license](https://github.com/frankie567/httpx-ws/blob/main/LICENSE).
Modifications have been made to better support the concurrency paradigm used in
our unit test suite.
"""

import asyncio
import contextlib
import typing

import wsproto
from httpcore import AsyncNetworkStream
from httpx import ASGITransport, AsyncByteStream, Request, Response
from wsproto.frame_protocol import CloseReason

Scope = dict[str, typing.Any]
Message = dict[str, typing.Any]
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]]


class HTTPXWSException(Exception):
"""
Base exception class for HTTPX WS.
"""

pass


class WebSocketDisconnect(HTTPXWSException):
"""
Raised when the server closed the WebSocket session.
Args:
code:
The integer close code to indicate why the connection has closed.
reason:
Additional reasoning for why the connection has closed.
"""

def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
self.code = code
self.reason = reason or ""


class ASGIWebSocketTransportError(Exception):
pass


class UnhandledASGIMessageType(ASGIWebSocketTransportError):
def __init__(self, message: Message) -> None:
self.message = message


class UnhandledWebSocketEvent(ASGIWebSocketTransportError):
def __init__(self, event: wsproto.events.Event) -> None:
self.event = event


class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream):
def __init__(self, app: ASGIApp, scope: Scope) -> None:
self.app = app
self.scope = scope
self._receive_queue: asyncio.Queue[Message] = asyncio.Queue()
self._send_queue: asyncio.Queue[Message] = asyncio.Queue()
self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER)
self.connection.initiate_upgrade_connection(scope["headers"], scope["path"])
self.tasks: list[asyncio.Task[None]] = []

async def __aenter__(
self,
) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]:
self.exit_stack = contextlib.AsyncExitStack()
await self.exit_stack.__aenter__()

# Start the _run coroutine as a task
self._run_task = asyncio.create_task(self._run())
self.tasks.append(self._run_task)
self.exit_stack.push_async_callback(self._cancel_tasks)

await self.send({"type": "websocket.connect"})
message = await self.receive()

if message["type"] == "websocket.close":
await self.aclose()
raise WebSocketDisconnect(message["code"], message.get("reason"))

assert message["type"] == "websocket.accept"
return self, self._build_accept_response(message)

async def __aexit__(self, *args: typing.Any) -> None:
await self.aclose()
await self.exit_stack.aclose()

async def _cancel_tasks(self) -> None:
# Cancel all running tasks
for task in self.tasks:
task.cancel()
# Wait for tasks to be cancelled
await asyncio.gather(*self.tasks, return_exceptions=True)

async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
message: Message = await self.receive()
message_type = message["type"]

if message_type not in {"websocket.send", "websocket.close"}:
raise UnhandledASGIMessageType(message)

event: wsproto.events.Event
if message_type == "websocket.send":
data_str: typing.Optional[str] = message.get("text")
if data_str is not None:
event = wsproto.events.TextMessage(data_str)
else:
data_bytes: typing.Optional[bytes] = message.get("bytes")
if data_bytes is not None:
event = wsproto.events.BytesMessage(data_bytes)
else:
# If neither text nor bytes are provided, raise an error
raise ValueError("websocket.send message missing 'text' or 'bytes'")
elif message_type == "websocket.close":
event = wsproto.events.CloseConnection(message["code"], message.get("reason"))

return self.connection.send(event)

async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
self.connection.receive_data(buffer)
for event in self.connection.events():
if isinstance(event, wsproto.events.Request):
pass # Already handled in __init__
elif isinstance(event, wsproto.events.CloseConnection):
await self.send(
{
"type": "websocket.close",
"code": event.code,
"reason": event.reason,
}
)
elif isinstance(event, wsproto.events.TextMessage):
await self.send({"type": "websocket.receive", "text": event.data})
elif isinstance(event, wsproto.events.BytesMessage):
await self.send({"type": "websocket.receive", "bytes": event.data})
else:
raise UnhandledWebSocketEvent(event)

async def aclose(self) -> None:
await self.send({"type": "websocket.close"})
# Ensure tasks are cancelled and cleaned up
await self._cancel_tasks()

async def send(self, message: Message) -> None:
await self._receive_queue.put(message)

async def receive(self, timeout: typing.Optional[float] = None) -> Message:
try:
message = await asyncio.wait_for(self._send_queue.get(), timeout)
return message
except asyncio.TimeoutError:
raise TimeoutError("Timed out waiting for message")

async def _run(self) -> None:
"""
The coroutine in which the websocket session runs.
"""
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
await self.app(scope, receive, send)
except Exception as e:
message = {
"type": "websocket.close",
"code": CloseReason.INTERNAL_ERROR,
"reason": str(e),
}
await self._asgi_send(message)

async def _asgi_receive(self) -> Message:
return await self._receive_queue.get()

async def _asgi_send(self, message: Message) -> None:
await self._send_queue.put(message)

def _build_accept_response(self, message: Message) -> bytes:
subprotocol = message.get("subprotocol", None)
headers = message.get("headers", [])
return self.connection.send(
wsproto.events.AcceptConnection(
subprotocol=subprotocol,
extra_headers=headers,
)
)


class ASGIWebSocketTransport(ASGITransport):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None

async def handle_async_request(self, request: Request) -> Response:
scheme = request.url.scheme
headers = request.headers

if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket":
subprotocols: list[str] = []
if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None:
subprotocols = subprotocols_header.split(",")

scope = {
"type": "websocket",
"path": request.url.path,
"raw_path": request.url.raw_path,
"root_path": self.root_path,
"scheme": scheme,
"query_string": request.url.query,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"client": self.client,
"server": (request.url.host, request.url.port),
"subprotocols": subprotocols,
}
return await self._handle_ws_request(request, scope)

return await super().handle_async_request(request)

async def _handle_ws_request(
self,
request: Request,
scope: Scope,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)

self.scope = scope
self.exit_stack = contextlib.AsyncExitStack()
stream, accept_response = await self.exit_stack.enter_async_context(
ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type]
)

accept_response_lines = accept_response.decode("utf-8").splitlines()
headers = [
typing.cast(tuple[str, str], line.split(": ", 1))
for line in accept_response_lines[1:]
if line.strip() != ""
]

return Response(
status_code=101,
headers=headers,
extensions={"network_stream": stream},
)

async def aclose(self) -> None:
if self.exit_stack:
await self.exit_stack.aclose()

0 comments on commit 78942cd

Please sign in to comment.