From cdd44432cb9a6a610f8ee2fe8aedad2422c4caa3 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Mon, 29 Jul 2024 08:58:09 -0700 Subject: [PATCH 1/4] fix: use outerjoin for evals filter --- src/phoenix/trace/dsl/filter.py | 8 +--- tests/conftest.py | 2 + tests/trace/dsl/conftest.py | 54 +++++++++++++++++++++----- tests/trace/dsl/test_query.py | 69 ++++++++++++++++++++++++++++++++- 4 files changed, 117 insertions(+), 16 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 020a6ff003..03e6d33643 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -56,11 +56,7 @@ def __post_init__(self) -> None: table = aliased(models.SpanAnnotation, name=table_alias) object.__setattr__(self, "_label_attribute_alias", label_attribute_alias) object.__setattr__(self, "_score_attribute_alias", score_attribute_alias) - object.__setattr__( - self, - "table", - table, - ) + object.__setattr__(self, "table", table) @property def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: @@ -236,7 +232,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any for eval_alias in self._aliased_annotation_relations: eval_name = eval_alias.name AliasedSpanAnnotation = eval_alias.table - stmt = stmt.join( + stmt = stmt.outerjoin( AliasedSpanAnnotation, onclause=( sqlalchemy.and_( diff --git a/tests/conftest.py b/tests/conftest.py index e66ef2154d..6343e2794f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,7 @@ async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]: engine = aio_postgresql_engine(postgresql_url, migrate=False) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) + engine.echo = True yield engine await engine.dispose() @@ -108,6 +109,7 @@ async def sqlite_engine() -> AsyncIterator[AsyncEngine]: engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) + engine.echo = True yield engine await engine.dispose() diff --git a/tests/trace/dsl/conftest.py b/tests/trace/dsl/conftest.py index 5a40bb93fe..d7f6d1389d 100644 --- a/tests/trace/dsl/conftest.py +++ b/tests/trace/dsl/conftest.py @@ -13,7 +13,7 @@ async def default_project(db: DbSessionFactory) -> None: project_row_id = await session.scalar( insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id) ) - trace_row_id = await session.scalar( + trace_rowid = await session.scalar( insert(models.Trace) .values( trace_id="0123", @@ -26,7 +26,7 @@ async def default_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="2345", parent_id=None, name="root span", @@ -49,7 +49,7 @@ async def default_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="4567", parent_id="2345", name="retriever span", @@ -85,7 +85,7 @@ async def abc_project(db: DbSessionFactory) -> None: project_row_id = await session.scalar( insert(models.Project).values(name="abc").returning(models.Project.id) ) - trace_row_id = await session.scalar( + trace_rowid = await session.scalar( insert(models.Trace) .values( trace_id="012", @@ -98,7 +98,7 @@ async def abc_project(db: DbSessionFactory) -> None: await session.execute( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="234", parent_id="123", name="root span", @@ -118,10 +118,10 @@ async def abc_project(db: DbSessionFactory) -> None: ) .returning(models.Span.id) ) - await session.execute( + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="345", parent_id="234", name="embedding span", @@ -155,9 +155,18 @@ async def abc_project(db: DbSessionFactory) -> None: .returning(models.Span.id) ) await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="0", + score=0, + metadata_={}, + ) + ) + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="456", parent_id="234", name="retriever span", @@ -187,9 +196,27 @@ async def abc_project(db: DbSessionFactory) -> None: .returning(models.Span.id) ) await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="0", + score=1, + metadata_={}, + ) + ) + await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="1", + label="1", + metadata_={}, + ) + ) + span_rowid = await session.scalar( insert(models.Span) .values( - trace_rowid=trace_row_id, + trace_rowid=trace_rowid, span_id="567", parent_id="234", name="llm span", @@ -214,3 +241,12 @@ async def abc_project(db: DbSessionFactory) -> None: ) .returning(models.Span.id) ) + await session.execute( + insert(models.SpanAnnotation).values( + span_rowid=span_rowid, + annotator_kind="LLM", + name="1", + label="0", + metadata_={}, + ) + ) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 61a308b173..a74cb14251 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, List import pandas as pd import pytest @@ -347,6 +347,26 @@ async def test_filter_for_none( check_index_type=False, ) + sq = ( + SpanQuery() + .select("name") + .where( + "parent_id is not None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "name": ["root span", "embedding span", "retriever span", "llm span"], + } + ).set_index("context.span_id") + async with db() as session: + actual = await session.run_sync(sq, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + async def test_filter_for_not_none( db: DbSessionFactory, @@ -373,6 +393,26 @@ async def test_filter_for_not_none( expected.sort_index().sort_index(axis=1), ) + sq = ( + SpanQuery() + .select("name") + .where( + "output.value is None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345", "456", "567"], + "name": ["embedding span", "retriever span", "llm span"], + } + ).set_index("context.span_id") + async with db() as session: + actual = await session.run_sync(sq, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + async def test_filter_for_substring_case_sensitive_not_glob_not_like( db: DbSessionFactory, @@ -891,6 +931,33 @@ async def test_filter_on_trace_id_multiple( ) +@pytest.mark.parametrize( + "condition,expected", + [ + ["evals['0'].score is not None", ["345", "456"]], + ["evals['0'].score is None", ["234", "567"]], + ["evals['0'].score == 0", ["345"]], + ["evals['0'].score != 0", ["456"]], + ["evals['0'].score != 0 or evals['0'].score is None", ["234", "456", "567"]], + ["evals['1'].label is not None", ["456", "567"]], + ["evals['1'].label is None", ["234", "345"]], + ["evals['1'].label == '1'", ["456"]], + ["evals['1'].label != '1'", ["567"]], + ["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]], + ], +) +async def test_filter_on_span_annotation( + db: DbSessionFactory, + abc_project: Any, + condition: str, + expected: List[str], +) -> None: + sq = SpanQuery().select("span_id").where(condition) + async with db() as session: + actual = await session.run_sync(sq, project_name="abc") + assert sorted(actual.index) == expected + + async def test_explode_embeddings_no_select( db: DbSessionFactory, default_project: Any, From 41656ca8ac416854eb5d9a0ef7d7a952200a6932 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:57:01 -0700 Subject: [PATCH 2/4] remove inadventent diffs --- tests/trace/dsl/test_query.py | 40 ----------------------------------- 1 file changed, 40 deletions(-) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index a74cb14251..7a4c2b99ef 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -347,26 +347,6 @@ async def test_filter_for_none( check_index_type=False, ) - sq = ( - SpanQuery() - .select("name") - .where( - "parent_id is not None", - ) - ) - expected = pd.DataFrame( - { - "context.span_id": ["234", "345", "456", "567"], - "name": ["root span", "embedding span", "retriever span", "llm span"], - } - ).set_index("context.span_id") - async with db() as session: - actual = await session.run_sync(sq, project_name="abc") - assert_frame_equal( - actual.sort_index().sort_index(axis=1), - expected.sort_index().sort_index(axis=1), - ) - async def test_filter_for_not_none( db: DbSessionFactory, @@ -393,26 +373,6 @@ async def test_filter_for_not_none( expected.sort_index().sort_index(axis=1), ) - sq = ( - SpanQuery() - .select("name") - .where( - "output.value is None", - ) - ) - expected = pd.DataFrame( - { - "context.span_id": ["345", "456", "567"], - "name": ["embedding span", "retriever span", "llm span"], - } - ).set_index("context.span_id") - async with db() as session: - actual = await session.run_sync(sq, project_name="abc") - assert_frame_equal( - actual.sort_index().sort_index(axis=1), - expected.sort_index().sort_index(axis=1), - ) - async def test_filter_for_substring_case_sensitive_not_glob_not_like( db: DbSessionFactory, From 47aa4e0c9d46ee936120bce5ee2a524a242f10a1 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Tue, 30 Jul 2024 12:04:31 -0700 Subject: [PATCH 3/4] clean up --- tests/trace/dsl/test_query.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 7a4c2b99ef..d6ff89b293 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -904,6 +904,14 @@ async def test_filter_on_trace_id_multiple( ["evals['1'].label == '1'", ["456"]], ["evals['1'].label != '1'", ["567"]], ["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]], + ["evals['0'].score is not None or evals['1'].label is not None", ["345", "456", "567"]], + ["evals['0'].score is None or evals['1'].label is None", ["234", "345", "567"]], + ["evals['0'].score == 0 or evals['1'].label == 1", ["345", "456"]], + ["evals['0'].score != 0 or evals['1'].label != 1", ["456", "567"]], + ["evals['0'].score is not None or evals['1'].label is None", ["234", "345", "456"]], + ["evals['0'].score is None or evals['1'].label is not None", ["234", "456", "567"]], + ["evals['0'].score == 0 or evals['1'].label != 1", ["345", "567"]], + ["evals['0'].score != 0 or evals['1'].label == 1", ["456"]], ], ) async def test_filter_on_span_annotation( From 27451aa72368d68b0b8ad83170881994dd533563 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Tue, 30 Jul 2024 15:45:05 -0700 Subject: [PATCH 4/4] fix typo --- tests/trace/dsl/test_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index d6ff89b293..e09cdd3f96 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -906,12 +906,12 @@ async def test_filter_on_trace_id_multiple( ["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]], ["evals['0'].score is not None or evals['1'].label is not None", ["345", "456", "567"]], ["evals['0'].score is None or evals['1'].label is None", ["234", "345", "567"]], - ["evals['0'].score == 0 or evals['1'].label == 1", ["345", "456"]], - ["evals['0'].score != 0 or evals['1'].label != 1", ["456", "567"]], + ["evals['0'].score == 0 or evals['1'].label == '1'", ["345", "456"]], + ["evals['0'].score != 0 or evals['1'].label != '1'", ["456", "567"]], ["evals['0'].score is not None or evals['1'].label is None", ["234", "345", "456"]], ["evals['0'].score is None or evals['1'].label is not None", ["234", "456", "567"]], - ["evals['0'].score == 0 or evals['1'].label != 1", ["345", "567"]], - ["evals['0'].score != 0 or evals['1'].label == 1", ["456"]], + ["evals['0'].score == 0 or evals['1'].label != '1'", ["345", "567"]], + ["evals['0'].score != 0 or evals['1'].label == '1'", ["456"]], ], ) async def test_filter_on_span_annotation(