Skip to content

Commit

Permalink
feat: Extend evals DSL to accept 'annotations' symbol (#3939)
Browse files Browse the repository at this point in the history
* Sketch out basic Annotations lexer and parser

* Convert to AST Parser

* Extend existing evals DSL

* Add cast for type checker

* Address PR feedback
  • Loading branch information
anticorrelator authored Jul 23, 2024
1 parent ef7cf61 commit 659b674
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 36 deletions.
61 changes: 38 additions & 23 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from itertools import chain
from random import randint
from types import MappingProxyType
from uuid import uuid4

import sqlalchemy
from sqlalchemy.orm import Mapped, aliased
Expand All @@ -22,11 +22,14 @@
)


EvalAttribute: TypeAlias = typing.Literal["label", "score"]
EvalExpression: TypeAlias = str
EvalName: TypeAlias = str
AnnotationType: TypeAlias = typing.Literal["annotations", "evals"]
AnnotationAttribute: TypeAlias = typing.Literal["label", "score"]
AnnotationExpression: TypeAlias = str
AnnotationName: TypeAlias = str

EVAL_EXPRESSION_PATTERN = re.compile(r"""\b(evals\[(".*?"|'.*?')\][.](label|score))\b""")
EVAL_EXPRESSION_PATTERN = re.compile(
r"""\b((annotations|evals)\[(".*?"|'.*?')\][.](label|score))\b"""
)


@dataclass(frozen=True)
Expand All @@ -46,9 +49,10 @@ class AliasedAnnotationRelation:

def __post_init__(self) -> None:
table_alias = f"span_annotation_{self.index}"
alias_id = f"{randint(0, 10**6):06d}" # prevent conflicts with user-defined attributes
alias_id = uuid4().hex
label_attribute_alias = f"{table_alias}_label_{alias_id}"
score_attribute_alias = f"{table_alias}_score_{alias_id}"

table = aliased(models.SpanAnnotation, name=table_alias)
object.__setattr__(self, "_label_attribute_alias", label_attribute_alias)
object.__setattr__(self, "_score_attribute_alias", score_attribute_alias)
Expand All @@ -67,7 +71,7 @@ def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]:
yield self._label_attribute_alias, self.table.label
yield self._score_attribute_alias, self.table.score

def attribute_alias(self, attribute: EvalAttribute) -> str:
def attribute_alias(self, attribute: AnnotationAttribute) -> str:
"""
Returns an alias for the given attribute (i.e., column).
"""
Expand Down Expand Up @@ -579,7 +583,7 @@ def _validate_expression(
_is_subscript(node, "metadata") or _is_subscript(node, "attributes")
) and _get_attribute_keys_list(node) is not None:
continue
elif _is_eval(node) and _get_subscript_key(node) is not None:
elif _is_annotation(node) and _get_subscript_key(node) is not None:
# e.g. `evals["name"]`
if not (eval_name := _get_subscript_key(node)) or (
valid_eval_names is not None and eval_name not in valid_eval_names
Expand All @@ -601,7 +605,7 @@ def _validate_expression(
else ""
)
continue
elif isinstance(node, ast.Attribute) and _is_eval(node.value):
elif isinstance(node, ast.Attribute) and _is_annotation(node.value):
# e.g. `evals["name"].score`
if (attr := node.attr) not in valid_eval_attributes:
source_segment = typing.cast(str, ast.get_source_segment(source, node))
Expand Down Expand Up @@ -669,12 +673,12 @@ def _as_attribute(
)


def _is_eval(node: typing.Any) -> TypeGuard[ast.Subscript]:
def _is_annotation(node: typing.Any) -> TypeGuard[ast.Subscript]:
# e.g. `evals["name"]`
return (
isinstance(node, ast.Subscript)
and isinstance(value := node.value, ast.Name)
and value.id == "evals"
and value.id in ["evals", "annotations"]
)


Expand Down Expand Up @@ -817,34 +821,45 @@ def _apply_eval_aliasing(
span_annotation_0_label_123 == 'correct' or span_annotation_0_score_456 < 0.5
```
"""
eval_aliases: typing.Dict[EvalName, AliasedAnnotationRelation] = {}
for eval_expression, eval_name, eval_attribute in _parse_eval_expressions_and_names(source):
if (eval_alias := eval_aliases.get(eval_name)) is None:
eval_alias = AliasedAnnotationRelation(index=len(eval_aliases), name=eval_name)
eval_aliases[eval_name] = eval_alias
alias_name = eval_alias.attribute_alias(eval_attribute)
source = source.replace(eval_expression, alias_name)
eval_aliases: typing.Dict[AnnotationName, AliasedAnnotationRelation] = {}
for (
annotation_expression,
annotation_type,
annotation_name,
annotation_attribute,
) in _parse_annotation_expressions_and_names(source):
if (eval_alias := eval_aliases.get(annotation_name)) is None:
eval_alias = AliasedAnnotationRelation(index=len(eval_aliases), name=annotation_name)
eval_aliases[annotation_name] = eval_alias
alias_name = eval_alias.attribute_alias(annotation_attribute)
source = source.replace(annotation_expression, alias_name)
return source, tuple(eval_aliases.values())


def _parse_eval_expressions_and_names(
def _parse_annotation_expressions_and_names(
source: str,
) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName, EvalAttribute]]:
) -> typing.Iterator[
typing.Tuple[AnnotationExpression, AnnotationType, AnnotationName, AnnotationAttribute]
]:
"""
Parses filter conditions for evaluation expressions of the form:
```
evals["<eval-name>"].<attribute>
annotations["eval-name"].<attribute>
```
"""
for match in EVAL_EXPRESSION_PATTERN.finditer(source):
(
eval_expression,
annotation_expression,
annotation_type,
quoted_eval_name,
evaluation_attribute_name,
) = match.groups()
annotation_type = typing.cast(AnnotationType, annotation_type)
yield (
eval_expression,
annotation_expression,
annotation_type,
quoted_eval_name[1:-1],
typing.cast(EvalAttribute, evaluation_attribute_name),
typing.cast(AnnotationAttribute, evaluation_attribute_name),
)
36 changes: 23 additions & 13 deletions tests/trace/dsl/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from typing import List, Optional
from unittest.mock import patch
from uuid import UUID

import phoenix.trace.dsl.filter
import pytest
Expand Down Expand Up @@ -152,8 +153,8 @@ async def test_filter_translated(
) -> None:
with patch.object(
phoenix.trace.dsl.filter,
"randint",
return_value=0,
"uuid4",
return_value=UUID(hex="00000000000000000000000000000000"),
):
f = SpanFilter(expression)
assert unparse(f.translated).strip() == expected
Expand All @@ -166,37 +167,37 @@ async def test_filter_translated(
[
pytest.param(
"""evals["Q&A Correctness"].label is not None""",
"span_annotation_0_label_000000 is not None",
"span_annotation_0_label_00000000000000000000000000000000 is not None",
id="double-quoted-eval-name",
),
pytest.param(
"""evals['Q&A Correctness'].label is not None""",
"span_annotation_0_label_000000 is not None",
"span_annotation_0_label_00000000000000000000000000000000 is not None",
id="single-quoted-eval-name",
),
pytest.param(
"""evals[""].label is not None""",
"span_annotation_0_label_000000 is not None",
"span_annotation_0_label_00000000000000000000000000000000 is not None",
id="empty-eval-name",
),
pytest.param(
"""evals['Hallucination'].label == 'correct' or evals['Hallucination'].score < 0.5""", # noqa E501
"span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501
"span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501
id="repeated-single-quoted-eval-name",
),
pytest.param(
"""evals["Hallucination"].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501
"span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501
"span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501
id="repeated-double-quoted-eval-name",
),
pytest.param(
"""evals['Hallucination'].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501
"span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501
"span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501
id="repeated-mixed-quoted-eval-name",
),
pytest.param(
"""evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501
"span_annotation_0_label_000000 == 'correct' and span_annotation_1_score_000000 < 0.5", # noqa E501
"span_annotation_0_label_00000000000000000000000000000000 == 'correct' and span_annotation_1_score_00000000000000000000000000000000 < 0.5", # noqa E501
id="distinct-mixed-quoted-eval-names",
),
pytest.param(
Expand All @@ -206,7 +207,7 @@ async def test_filter_translated(
),
pytest.param(
"""evals["Hallucination"].label == 'correct' orevals["Hallucination"].score < 0.5""", # noqa E501
"""span_annotation_0_label_000000 == 'correct' orevals["Hallucination"].score < 0.5""", # noqa E501
"""span_annotation_0_label_00000000000000000000000000000000 == 'correct' orevals["Hallucination"].score < 0.5""", # noqa E501
id="no-word-boundary-on-the-left",
),
pytest.param(
Expand All @@ -216,17 +217,26 @@ async def test_filter_translated(
),
pytest.param(
"""0.5 <evals["Hallucination"].score""", # noqa E501
"""0.5 <span_annotation_0_score_000000""", # noqa E501
"""0.5 <span_annotation_0_score_00000000000000000000000000000000""", # noqa E501
id="left-word-boundary-without-space",
),
pytest.param(
"""evals["Hallucination"].score< 0.5""", # noqa E501
"""span_annotation_0_score_000000< 0.5""", # noqa E501
"""span_annotation_0_score_00000000000000000000000000000000< 0.5""", # noqa E501
id="right-word-boundary-without-space",
),
pytest.param(
"""annotations["Q&A Correctness"].label is not None""",
"span_annotation_0_label_00000000000000000000000000000000 is not None",
id="double-quoted-annotation-name",
),
],
)
def test_apply_eval_aliasing(filter_condition: str, expected: str) -> None:
with patch.object(phoenix.trace.dsl.filter, "randint", return_value=0):
with patch.object(
phoenix.trace.dsl.filter,
"uuid4",
return_value=UUID(hex="00000000000000000000000000000000"),
):
aliased, _ = _apply_eval_aliasing(filter_condition)
assert aliased == expected

0 comments on commit 659b674

Please sign in to comment.