Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use outerjoin for evals filter #4066

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def __post_init__(self) -> None:
table = aliased(models.SpanAnnotation, name=table_alias)
object.__setattr__(self, "_label_attribute_alias", label_attribute_alias)
object.__setattr__(self, "_score_attribute_alias", score_attribute_alias)
object.__setattr__(
self,
"table",
table,
)
object.__setattr__(self, "table", table)

@property
def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]:
Expand Down Expand Up @@ -236,7 +232,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any
for eval_alias in self._aliased_annotation_relations:
eval_name = eval_alias.name
AliasedSpanAnnotation = eval_alias.table
stmt = stmt.join(
stmt = stmt.outerjoin(
AliasedSpanAnnotation,
onclause=(
sqlalchemy.and_(
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]:
engine = aio_postgresql_engine(postgresql_url, migrate=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
engine.echo = True
yield engine
await engine.dispose()

Expand All @@ -108,6 +109,7 @@ async def sqlite_engine() -> AsyncIterator[AsyncEngine]:
engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False)
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
engine.echo = True
yield engine
await engine.dispose()

Expand Down
54 changes: 45 additions & 9 deletions tests/trace/dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def default_project(db: DbSessionFactory) -> None:
project_row_id = await session.scalar(
insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id)
)
trace_row_id = await session.scalar(
trace_rowid = await session.scalar(
insert(models.Trace)
.values(
trace_id="0123",
Expand All @@ -26,7 +26,7 @@ async def default_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="2345",
parent_id=None,
name="root span",
Expand All @@ -49,7 +49,7 @@ async def default_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="4567",
parent_id="2345",
name="retriever span",
Expand Down Expand Up @@ -85,7 +85,7 @@ async def abc_project(db: DbSessionFactory) -> None:
project_row_id = await session.scalar(
insert(models.Project).values(name="abc").returning(models.Project.id)
)
trace_row_id = await session.scalar(
trace_rowid = await session.scalar(
insert(models.Trace)
.values(
trace_id="012",
Expand All @@ -98,7 +98,7 @@ async def abc_project(db: DbSessionFactory) -> None:
await session.execute(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="234",
parent_id="123",
name="root span",
Expand All @@ -118,10 +118,10 @@ async def abc_project(db: DbSessionFactory) -> None:
)
.returning(models.Span.id)
)
await session.execute(
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="345",
parent_id="234",
name="embedding span",
Expand Down Expand Up @@ -155,9 +155,18 @@ async def abc_project(db: DbSessionFactory) -> None:
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="0",
score=0,
metadata_={},
)
)
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="456",
parent_id="234",
name="retriever span",
Expand Down Expand Up @@ -187,9 +196,27 @@ async def abc_project(db: DbSessionFactory) -> None:
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="0",
score=1,
metadata_={},
)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="1",
label="1",
metadata_={},
)
)
span_rowid = await session.scalar(
insert(models.Span)
.values(
trace_rowid=trace_row_id,
trace_rowid=trace_rowid,
span_id="567",
parent_id="234",
name="llm span",
Expand All @@ -214,3 +241,12 @@ async def abc_project(db: DbSessionFactory) -> None:
)
.returning(models.Span.id)
)
await session.execute(
insert(models.SpanAnnotation).values(
span_rowid=span_rowid,
annotator_kind="LLM",
name="1",
label="0",
metadata_={},
)
)
69 changes: 68 additions & 1 deletion tests/trace/dsl/test_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Any, List

import pandas as pd
import pytest
Expand Down Expand Up @@ -347,6 +347,26 @@ async def test_filter_for_none(
check_index_type=False,
)

sq = (
SpanQuery()
.select("name")
.where(
"parent_id is not None",
)
)
expected = pd.DataFrame(
{
"context.span_id": ["234", "345", "456", "567"],
"name": ["root span", "embedding span", "retriever span", "llm span"],
}
).set_index("context.span_id")
async with db() as session:
actual = await session.run_sync(sq, project_name="abc")
assert_frame_equal(
actual.sort_index().sort_index(axis=1),
expected.sort_index().sort_index(axis=1),
)

RogerHYang marked this conversation as resolved.
Show resolved Hide resolved

async def test_filter_for_not_none(
db: DbSessionFactory,
Expand All @@ -373,6 +393,26 @@ async def test_filter_for_not_none(
expected.sort_index().sort_index(axis=1),
)

sq = (
SpanQuery()
.select("name")
.where(
"output.value is None",
)
)
expected = pd.DataFrame(
{
"context.span_id": ["345", "456", "567"],
"name": ["embedding span", "retriever span", "llm span"],
}
).set_index("context.span_id")
async with db() as session:
actual = await session.run_sync(sq, project_name="abc")
assert_frame_equal(
actual.sort_index().sort_index(axis=1),
expected.sort_index().sort_index(axis=1),
)
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved


async def test_filter_for_substring_case_sensitive_not_glob_not_like(
db: DbSessionFactory,
Expand Down Expand Up @@ -891,6 +931,33 @@ async def test_filter_on_trace_id_multiple(
)


@pytest.mark.parametrize(
"condition,expected",
[
["evals['0'].score is not None", ["345", "456"]],
["evals['0'].score is None", ["234", "567"]],
["evals['0'].score == 0", ["345"]],
["evals['0'].score != 0", ["456"]],
["evals['0'].score != 0 or evals['0'].score is None", ["234", "456", "567"]],
["evals['1'].label is not None", ["456", "567"]],
["evals['1'].label is None", ["234", "345"]],
["evals['1'].label == '1'", ["456"]],
["evals['1'].label != '1'", ["567"]],
["evals['1'].label != '1' or evals['1'].label is None", ["234", "345", "567"]],
],
)
async def test_filter_on_span_annotation(
db: DbSessionFactory,
abc_project: Any,
condition: str,
expected: List[str],
) -> None:
sq = SpanQuery().select("span_id").where(condition)
async with db() as session:
actual = await session.run_sync(sq, project_name="abc")
assert sorted(actual.index) == expected


async def test_explode_embeddings_no_select(
db: DbSessionFactory,
default_project: Any,
Expand Down
Loading