Skip to content

Commit

Permalink
fix: Fix DB unittest reliability (#4548)
Browse files Browse the repository at this point in the history
* Remove busywait

* Ruff 🐶

* Use closure loop

* Use nest-asyncio for nested asgi fixture management

* Wait for db insertions before reading in test

* Ensure the entire experiment has run

* Experiment with locks

* xfail unstable tests

* Use asyncio.sleep before querying database after client interactions

* Ruff 🐶

* Reduce number of evaluators to make tests more reliable

* Only bypass lock for unittests

* Convert to an integration test

* Set default loop scope for unit tests

* Remove loop policy

* xfail tests where evals do not reliably write to the database

* Ensure databases are function scoped

* Ensure inmemory sqlite testing

* Ruff 🐶

* Wipe DBs between tests

* Continue github actions on error

* Use async sleep in spans test

* Remove needless import

* Refactor engine setup to potentially reduce deadlock risk

* Wait for evaluations for more stable tests

* Don't continue on failure

* Ruff 🐶

* BulkInsterters insert immediately in tests

* Remove xdist

* Increase timeout to 30

* Xfail test

* Use shared cache

* Use tempfile based sqlite db

* Use tempdirs for windows compatibility

* Xfail test again

* Wait a waiter to llm eval test

* Skip flaky tests only on windows and mac
  • Loading branch information
anticorrelator authored and RogerHYang committed Sep 21, 2024
1 parent 8be4330 commit 29460c5
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 128 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/python-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,18 @@ jobs:
- name: Run tests (Ubuntu)
if: runner.os == 'Linux'
run: |
hatch run test:tests --run-postgres --allow-flaky
hatch run test:tests --run-postgres
continue-on-error: false
- name: Run tests (macOS)
if: runner.os == 'macOS'
run: |
hatch run test:tests --allow-flaky
hatch run test:tests
continue-on-error: false
- name: Run tests (Windows)
if: runner.os == 'Windows'
run: |
hatch run test:tests --allow-flaky
hatch run test:tests
continue-on-error: false
integration-test:
runs-on: ${{ matrix.os }}
needs: changes
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ dependencies = [
"pandas==2.2.2; python_version>='3.9'",
"pandas==1.4.0; python_version<'3.9'",
"pytest==8.3.2",
"pytest-xdist",
"pytest-asyncio",
"pytest-cov",
"pytest-postgresql",
Expand Down Expand Up @@ -249,21 +248,21 @@ dependencies = [
]

[tool.hatch.envs.default.scripts]
tests = "pytest -n auto {args}"
tests = "pytest {args}"
coverage = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/phoenix --cov=tests {args}"

[[tool.hatch.envs.test.matrix]]
python = ["3.8", "3.12"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope="function"
addopts = [
"-rA",
"--import-mode=importlib",
"--doctest-modules",
"--new-first",
"--showlocals",
"--exitfirst",
]
testpaths = [
"tests",
Expand Down
10 changes: 7 additions & 3 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,15 @@ async def version() -> PlainTextResponse:
DB_MUTEX: Optional[asyncio.Lock] = None


def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
def _db(
engine: AsyncEngine, bypass_lock: bool = False
) -> Callable[[], AsyncContextManager[AsyncSession]]:
Session = async_sessionmaker(engine, expire_on_commit=False)

@contextlib.asynccontextmanager
async def factory() -> AsyncIterator[AsyncSession]:
async with contextlib.AsyncExitStack() as stack:
if DB_MUTEX:
if not bypass_lock and DB_MUTEX:
await stack.enter_async_context(DB_MUTEX)
yield await stack.enter_async_context(Session.begin())

Expand Down Expand Up @@ -626,8 +628,10 @@ def create_app(
scaffolder_config: Optional[ScaffolderConfig] = None,
email_sender: Optional[EmailSender] = None,
oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
) -> FastAPI:
logger.info(f"Server umap params: {umap_params}")
bulk_inserter_factory = bulk_inserter_factory or BulkInserter
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
startup_callbacks_list.append(Facilitator(db=db))
Expand Down Expand Up @@ -660,7 +664,7 @@ def create_app(
cache_for_dataloaders=cache_for_dataloaders,
last_updated_at=last_updated_at,
)
bulk_inserter = BulkInserter(
bulk_inserter = bulk_inserter_factory(
db,
enable_prometheus=enable_prometheus,
event_queue=dml_event_handler,
Expand Down
171 changes: 96 additions & 75 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import contextlib
import time
from asyncio import AbstractEventLoop, get_running_loop
import os
import tempfile
from asyncio import AbstractEventLoop
from functools import partial
from importlib.metadata import version
from random import getrandbits
Expand Down Expand Up @@ -31,17 +32,20 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.types import ASGIApp

import phoenix.trace.v1 as pb
from phoenix.config import EXPORT_DIR
from phoenix.core.model_schema_adapter import create_model_from_inferences
from phoenix.db import models
from phoenix.db.bulk_inserter import BulkInserter
from phoenix.db.engines import aio_postgresql_engine, aio_sqlite_engine
from phoenix.db.insertion.helpers import DataManipulation
from phoenix.inferences.inferences import EMPTY_INFERENCES
from phoenix.pointcloud.umap_parameters import get_umap_parameters
from phoenix.server.app import _db, create_app
from phoenix.server.grpc_server import GrpcServer
from phoenix.server.types import BatchedCaller, DbSessionFactory
from phoenix.session.client import Client
from phoenix.trace.schemas import Span


def pytest_addoption(parser: Parser) -> None:
Expand All @@ -65,20 +69,21 @@ def pytest_terminal_summary(

xfail_threshold = 12 # our tests are currently quite unreliable

terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}")
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}")
if config.getoption("--run-postgres"):
terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}")
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}")

if exitstatus == pytest.ExitCode.OK:
if xfails < xfail_threshold:
terminalreporter.write_sep(
"=", "Within xfail threshold. Passing the test suite.", green=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.OK
else:
terminalreporter.write_sep(
"=", "Too many flaky tests. Failing the test suite.", red=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED
if exitstatus == pytest.ExitCode.OK:
if xfails < xfail_threshold:
terminalreporter.write_sep(
"=", "Within xfail threshold. Passing the test suite.", green=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.OK
else:
terminalreporter.write_sep(
"=", "Too many flaky tests. Failing the test suite.", red=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED


def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None:
Expand All @@ -89,11 +94,6 @@ def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None:
if "postgresql" in item.callspec.params.values():
item.add_marker(skip_postgres)

if config.getoption("--allow-flaky"):
for item in items:
if "dialect" in item.fixturenames:
item.add_marker(pytest.mark.xfail(reason="database tests are currently flaky"))


@pytest.fixture
def pydantic_version() -> Literal["v1", "v2"]:
Expand All @@ -116,7 +116,7 @@ def openai_api_key(monkeypatch: pytest.MonkeyPatch) -> str:
postgresql_connection = factories.postgresql("postgresql_proc")


@pytest.fixture()
@pytest.fixture(scope="function")
async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL]:
connection = postgresql_connection
user = connection.info.user
Expand All @@ -127,10 +127,11 @@ async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL
yield make_url(f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}")


@pytest.fixture
@pytest.fixture(scope="function")
async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]:
engine = aio_postgresql_engine(postgresql_url, migrate=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)
await conn.run_sync(models.Base.metadata.create_all)
yield engine
await engine.dispose()
Expand All @@ -141,33 +142,36 @@ def dialect(request: SubRequest) -> str:
return request.param


@pytest.fixture
@pytest.fixture(scope="function")
async def sqlite_engine() -> AsyncIterator[AsyncEngine]:
engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
yield engine
await engine.dispose()
with tempfile.TemporaryDirectory() as temp_dir:
db_file = os.path.join(temp_dir, "test.db")
engine = aio_sqlite_engine(make_url(f"sqlite+aiosqlite:///{db_file}"), migrate=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.drop_all)
await conn.run_sync(models.Base.metadata.create_all)
yield engine
await engine.dispose()


@pytest.fixture
@pytest.fixture(scope="function")
def db(
request: SubRequest,
dialect: str,
) -> DbSessionFactory:
if dialect == "sqlite":
return _db_with_lock(request.getfixturevalue("sqlite_engine"))
return db_session_factory(request.getfixturevalue("sqlite_engine"))
elif dialect == "postgresql":
return _db_with_lock(request.getfixturevalue("postgresql_engine"))
return db_session_factory(request.getfixturevalue("postgresql_engine"))
raise ValueError(f"Unknown db fixture: {dialect}")


def _db_with_lock(engine: AsyncEngine) -> DbSessionFactory:
lock, db = asyncio.Lock(), _db(engine)
def db_session_factory(engine: AsyncEngine) -> DbSessionFactory:
db = _db(engine, bypass_lock=True)

@contextlib.asynccontextmanager
async def factory() -> AsyncIterator[AsyncSession]:
async with lock, db() as session:
async with db() as session:
yield session

return DbSessionFactory(db=factory, dialect=engine.dialect.name)
Expand All @@ -186,7 +190,6 @@ async def app(
) -> AsyncIterator[ASGIApp]:
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(patch_batched_caller())
await stack.enter_async_context(patch_bulk_inserter())
await stack.enter_async_context(patch_grpc_server())
app = create_app(
db=db,
Expand All @@ -195,57 +198,48 @@ async def app(
export_path=EXPORT_DIR,
umap_params=get_umap_parameters(None),
serve_ui=False,
bulk_inserter_factory=TestBulkInserter,
)
manager = await stack.enter_async_context(LifespanManager(app))
yield manager.app


@pytest.fixture(scope="session")
def event_loop_policy():
try:
import uvloop
except ImportError:
return asyncio.DefaultEventLoopPolicy()
return uvloop.EventLoopPolicy()


@pytest.fixture
async def loop() -> AbstractEventLoop:
return get_running_loop()


@pytest.fixture
def httpx_clients(
app: ASGIApp,
loop: AbstractEventLoop,
) -> Tuple[httpx.Client, httpx.AsyncClient]:
class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport):
def __init__(self, transport: httpx.ASGITransport) -> None:
self.transport = transport
class Transport(httpx.BaseTransport):
def __init__(self, app, asgi_transport):
import nest_asyncio

nest_asyncio.apply()

self.app = app
self.asgi_transport = asgi_transport

def handle_request(self, request: Request) -> Response:
fut = loop.create_task(self.handle_async_request(request))
time_cutoff = time.time() + 10
while not fut.done() and time.time() < time_cutoff:
time.sleep(0.01)
if fut.done():
return fut.result()
raise TimeoutError

async def handle_async_request(self, request: Request) -> Response:
response = await self.transport.handle_async_request(request)
response = asyncio.run(self.asgi_transport.handle_async_request(request))

async def read_stream():
content = b""
async for chunk in response.stream:
content += chunk
return content

content = asyncio.run(read_stream())
return Response(
status_code=response.status_code,
headers=response.headers,
content=b"".join([_ async for _ in response.stream]),
content=content,
request=request,
)

transport = Transport(httpx.ASGITransport(app))
asgi_transport = httpx.ASGITransport(app=app)
transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport)
base_url = "http://test"
return (
httpx.Client(transport=transport, base_url=base_url),
httpx.AsyncClient(transport=transport, base_url=base_url),
httpx.AsyncClient(transport=asgi_transport, base_url=base_url),
)


Expand Down Expand Up @@ -284,15 +278,42 @@ async def patch_grpc_server() -> AsyncIterator[None]:
setattr(cls, name, original)


@contextlib.asynccontextmanager
async def patch_bulk_inserter() -> AsyncIterator[None]:
cls = BulkInserter
original = cls.__init__
name = original.__name__
changes = {"sleep": 0.001, "retry_delay_sec": 0.001, "retry_allowance": 1000}
setattr(cls, name, lambda *_, **__: original(*_, **{**__, **changes}))
yield
setattr(cls, name, original)
class TestBulkInserter(BulkInserter):
async def __aenter__(
self,
) -> Tuple[
Callable[[Any], Awaitable[None]],
Callable[[Span, str], Awaitable[None]],
Callable[[pb.Evaluation], Awaitable[None]],
Callable[[DataManipulation], None],
]:
# Return the overridden methods
return (
self._enqueue_immediate,
self._queue_span_immediate,
self._queue_evaluation_immediate,
self._enqueue_operation_immediate,
)

async def __aexit__(self, *args: Any) -> None:
# No background tasks to cancel
pass

async def _enqueue_immediate(self, *items: Any) -> None:
# Process items immediately
await self._queue_inserters.enqueue(*items)
async for event in self._queue_inserters.insert():
self._event_queue.put(event)

async def _enqueue_operation_immediate(self, operation: DataManipulation) -> None:
async with self._db() as session:
await operation(session)

async def _queue_span_immediate(self, span: Span, project_name: str) -> None:
await self._insert_spans([(span, project_name)])

async def _queue_evaluation_immediate(self, evaluation: pb.Evaluation) -> None:
await self._insert_evaluations([evaluation])


@contextlib.asynccontextmanager
Expand Down
Loading

0 comments on commit 29460c5

Please sign in to comment.