Skip to content

Commit

Permalink
BulkInsterters insert immediately in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Sep 17, 2024
1 parent 26b64ac commit e14a126
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 27 deletions.
4 changes: 3 additions & 1 deletion src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ def create_app(
shutdown_callbacks: Iterable[_Callback] = (),
secret: Optional[str] = None,
scaffolder_config: Optional[ScaffolderConfig] = None,
bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
) -> FastAPI:
bulk_inserter_factory = bulk_inserter_factory or BulkInserter
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
startup_callbacks_list.append(Facilitator(db))
Expand Down Expand Up @@ -654,7 +656,7 @@ def create_app(
cache_for_dataloaders=cache_for_dataloaders,
last_updated_at=last_updated_at,
)
bulk_inserter = BulkInserter(
bulk_inserter = bulk_inserter_factory(
db,
enable_prometheus=enable_prometheus,
event_queue=dml_event_handler,
Expand Down
75 changes: 56 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.types import ASGIApp

import phoenix.trace.v1 as pb
from phoenix.config import EXPORT_DIR
from phoenix.core.model_schema_adapter import create_model_from_inferences
from phoenix.db import models
from phoenix.db.bulk_inserter import BulkInserter
from phoenix.db.engines import aio_postgresql_engine, aio_sqlite_engine
from phoenix.db.insertion.helpers import DataManipulation
from phoenix.inferences.inferences import EMPTY_INFERENCES
from phoenix.pointcloud.umap_parameters import get_umap_parameters
from phoenix.server.app import _db, create_app
from phoenix.server.grpc_server import GrpcServer
from phoenix.server.types import BatchedCaller, DbSessionFactory
from phoenix.session.client import Client
from phoenix.trace.schemas import Span


def pytest_addoption(parser: Parser) -> None:
Expand Down Expand Up @@ -185,7 +188,6 @@ async def app(
) -> AsyncIterator[ASGIApp]:
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(patch_batched_caller())
await stack.enter_async_context(patch_bulk_inserter())
await stack.enter_async_context(patch_grpc_server())
app = create_app(
db=db,
Expand All @@ -194,6 +196,7 @@ async def app(
export_path=EXPORT_DIR,
umap_params=get_umap_parameters(None),
serve_ui=False,
bulk_inserter_factory=TestBulkInserter,
)
manager = await stack.enter_async_context(LifespanManager(app))
yield manager.app
Expand All @@ -203,31 +206,38 @@ async def app(
def httpx_clients(
app: ASGIApp,
) -> Tuple[httpx.Client, httpx.AsyncClient]:
class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport):
def __init__(self, transport: httpx.ASGITransport) -> None:
class Transport(httpx.BaseTransport):
def __init__(self, app, asgi_transport):
import nest_asyncio

nest_asyncio.apply()

self.transport = transport
self.app = app
self.asgi_transport = asgi_transport

def handle_request(self, request: Request) -> Response:
return asyncio.run(self.handle_async_request(request))
response = asyncio.run(self.asgi_transport.handle_async_request(request))

async def handle_async_request(self, request: Request) -> Response:
response = await self.transport.handle_async_request(request)
async def read_stream():
content = b""
async for chunk in response.stream:
content += chunk
return content

content = asyncio.run(read_stream())
return Response(
status_code=response.status_code,
headers=response.headers,
content=b"".join([_ async for _ in response.stream]),
content=content,
request=request,
)

transport = Transport(httpx.ASGITransport(app))
asgi_transport = httpx.ASGITransport(app=app)
transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport)
base_url = "http://test"
return (
httpx.Client(transport=transport, base_url=base_url),
httpx.AsyncClient(transport=transport, base_url=base_url),
httpx.AsyncClient(transport=asgi_transport, base_url=base_url),
)


Expand Down Expand Up @@ -266,15 +276,42 @@ async def patch_grpc_server() -> AsyncIterator[None]:
setattr(cls, name, original)


@contextlib.asynccontextmanager
async def patch_bulk_inserter() -> AsyncIterator[None]:
cls = BulkInserter
original = cls.__init__
name = original.__name__
changes = {"sleep": 0.001, "retry_delay_sec": 0.001, "retry_allowance": 1000}
setattr(cls, name, lambda *_, **__: original(*_, **{**__, **changes}))
yield
setattr(cls, name, original)
class TestBulkInserter(BulkInserter):
async def __aenter__(
self,
) -> Tuple[
Callable[[Any], Awaitable[None]],
Callable[[Span, str], Awaitable[None]],
Callable[[pb.Evaluation], Awaitable[None]],
Callable[[DataManipulation], None],
]:
# Return the overridden methods
return (
self._enqueue_immediate,
self._queue_span_immediate,
self._queue_evaluation_immediate,
self._enqueue_operation_immediate,
)

async def __aexit__(self, *args: Any) -> None:
# No background tasks to cancel
pass

async def _enqueue_immediate(self, *items: Any) -> None:
# Process items immediately
await self._queue_inserters.enqueue(*items)
async for event in self._queue_inserters.insert():
self._event_queue.put(event)

async def _enqueue_operation_immediate(self, operation: DataManipulation) -> None:
async with self._db() as session:
await operation(session)

async def _queue_span_immediate(self, span: Span, project_name: str) -> None:
await self._insert_spans([(span, project_name)])

async def _queue_evaluation_immediate(self, evaluation: pb.Evaluation) -> None:
await self._insert_evaluations([evaluation])


@contextlib.asynccontextmanager
Expand Down
15 changes: 8 additions & 7 deletions tests/datasets/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import httpx
import pytest
from sqlalchemy import select
from strawberry.relay import GlobalID

Expand Down Expand Up @@ -36,8 +37,8 @@ async def test_run_experiment(
simple_dataset: Any,
dialect: str,
) -> None:
# if dialect == "postgresql":
# pytest.xfail("TODO: Convert this to an integration test")
if dialect == "postgresql":
pytest.xfail("TODO: Convert this to an integration test")

async with db() as session:
nonexistent_experiment = (await session.execute(select(models.Experiment))).scalar()
Expand Down Expand Up @@ -102,7 +103,7 @@ def experiment_task(_) -> Dict[str, str]:

# Wait until all evaluations are complete
async def wait_for_evaluations():
timeout = 30
timeout = 15
interval = 0.5
total_wait = 0
while total_wait < timeout:
Expand Down Expand Up @@ -180,8 +181,8 @@ async def test_run_experiment_with_llm_eval(
simple_dataset: Any,
dialect: str,
) -> None:
# if dialect == "postgresql":
# pytest.xfail("This test fails on PostgreSQL")
if dialect == "postgresql":
pytest.xfail("This test fails on PostgreSQL")

async with db() as session:
nonexistent_experiment = (await session.execute(select(models.Experiment))).scalar()
Expand Down Expand Up @@ -292,8 +293,8 @@ async def test_run_evaluation(
simple_dataset_with_one_experiment_run: Any,
dialect: str,
) -> None:
# if dialect == "postgresql":
# pytest.xfail("This test fails on PostgreSQL")
if dialect == "postgresql":
pytest.xfail("This test fails on PostgreSQL")

experiment = Experiment(
id=str(GlobalID("Experiment", "0")),
Expand Down

0 comments on commit e14a126

Please sign in to comment.