From 78942cd67068b6f7992825ba68348ec1cb2f03ac Mon Sep 17 00:00:00 2001 From: Xander Song Date: Mon, 28 Oct 2024 15:52:29 -0700 Subject: [PATCH] fix(playground): ensure chat completion subscription unit tests run on all os and on postgres (#5205) Co-authored-by: Dustin Ngo --- tests/unit/conftest.py | 2 +- tests/unit/server/api/test_subscriptions.py | 35 +-- tests/unit/transport.py | 256 ++++++++++++++++++++ 3 files changed, 277 insertions(+), 16 deletions(-) create mode 100644 tests/unit/transport.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ac3630a7db..4b05a563cb 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,7 +20,6 @@ from faker import Faker from httpx import AsyncByteStream, Request, Response from httpx_ws import AsyncWebSocketSession, aconnect_ws -from httpx_ws.transport import ASGIWebSocketTransport from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import URL, make_url @@ -42,6 +41,7 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span +from tests.unit.transport import ASGIWebSocketTransport def pytest_terminal_summary( diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index 560ebae0de..760de2eaf8 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -1,10 +1,8 @@ import json -import sys from datetime import datetime from pathlib import Path from typing import Any -import pytest from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, @@ -53,14 +51,6 @@ def test_openai() -> None: return response -@pytest.mark.skipif( - sys.platform - in ( - "win32", - "linux", - ), # todo: support windows and linux https://github.com/Arize-ai/phoenix/issues/5126 - reason="subscriptions are not currently supported on windows or linux", -) class TestChatCompletionSubscription: QUERY = """ subscription ChatCompletionSubscription($input: ChatCompletionInput!) { @@ -188,6 +178,10 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -201,7 +195,6 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -317,6 +310,10 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -330,7 +327,6 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -457,6 +453,10 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -470,7 +470,6 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -607,6 +606,10 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -620,7 +623,6 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -750,6 +752,10 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -763,7 +769,6 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None diff --git a/tests/unit/transport.py b/tests/unit/transport.py new file mode 100644 index 0000000000..b93a998928 --- /dev/null +++ b/tests/unit/transport.py @@ -0,0 +1,256 @@ +""" +This file contains a copy of [httpx-ws](https://github.com/frankie567/httpx-ws), +which is published under an [MIT +license](https://github.com/frankie567/httpx-ws/blob/main/LICENSE). +Modifications have been made to better support the concurrency paradigm used in +our unit test suite. +""" + +import asyncio +import contextlib +import typing + +import wsproto +from httpcore import AsyncNetworkStream +from httpx import ASGITransport, AsyncByteStream, Request, Response +from wsproto.frame_protocol import CloseReason + +Scope = dict[str, typing.Any] +Message = dict[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] + + +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + pass + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): + def __init__(self, app: ASGIApp, scope: Scope) -> None: + self.app = app + self.scope = scope + self._receive_queue: asyncio.Queue[Message] = asyncio.Queue() + self._send_queue: asyncio.Queue[Message] = asyncio.Queue() + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + self.tasks: list[asyncio.Task[None]] = [] + + async def __aenter__( + self, + ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: + self.exit_stack = contextlib.AsyncExitStack() + await self.exit_stack.__aenter__() + + # Start the _run coroutine as a task + self._run_task = asyncio.create_task(self._run()) + self.tasks.append(self._run_task) + self.exit_stack.push_async_callback(self._cancel_tasks) + + await self.send({"type": "websocket.connect"}) + message = await self.receive() + + if message["type"] == "websocket.close": + await self.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + assert message["type"] == "websocket.accept" + return self, self._build_accept_response(message) + + async def __aexit__(self, *args: typing.Any) -> None: + await self.aclose() + await self.exit_stack.aclose() + + async def _cancel_tasks(self) -> None: + # Cancel all running tasks + for task in self.tasks: + task.cancel() + # Wait for tasks to be cancelled + await asyncio.gather(*self.tasks, return_exceptions=True) + + async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + message: Message = await self.receive() + message_type = message["type"] + + if message_type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + event: wsproto.events.Event + if message_type == "websocket.send": + data_str: typing.Optional[str] = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + else: + data_bytes: typing.Optional[bytes] = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + else: + # If neither text nor bytes are provided, raise an error + raise ValueError("websocket.send message missing 'text' or 'bytes'") + elif message_type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message.get("reason")) + + return self.connection.send(event) + + async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Request): + pass # Already handled in __init__ + elif isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.close", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise UnhandledWebSocketEvent(event) + + async def aclose(self) -> None: + await self.send({"type": "websocket.close"}) + # Ensure tasks are cancelled and cleaned up + await self._cancel_tasks() + + async def send(self, message: Message) -> None: + await self._receive_queue.put(message) + + async def receive(self, timeout: typing.Optional[float] = None) -> Message: + try: + message = await asyncio.wait_for(self._send_queue.get(), timeout) + return message + except asyncio.TimeoutError: + raise TimeoutError("Timed out waiting for message") + + async def _run(self) -> None: + """ + The coroutine in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except Exception as e: + message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + await self._asgi_send(message) + + async def _asgi_receive(self) -> Message: + return await self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + await self._send_queue.put(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + + +class ASGIWebSocketTransport(ASGITransport): + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + super().__init__(*args, **kwargs) + self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _handle_ws_request( + self, + request: Request, + scope: Scope, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + self.exit_stack = contextlib.AsyncExitStack() + stream, accept_response = await self.exit_stack.enter_async_context( + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] + ) + + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) + for line in accept_response_lines[1:] + if line.strip() != "" + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) + + async def aclose(self) -> None: + if self.exit_stack: + await self.exit_stack.aclose()