Skip to content

Commit

Permalink
fix: use outerjoin for evals filter (#4066)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Jul 30, 2024
1 parent 13af3b5 commit 334a9a9
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 16 deletions.
8 changes: 2 additions & 6 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def __post_init__(self) -> None:
table = aliased(models.SpanAnnotation, name=table_alias)
object.__setattr__(self, "_label_attribute_alias", label_attribute_alias)
object.__setattr__(self, "_score_attribute_alias", score_attribute_alias)
object.__setattr__(
self,
"table",
table,
)
object.__setattr__(self, "table", table)

@property
def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]:
Expand Down Expand Up @@ -236,7 +232,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any
for eval_alias in self._aliased_annotation_relations:
eval_name = eval_alias.name
AliasedSpanAnnotation = eval_alias.table
stmt = stmt.join(
stmt = stmt.outerjoin(
AliasedSpanAnnotation,
onclause=(
sqlalchemy.and_(
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]:
engine = aio_postgresql_engine(postgresql_url, migrate=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
engine.echo = True
yield engine
await engine.dispose()

Expand All @@ -110,6 +111,7 @@ async def sqlite_engine() -> AsyncIterator[AsyncEngine]:
engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
engine.echo = True
yield engine
await engine.dispose()

Expand Down
54 changes: 45 additions & 9 deletions tests/trace/dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def default_project(db: DbSessionFactory) -> None:
project_row_id = await session.scalar(
insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id)
)
trace_row_id = await session.scalar(
trace_rowid = await session.scalar(
insert(models.Trace)
.values(
trace_id="0123",
Expand All @@ -26,7 +26,7 @@ async def default_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="2345",
parent_id=None,
name="root span",
Expand All @@ -49,7 +49,7 @@ async def default_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="4567",
parent_id="2345",
name="retriever span",
Expand Down Expand Up @@ -85,7 +85,7 @@ async def abc_project(db: DbSessionFactory) -> None:
project_row_id = await session.scalar(
insert(models.Project).values(name="abc").returning(models.Project.id)
)
trace_row_id = await session.scalar(
trace_rowid = await session.scalar(
insert(models.Trace)
.values(
trace_id="012",
Expand All @@ -98,7 +98,7 @@ async def abc_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="234",
parent_id="123",
name="root span",
Expand All @@ -118,10 +118,10 @@ async def abc_project(db: DbSessionFactory) -> None:
)
.returning(models.Span.id)
)
await session.execute(
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="345",
parent_id="234",
name="embedding span",
Expand Down Expand Up @@ -155,9 +155,18 @@ async def abc_project(db: DbSessionFactory) -> None:
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="0",
score=0,
metadata_={},
)
)
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="456",
parent_id="234",
name="retriever span",
Expand Down Expand Up @@ -187,9 +196,27 @@ async def abc_project(db: DbSessionFactory) -> None:
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="0",
score=1,
metadata_={},
)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="1",
label="1",
metadata_={},
)
)
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="567",
parent_id="234",
name="llm span",
Expand All @@ -214,3 +241,12 @@ async def abc_project(db: DbSessionFactory) -> None:
)
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="1",
label="0",
metadata_={},
)
)
37 changes: 36 additions & 1 deletion tests/trace/dsl/test_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Any, List

import pandas as pd
import pytest
Expand Down Expand Up @@ -891,6 +891,41 @@ async def test_filter_on_trace_id_multiple(
)


@pytest.mark.parametrize(
"condition,expected",
[
["evals['0'].score is not None", ["345", "456"]],
["evals['0'].score is None", ["234", "567"]],
["evals['0'].score == 0", ["345"]],
["evals['0'].score != 0", ["456"]],
["evals['0'].score != 0 or evals['0'].score is None", ["234", "456", "567"]],
["evals['1'].label is not None", ["456", "567"]],
["evals['1'].label is None", ["234", "345"]],
["evals['1'].label == '1'", ["456"]],
["evals['1'].label != '1'", ["567"]],
["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]],
["evals['0'].score is not None or evals['1'].label is not None", ["345", "456", "567"]],
["evals['0'].score is None or evals['1'].label is None", ["234", "345", "567"]],
["evals['0'].score == 0 or evals['1'].label == '1'", ["345", "456"]],
["evals['0'].score != 0 or evals['1'].label != '1'", ["456", "567"]],
["evals['0'].score is not None or evals['1'].label is None", ["234", "345", "456"]],
["evals['0'].score is None or evals['1'].label is not None", ["234", "456", "567"]],
["evals['0'].score == 0 or evals['1'].label != '1'", ["345", "567"]],
["evals['0'].score != 0 or evals['1'].label == '1'", ["456"]],
],
)
async def test_filter_on_span_annotation(
db: DbSessionFactory,
abc_project: Any,
condition: str,
expected: List[str],
) -> None:
sq = SpanQuery().select("span_id").where(condition)
async with db() as session:
actual = await session.run_sync(sq, project_name="abc")
assert sorted(actual.index) == expected


async def test_explode_embeddings_no_select(
db: DbSessionFactory,
default_project: Any,
Expand Down

0 comments on commit 334a9a9

Please sign in to comment.