From 334a9a94bad2e1384c443f5996516a443edcd5b1 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:59:51 -0700 Subject: [PATCH] fix: use outerjoin for evals filter (#4066) --- src/phoenix/trace/dsl/filter.py | 8 ++--- tests/conftest.py | 2 ++ tests/trace/dsl/conftest.py | 54 +++++++++++++++++++++++++++------ tests/trace/dsl/test_query.py | 37 +++++++++++++++++++++- 4 files changed, 85 insertions(+), 16 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 020a6ff003..03e6d33643 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -56,11 +56,7 @@ def __post_init__(self) -> None: table = aliased(models.SpanAnnotation, name=table_alias) object.__setattr__(self, "_label_attribute_alias", label_attribute_alias) object.__setattr__(self, "_score_attribute_alias", score_attribute_alias) - object.__setattr__( - self, - "table", - table, - ) + object.__setattr__(self, "table", table) @property def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: @@ -236,7 +232,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any for eval_alias in self._aliased_annotation_relations: eval_name = eval_alias.name AliasedSpanAnnotation = eval_alias.table - stmt = stmt.join( + stmt = stmt.outerjoin( AliasedSpanAnnotation, onclause=( sqlalchemy.and_( diff --git a/tests/conftest.py b/tests/conftest.py index 1f2cfcdd33..7ff3a5c193 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,6 +96,7 @@ async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]: engine = aio_postgresql_engine(postgresql_url, migrate=False) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) + engine.echo = True yield engine await engine.dispose() @@ -110,6 +111,7 @@ async def sqlite_engine() -> AsyncIterator[AsyncEngine]: engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) + engine.echo = True yield engine await engine.dispose() diff --git a/tests/trace/dsl/conftest.py b/tests/trace/dsl/conftest.py index 5a40bb93fe..d7f6d1389d 100644 --- a/tests/trace/dsl/conftest.py +++ b/tests/trace/dsl/conftest.py @@ -13,7 +13,7 @@ async def default_project(db: DbSessionFactory) -> None: project_row_id = await session.scalar( insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id) ) - trace_row_id = await session.scalar( + trace_rowid = await session.scalar( insert(models.Trace) .values( trace_id="0123", @@ -26,7 +26,7 @@ async def default_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="2345", parent_id=None, name="root span", @@ -49,7 +49,7 @@ async def default_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="4567", parent_id="2345", name="retriever span", @@ -85,7 +85,7 @@ async def abc_project(db: DbSessionFactory) -> None: project_row_id = await session.scalar( insert(models.Project).values(name="abc").returning(models.Project.id) ) - trace_row_id = await session.scalar( + trace_rowid = await session.scalar( insert(models.Trace) .values( trace_id="012", @@ -98,7 +98,7 @@ async def abc_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="234", parent_id="123", name="root span", @@ -118,10 +118,10 @@ async def abc_project(db: DbSessionFactory) -> None: ) .returning(models.Span.id) ) - await session.execute( + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="345", parent_id="234", name="embedding span", @@ -155,9 +155,18 @@ async def abc_project(db: DbSessionFactory) -> None: .returning(models.Span.id) ) await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="0", + score=0, + metadata_={}, + ) + ) + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="456", parent_id="234", name="retriever span", @@ -187,9 +196,27 @@ async def abc_project(db: DbSessionFactory) -> None: .returning(models.Span.id) ) await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="0", + score=1, + metadata_={}, + ) + ) + await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="1", + label="1", + metadata_={}, + ) + ) + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="567", parent_id="234", name="llm span", @@ -214,3 +241,12 @@ async def abc_project(db: DbSessionFactory) -> None: ) .returning(models.Span.id) ) + await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="1", + label="0", + metadata_={}, + ) + ) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 61a308b173..e09cdd3f96 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, List import pandas as pd import pytest @@ -891,6 +891,41 @@ async def test_filter_on_trace_id_multiple( ) +@pytest.mark.parametrize( + "condition,expected", + [ + ["evals['0'].score is not None", ["345", "456"]], + ["evals['0'].score is None", ["234", "567"]], + ["evals['0'].score == 0", ["345"]], + ["evals['0'].score != 0", ["456"]], + ["evals['0'].score != 0 or evals['0'].score is None", ["234", "456", "567"]], + ["evals['1'].label is not None", ["456", "567"]], + ["evals['1'].label is None", ["234", "345"]], + ["evals['1'].label == '1'", ["456"]], + ["evals['1'].label != '1'", ["567"]], + ["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]], + ["evals['0'].score is not None or evals['1'].label is not None", ["345", "456", "567"]], + ["evals['0'].score is None or evals['1'].label is None", ["234", "345", "567"]], + ["evals['0'].score == 0 or evals['1'].label == '1'", ["345", "456"]], + ["evals['0'].score != 0 or evals['1'].label != '1'", ["456", "567"]], + ["evals['0'].score is not None or evals['1'].label is None", ["234", "345", "456"]], + ["evals['0'].score is None or evals['1'].label is not None", ["234", "456", "567"]], + ["evals['0'].score == 0 or evals['1'].label != '1'", ["345", "567"]], + ["evals['0'].score != 0 or evals['1'].label == '1'", ["456"]], + ], +) +async def test_filter_on_span_annotation( + db: DbSessionFactory, + abc_project: Any, + condition: str, + expected: List[str], +) -> None: + sq = SpanQuery().select("span_id").where(condition) + async with db() as session: + actual = await session.run_sync(sq, project_name="abc") + assert sorted(actual.index) == expected + + async def test_explode_embeddings_no_select( db: DbSessionFactory, default_project: Any,