Skip to content

Commit e2d3b2d

Browse files
ankursharmascopybara-github
authored andcommitted
feat: Added support for InOrder and AnyOrder match in ToolTrajectoryAvgScore Metric
Co-authored-by: Ankur Sharma <ankusharma@google.com> PiperOrigin-RevId: 831413968
1 parent b2c8ba5 commit e2d3b2d

File tree

3 files changed

+467
-15
lines changed

3 files changed

+467
-15
lines changed

src/google/adk/evaluation/eval_metrics.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,76 @@ class HallucinationsCriterion(BaseCriterion):
150150
)
151151

152152

153+
class ToolTrajectoryCriterion(BaseCriterion):
154+
"""Criterion to use when evaluating agent's tool trajectories with a reference one."""
155+
156+
class MatchType(Enum):
157+
"""The type of Match between actual and expected tool call trajectories."""
158+
159+
EXACT = 0
160+
"""Requires a perfect match between the actual and expected tool calls."""
161+
162+
IN_ORDER = 1
163+
"""Requires the actual tool calls to be in the same order as expected tools,
164+
with allowance for extra tool calls to have happened.
165+
166+
This criteria is useful in assuring if certain key actions/tool calls
167+
occur and in certain order, leaving some scope for other tools calls to
168+
happen as well.
169+
170+
Example 1: Set of actual vs expected tool calls that satisfies the criteria:
171+
172+
Expected tools calls: [T1, T2, T3]
173+
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
174+
175+
This satisfies, as the tools T1, T2 and T3 happened in the "Actual" and in
176+
the same order.
177+
178+
Example 2: Set of actual vs expected tool calls that don't satisfy the
179+
criteria:
180+
181+
Expected tools calls: [T1, T2, T3, T4]
182+
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
183+
184+
While the tool calls T1, T2 and T3 happened in the "Actual" and in
185+
the same order as "Expected", but the the tool calls T4 is missing.
186+
"""
187+
188+
ANY_ORDER = 2
189+
"""Requires the actual tool calls to be in the any order as expected tools,
190+
with allowance for extra tool calls to have happened.
191+
192+
This criteria is helpful for cases where multiple tool calls about the same
193+
concept occur, like your agent issues 5 search queries. You don't really
194+
care the order in which the search queries are issues, till they occur.
195+
196+
Example 1: Set of actual vs expected tool calls that satisfies the criteria:
197+
198+
Expected tools calls: [T1, T2, T3]
199+
Actual tool calls: [T2, T2.1, T1, T1.1, T1.2, T3, T3.1]
200+
201+
This satisfies, as the tools T1, T2 and T3 happened in the "Actual" and
202+
are also present in expected. Note that the order is different.
203+
204+
Example 2: Set of actual vs expected tool calls that don't satisfy the
205+
criteria:
206+
207+
Expected tools calls: [T1, T2, T3, T4]
208+
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
209+
210+
While the tool calls T1, T2 and T3 happened in the "Actual" and in
211+
the same order as "Expected", but the the tool calls T4 is missing.
212+
"""
213+
214+
match_type: MatchType = Field(
215+
default=MatchType.EXACT,
216+
description=(
217+
"The type of Match between actual and expected tool call"
218+
" trajectories."
219+
),
220+
)
221+
222+
153223
class EvalMetric(EvalBaseModel):
154224
"""A metric used to evaluate a particular aspect of an eval case."""
155225

src/google/adk/evaluation/trajectory_evaluator.py

Lines changed: 171 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
from __future__ import annotations
1616

17+
import logging
18+
from typing import ClassVar
1719
from typing import Optional
1820

1921
from google.genai import types as genai_types
22+
from pydantic import ValidationError
2023
from typing_extensions import override
2124

2225
from .eval_case import get_all_tool_calls
@@ -26,14 +29,43 @@
2629
from .eval_metrics import MetricInfo
2730
from .eval_metrics import MetricValueInfo
2831
from .eval_metrics import PrebuiltMetrics
32+
from .eval_metrics import ToolTrajectoryCriterion
2933
from .evaluator import EvalStatus
3034
from .evaluator import EvaluationResult
3135
from .evaluator import Evaluator
3236
from .evaluator import PerInvocationResult
3337

38+
logger = logging.getLogger("google_adk." + __name__)
39+
3440

3541
class TrajectoryEvaluator(Evaluator):
36-
"""Evaluates tool use trajectories for accuracy."""
42+
"""Evaluates tool use trajectories for accuracy.
43+
44+
This evaluator compares the sequence of tools called by the agent against a
45+
list of expected calls and computes an average score based on one of the match
46+
types: `EXACT`, `IN_ORDER`, or `ANY_ORDER`.
47+
48+
For each invocation being evaluated, this evaluator compares the list of
49+
tool calls produced by the agent with the list of expected tool calls using
50+
one of three match types. If the tool calls match based on the selected match
51+
type, a score of 1.0 is awarded for that invocation, otherwise the score is
52+
0.0. The final value is the average of these scores across all
53+
invocations in the eval case.
54+
55+
The comparison can be done using one of following match types:
56+
- `EXACT`: Requires a perfect match between the actual and expected tool
57+
calls, with no extra or missing tool calls.
58+
- `IN_ORDER`: Requires all tool calls from the expected list to be present
59+
in the actual list, in the same order, but allows for other tool calls
60+
to appear in between.
61+
- `ANY_ORDER`: Requires all tool calls from the expected list to be
62+
present in the actual list, in any order, and allows for other tool
63+
calls to appear in between.
64+
"""
65+
66+
criterion_type: ClassVar[type[ToolTrajectoryCriterion]] = (
67+
ToolTrajectoryCriterion
68+
)
3769

3870
def __init__(
3971
self,
@@ -46,10 +78,25 @@ def __init__(
4678
" specified."
4779
)
4880

49-
if eval_metric:
50-
threshold = eval_metric.threshold
51-
52-
self._threshold = threshold
81+
if eval_metric and eval_metric.criterion:
82+
try:
83+
criterion = TrajectoryEvaluator.criterion_type.model_validate(
84+
eval_metric.criterion.model_dump()
85+
)
86+
self._threshold = criterion.threshold
87+
self._match_type = criterion.match_type
88+
except ValidationError as e:
89+
expected_criterion_type_error = ValueError(
90+
f"`{eval_metric.metric_name}` metric expects a criterion of type"
91+
f" `{TrajectoryEvaluator.criterion_type}`."
92+
)
93+
raise expected_criterion_type_error from e
94+
elif eval_metric:
95+
self._threshold = eval_metric.threshold
96+
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
97+
else:
98+
self._threshold = threshold
99+
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
53100

54101
@staticmethod
55102
def get_metric_info() -> MetricInfo:
@@ -82,14 +129,7 @@ def evaluate_invocations(
82129
per_invocation_results = []
83130

84131
for actual, expected in zip(actual_invocations, expected_invocations):
85-
actual_tool_uses = get_all_tool_calls(actual.intermediate_data)
86-
expected_tool_uses = get_all_tool_calls(expected.intermediate_data)
87-
88-
tool_use_accuracy = (
89-
1.0
90-
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
91-
else 0.0
92-
)
132+
tool_use_accuracy = self._calculate_tool_use_accuracy(actual, expected)
93133
per_invocation_results.append(
94134
PerInvocationResult(
95135
actual_invocation=actual,
@@ -111,11 +151,128 @@ def evaluate_invocations(
111151

112152
return EvaluationResult()
113153

114-
def _are_tool_calls_equal(
154+
def _calculate_tool_use_accuracy(
155+
self,
156+
actual_invocation: Invocation,
157+
expected_invocation: Invocation,
158+
) -> float:
159+
"""Calculates tool use accuracy for a single invocation."""
160+
actual_tool_uses = get_all_tool_calls(actual_invocation.intermediate_data)
161+
expected_tool_uses = get_all_tool_calls(
162+
expected_invocation.intermediate_data
163+
)
164+
165+
tool_use_match_status = False
166+
if self._match_type == ToolTrajectoryCriterion.MatchType.EXACT:
167+
tool_use_match_status = self._are_tool_calls_exact_match(
168+
actual_tool_uses, expected_tool_uses
169+
)
170+
elif self._match_type == ToolTrajectoryCriterion.MatchType.IN_ORDER:
171+
tool_use_match_status = self._are_tool_calls_in_order_match(
172+
actual_tool_uses, expected_tool_uses
173+
)
174+
elif self._match_type == ToolTrajectoryCriterion.MatchType.ANY_ORDER:
175+
tool_use_match_status = self._are_tool_calls_any_order_match(
176+
actual_tool_uses, expected_tool_uses
177+
)
178+
else:
179+
raise ValueError(f"Unsupported match type {self._match_type}")
180+
181+
return 1.0 if tool_use_match_status else 0.0
182+
183+
def _are_tool_calls_in_order_match(
184+
self,
185+
actual_tool_calls: list[genai_types.FunctionCall],
186+
expected_tool_calls: list[genai_types.FunctionCall],
187+
) -> bool:
188+
"""Checks if expected tool calls appear in actual tool calls in order.
189+
190+
This method implements IN_ORDER match type. It allows for additional
191+
tool calls in actual_tool_calls, as long as all expected tool calls are
192+
present in the same order.
193+
194+
Args:
195+
actual_tool_calls: A list of tool calls that actually happened.
196+
expected_tool_calls: A list of tool calls that were expected to happen.
197+
198+
Returns:
199+
True if actual tool calls match expected tool calls in order,
200+
False otherwise.
201+
"""
202+
if not expected_tool_calls:
203+
return True
204+
if not actual_tool_calls and expected_tool_calls:
205+
return False
206+
207+
expected_it = iter(expected_tool_calls)
208+
try:
209+
current_expected = next(expected_it)
210+
for actual in actual_tool_calls:
211+
if (
212+
actual.name == current_expected.name
213+
and actual.args == current_expected.args
214+
):
215+
current_expected = next(expected_it)
216+
except StopIteration:
217+
return True
218+
219+
return False
220+
221+
def _are_tool_calls_any_order_match(
115222
self,
116223
actual_tool_calls: list[genai_types.FunctionCall],
117224
expected_tool_calls: list[genai_types.FunctionCall],
118225
) -> bool:
226+
"""Checks if expected tool calls appear in actual tool calls in any order.
227+
228+
This method implements ANY_ORDER match type. It allows for additional
229+
tool calls in actual_tool_calls, as long as all expected tool calls are
230+
present.
231+
232+
Args:
233+
actual_tool_calls: A list of tool calls that actually happened.
234+
expected_tool_calls: A list of tool calls that were expected to happen.
235+
236+
Returns:
237+
True if actual tool calls contain all expected tool calls,
238+
False otherwise.
239+
"""
240+
if not expected_tool_calls:
241+
return True
242+
if not actual_tool_calls and expected_tool_calls:
243+
return False
244+
245+
actual_tool_calls_copy = list(actual_tool_calls)
246+
for expected in expected_tool_calls:
247+
found = False
248+
for i, actual in enumerate(actual_tool_calls_copy):
249+
if actual.name == expected.name and actual.args == expected.args:
250+
actual_tool_calls_copy.pop(i)
251+
found = True
252+
break
253+
if not found:
254+
return False
255+
return True
256+
257+
def _are_tool_calls_exact_match(
258+
self,
259+
actual_tool_calls: list[genai_types.FunctionCall],
260+
expected_tool_calls: list[genai_types.FunctionCall],
261+
) -> bool:
262+
"""Checks if actual tool calls exactly match expected tool calls.
263+
264+
This method implements EXACT match type. It requires that
265+
actual_tool_calls and expected_tool_calls have the same tool calls in
266+
the same order, with no extra or missing tool calls.
267+
268+
Args:
269+
actual_tool_calls: A list of tool calls that actually happened.
270+
expected_tool_calls: A list of tool calls that were expected to happen.
271+
272+
Returns:
273+
True if actual tool calls exactly match expected tool calls,
274+
False otherwise.
275+
"""
119276
if len(actual_tool_calls) != len(expected_tool_calls):
120277
return False
121278

0 commit comments

Comments
 (0)