diff --git a/src/phoenix/server/api/helpers/playground_clients.py b/src/phoenix/server/api/helpers/playground_clients.py index 6d481c68b2..76a96f6205 100644 --- a/src/phoenix/server/api/helpers/playground_clients.py +++ b/src/phoenix/server/api/helpers/playground_clients.py @@ -132,7 +132,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> GenericType: request_start_time = time.time() maybe_coroutine = fn(*args, **kwargs) if inspect.isawaitable(maybe_coroutine): - return await maybe_coroutine # type: ignore + return await maybe_coroutine # type: ignore[no-any-return] else: return maybe_coroutine except self._rate_limit_error: @@ -144,10 +144,11 @@ async def wrapper(*args: Any, **kwargs: Any) -> GenericType: try: request_start_time = time.time() await self._throttler.async_wait_until_ready() - if inspect.iscoroutinefunction(fn): - return await fn(*args, **kwargs) # type: ignore + maybe_coroutine = fn(*args, **kwargs) + if inspect.isawaitable(maybe_coroutine): + return await maybe_coroutine # type: ignore[no-any-return] else: - return fn(*args, **kwargs) + return maybe_coroutine except self._rate_limit_error: self._throttler.on_rate_limit_error( request_start_time, verbose=self._verbose diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 1c34611675..72b6a3a1fc 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -1,11 +1,11 @@ import asyncio import logging -from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for from collections.abc import AsyncIterator, Iterator from datetime import datetime, timedelta, timezone from typing import ( Any, AsyncGenerator, + Coroutine, Iterable, Mapping, Optional, @@ -287,7 +287,7 @@ async def chat_completion_over_dataset( experiment=to_gql_experiment(experiment) ) # eagerly yields experiment so it can be linked by consumers of the subscription - results: Queue[ChatCompletionResult] = Queue() + results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue() not_started: list[tuple[DatasetExampleID, ChatStream]] = [ ( GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)), @@ -303,7 +303,11 @@ async def chat_completion_over_dataset( for revision in revisions ] in_progress: list[ - tuple[Optional[DatasetExampleID], ChatStream, Task[ChatCompletionSubscriptionPayload]] + tuple[ + Optional[DatasetExampleID], + ChatStream, + asyncio.Task[ChatCompletionSubscriptionPayload], + ] ] = [] max_in_progress = 3 write_batch_size = 10 @@ -315,7 +319,9 @@ async def chat_completion_over_dataset( task = _create_task_with_timeout(stream) in_progress.append((ex_id, stream, task)) async_tasks_to_run = [task for _, _, task in in_progress] - completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED) + completed_tasks, _ = await asyncio.wait( + async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED + ) for completed_task in completed_tasks: idx = [task for _, _, task in in_progress].index(completed_task) example_id, stream, _ = in_progress[idx] @@ -327,7 +333,7 @@ async def chat_completion_over_dataset( del in_progress[idx] # removes timed-out stream if example_id is not None: yield ChatCompletionSubscriptionError( - message="Timed out", dataset_example_id=example_id + message="Playground task timed out", dataset_example_id=example_id ) except Exception as error: del in_progress[idx] # removes failed stream @@ -368,7 +374,7 @@ async def _stream_chat_completion_over_dataset_example( input: ChatCompletionOverDatasetInput, llm_client: PlaygroundStreamingClient, revision: models.DatasetExampleRevision, - results: Queue[ChatCompletionResult], + results: asyncio.Queue[ChatCompletionResult], experiment_id: int, project_id: int, ) -> ChatStream: @@ -470,23 +476,52 @@ def _is_result_payloads_stream( def _create_task_with_timeout( iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90 -) -> Task[GenericType]: - return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds)) +) -> asyncio.Task[GenericType]: + return asyncio.create_task( + _wait_for( + _as_coroutine(iterable), + timeout=timeout_in_seconds, + timeout_message="Playground task timed out", + ) + ) + + +async def _wait_for( + coro: Coroutine[None, None, GenericType], + timeout: float, + timeout_message: Optional[str] = None, +) -> GenericType: + """ + A function that imitates asyncio.wait_for, but allows the task to be + cancelled with a custom message. + """ + task = asyncio.create_task(coro) + done, pending = await asyncio.wait([task], timeout=timeout) + assert len(done) + len(pending) == 1 + if done: + task = done.pop() + return task.result() + task = pending.pop() + task.cancel(msg=timeout_message) + try: + return await task + except asyncio.CancelledError: + raise asyncio.TimeoutError() -async def _drain(queue: Queue[GenericType]) -> list[GenericType]: +async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]: values: list[GenericType] = [] while not queue.empty(): values.append(await queue.get()) return values -def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]: +def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]: values: list[GenericType] = [] while True: try: values.append(queue.get_nowait()) - except QueueEmpty: + except asyncio.QueueEmpty: break return values