Skip to content

Commit

Permalink
fix(playground): ensure playground timeout errors are displayed (#5486)
Browse files Browse the repository at this point in the history
Co-authored-by: Dustin Ngo <dustin@arize.com>
  • Loading branch information
axiomofjoy and anticorrelator authored Nov 22, 2024
1 parent 2d67b31 commit 38f8f56
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
9 changes: 5 additions & 4 deletions src/phoenix/server/api/helpers/playground_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
request_start_time = time.time()
maybe_coroutine = fn(*args, **kwargs)
if inspect.isawaitable(maybe_coroutine):
return await maybe_coroutine # type: ignore
return await maybe_coroutine # type: ignore[no-any-return]
else:
return maybe_coroutine
except self._rate_limit_error:
Expand All @@ -144,10 +144,11 @@ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
try:
request_start_time = time.time()
await self._throttler.async_wait_until_ready()
if inspect.iscoroutinefunction(fn):
return await fn(*args, **kwargs) # type: ignore
maybe_coroutine = fn(*args, **kwargs)
if inspect.isawaitable(maybe_coroutine):
return await maybe_coroutine # type: ignore[no-any-return]
else:
return fn(*args, **kwargs)
return maybe_coroutine
except self._rate_limit_error:
self._throttler.on_rate_limit_error(
request_start_time, verbose=self._verbose
Expand Down
57 changes: 46 additions & 11 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import logging
from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
from collections.abc import AsyncIterator, Iterator
from datetime import datetime, timedelta, timezone
from typing import (
Any,
AsyncGenerator,
Coroutine,
Iterable,
Mapping,
Optional,
Expand Down Expand Up @@ -287,7 +287,7 @@ async def chat_completion_over_dataset(
experiment=to_gql_experiment(experiment)
) # eagerly yields experiment so it can be linked by consumers of the subscription

results: Queue[ChatCompletionResult] = Queue()
results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue()
not_started: list[tuple[DatasetExampleID, ChatStream]] = [
(
GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)),
Expand All @@ -303,7 +303,11 @@ async def chat_completion_over_dataset(
for revision in revisions
]
in_progress: list[
tuple[Optional[DatasetExampleID], ChatStream, Task[ChatCompletionSubscriptionPayload]]
tuple[
Optional[DatasetExampleID],
ChatStream,
asyncio.Task[ChatCompletionSubscriptionPayload],
]
] = []
max_in_progress = 3
write_batch_size = 10
Expand All @@ -315,7 +319,9 @@ async def chat_completion_over_dataset(
task = _create_task_with_timeout(stream)
in_progress.append((ex_id, stream, task))
async_tasks_to_run = [task for _, _, task in in_progress]
completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
completed_tasks, _ = await asyncio.wait(
async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
)
for completed_task in completed_tasks:
idx = [task for _, _, task in in_progress].index(completed_task)
example_id, stream, _ = in_progress[idx]
Expand All @@ -327,7 +333,7 @@ async def chat_completion_over_dataset(
del in_progress[idx] # removes timed-out stream
if example_id is not None:
yield ChatCompletionSubscriptionError(
message="Timed out", dataset_example_id=example_id
message="Playground task timed out", dataset_example_id=example_id
)
except Exception as error:
del in_progress[idx] # removes failed stream
Expand Down Expand Up @@ -368,7 +374,7 @@ async def _stream_chat_completion_over_dataset_example(
input: ChatCompletionOverDatasetInput,
llm_client: PlaygroundStreamingClient,
revision: models.DatasetExampleRevision,
results: Queue[ChatCompletionResult],
results: asyncio.Queue[ChatCompletionResult],
experiment_id: int,
project_id: int,
) -> ChatStream:
Expand Down Expand Up @@ -470,23 +476,52 @@ def _is_result_payloads_stream(

def _create_task_with_timeout(
iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90
) -> Task[GenericType]:
return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
) -> asyncio.Task[GenericType]:
return asyncio.create_task(
_wait_for(
_as_coroutine(iterable),
timeout=timeout_in_seconds,
timeout_message="Playground task timed out",
)
)


async def _wait_for(
coro: Coroutine[None, None, GenericType],
timeout: float,
timeout_message: Optional[str] = None,
) -> GenericType:
"""
A function that imitates asyncio.wait_for, but allows the task to be
cancelled with a custom message.
"""
task = asyncio.create_task(coro)
done, pending = await asyncio.wait([task], timeout=timeout)
assert len(done) + len(pending) == 1
if done:
task = done.pop()
return task.result()
task = pending.pop()
task.cancel(msg=timeout_message)
try:
return await task
except asyncio.CancelledError:
raise asyncio.TimeoutError()


async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
values: list[GenericType] = []
while not queue.empty():
values.append(await queue.get())
return values


def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
values: list[GenericType] = []
while True:
try:
values.append(queue.get_nowait())
except QueueEmpty:
except asyncio.QueueEmpty:
break
return values

Expand Down

0 comments on commit 38f8f56

Please sign in to comment.