1414
1515from __future__ import annotations
1616
17+ import logging
18+ from typing import ClassVar
1719from typing import Optional
1820
1921from google .genai import types as genai_types
22+ from pydantic import ValidationError
2023from typing_extensions import override
2124
2225from .eval_case import get_all_tool_calls
2629from .eval_metrics import MetricInfo
2730from .eval_metrics import MetricValueInfo
2831from .eval_metrics import PrebuiltMetrics
32+ from .eval_metrics import ToolTrajectoryCriterion
2933from .evaluator import EvalStatus
3034from .evaluator import EvaluationResult
3135from .evaluator import Evaluator
3236from .evaluator import PerInvocationResult
3337
38+ logger = logging .getLogger ("google_adk." + __name__ )
39+
3440
3541class TrajectoryEvaluator (Evaluator ):
36- """Evaluates tool use trajectories for accuracy."""
42+ """Evaluates tool use trajectories for accuracy.
43+
44+ This evaluator compares the sequence of tools called by the agent against a
45+ list of expected calls and computes an average score based on one of the match
46+ types: `EXACT`, `IN_ORDER`, or `ANY_ORDER`.
47+
48+ For each invocation being evaluated, this evaluator compares the list of
49+ tool calls produced by the agent with the list of expected tool calls using
50+ one of three match types. If the tool calls match based on the selected match
51+ type, a score of 1.0 is awarded for that invocation, otherwise the score is
52+ 0.0. The final value is the average of these scores across all
53+ invocations in the eval case.
54+
55+ The comparison can be done using one of following match types:
56+ - `EXACT`: Requires a perfect match between the actual and expected tool
57+ calls, with no extra or missing tool calls.
58+ - `IN_ORDER`: Requires all tool calls from the expected list to be present
59+ in the actual list, in the same order, but allows for other tool calls
60+ to appear in between.
61+ - `ANY_ORDER`: Requires all tool calls from the expected list to be
62+ present in the actual list, in any order, and allows for other tool
63+ calls to appear in between.
64+ """
65+
66+ criterion_type : ClassVar [type [ToolTrajectoryCriterion ]] = (
67+ ToolTrajectoryCriterion
68+ )
3769
3870 def __init__ (
3971 self ,
@@ -46,10 +78,25 @@ def __init__(
4678 " specified."
4779 )
4880
49- if eval_metric :
50- threshold = eval_metric .threshold
51-
52- self ._threshold = threshold
81+ if eval_metric and eval_metric .criterion :
82+ try :
83+ criterion = TrajectoryEvaluator .criterion_type .model_validate (
84+ eval_metric .criterion .model_dump ()
85+ )
86+ self ._threshold = criterion .threshold
87+ self ._match_type = criterion .match_type
88+ except ValidationError as e :
89+ expected_criterion_type_error = ValueError (
90+ f"`{ eval_metric .metric_name } ` metric expects a criterion of type"
91+ f" `{ TrajectoryEvaluator .criterion_type } `."
92+ )
93+ raise expected_criterion_type_error from e
94+ elif eval_metric :
95+ self ._threshold = eval_metric .threshold
96+ self ._match_type = ToolTrajectoryCriterion .MatchType .EXACT
97+ else :
98+ self ._threshold = threshold
99+ self ._match_type = ToolTrajectoryCriterion .MatchType .EXACT
53100
54101 @staticmethod
55102 def get_metric_info () -> MetricInfo :
@@ -82,14 +129,7 @@ def evaluate_invocations(
82129 per_invocation_results = []
83130
84131 for actual , expected in zip (actual_invocations , expected_invocations ):
85- actual_tool_uses = get_all_tool_calls (actual .intermediate_data )
86- expected_tool_uses = get_all_tool_calls (expected .intermediate_data )
87-
88- tool_use_accuracy = (
89- 1.0
90- if self ._are_tool_calls_equal (actual_tool_uses , expected_tool_uses )
91- else 0.0
92- )
132+ tool_use_accuracy = self ._calculate_tool_use_accuracy (actual , expected )
93133 per_invocation_results .append (
94134 PerInvocationResult (
95135 actual_invocation = actual ,
@@ -111,11 +151,128 @@ def evaluate_invocations(
111151
112152 return EvaluationResult ()
113153
114- def _are_tool_calls_equal (
154+ def _calculate_tool_use_accuracy (
155+ self ,
156+ actual_invocation : Invocation ,
157+ expected_invocation : Invocation ,
158+ ) -> float :
159+ """Calculates tool use accuracy for a single invocation."""
160+ actual_tool_uses = get_all_tool_calls (actual_invocation .intermediate_data )
161+ expected_tool_uses = get_all_tool_calls (
162+ expected_invocation .intermediate_data
163+ )
164+
165+ tool_use_match_status = False
166+ if self ._match_type == ToolTrajectoryCriterion .MatchType .EXACT :
167+ tool_use_match_status = self ._are_tool_calls_exact_match (
168+ actual_tool_uses , expected_tool_uses
169+ )
170+ elif self ._match_type == ToolTrajectoryCriterion .MatchType .IN_ORDER :
171+ tool_use_match_status = self ._are_tool_calls_in_order_match (
172+ actual_tool_uses , expected_tool_uses
173+ )
174+ elif self ._match_type == ToolTrajectoryCriterion .MatchType .ANY_ORDER :
175+ tool_use_match_status = self ._are_tool_calls_any_order_match (
176+ actual_tool_uses , expected_tool_uses
177+ )
178+ else :
179+ raise ValueError (f"Unsupported match type { self ._match_type } " )
180+
181+ return 1.0 if tool_use_match_status else 0.0
182+
183+ def _are_tool_calls_in_order_match (
184+ self ,
185+ actual_tool_calls : list [genai_types .FunctionCall ],
186+ expected_tool_calls : list [genai_types .FunctionCall ],
187+ ) -> bool :
188+ """Checks if expected tool calls appear in actual tool calls in order.
189+
190+ This method implements IN_ORDER match type. It allows for additional
191+ tool calls in actual_tool_calls, as long as all expected tool calls are
192+ present in the same order.
193+
194+ Args:
195+ actual_tool_calls: A list of tool calls that actually happened.
196+ expected_tool_calls: A list of tool calls that were expected to happen.
197+
198+ Returns:
199+ True if actual tool calls match expected tool calls in order,
200+ False otherwise.
201+ """
202+ if not expected_tool_calls :
203+ return True
204+ if not actual_tool_calls and expected_tool_calls :
205+ return False
206+
207+ expected_it = iter (expected_tool_calls )
208+ try :
209+ current_expected = next (expected_it )
210+ for actual in actual_tool_calls :
211+ if (
212+ actual .name == current_expected .name
213+ and actual .args == current_expected .args
214+ ):
215+ current_expected = next (expected_it )
216+ except StopIteration :
217+ return True
218+
219+ return False
220+
221+ def _are_tool_calls_any_order_match (
115222 self ,
116223 actual_tool_calls : list [genai_types .FunctionCall ],
117224 expected_tool_calls : list [genai_types .FunctionCall ],
118225 ) -> bool :
226+ """Checks if expected tool calls appear in actual tool calls in any order.
227+
228+ This method implements ANY_ORDER match type. It allows for additional
229+ tool calls in actual_tool_calls, as long as all expected tool calls are
230+ present.
231+
232+ Args:
233+ actual_tool_calls: A list of tool calls that actually happened.
234+ expected_tool_calls: A list of tool calls that were expected to happen.
235+
236+ Returns:
237+ True if actual tool calls contain all expected tool calls,
238+ False otherwise.
239+ """
240+ if not expected_tool_calls :
241+ return True
242+ if not actual_tool_calls and expected_tool_calls :
243+ return False
244+
245+ actual_tool_calls_copy = list (actual_tool_calls )
246+ for expected in expected_tool_calls :
247+ found = False
248+ for i , actual in enumerate (actual_tool_calls_copy ):
249+ if actual .name == expected .name and actual .args == expected .args :
250+ actual_tool_calls_copy .pop (i )
251+ found = True
252+ break
253+ if not found :
254+ return False
255+ return True
256+
257+ def _are_tool_calls_exact_match (
258+ self ,
259+ actual_tool_calls : list [genai_types .FunctionCall ],
260+ expected_tool_calls : list [genai_types .FunctionCall ],
261+ ) -> bool :
262+ """Checks if actual tool calls exactly match expected tool calls.
263+
264+ This method implements EXACT match type. It requires that
265+ actual_tool_calls and expected_tool_calls have the same tool calls in
266+ the same order, with no extra or missing tool calls.
267+
268+ Args:
269+ actual_tool_calls: A list of tool calls that actually happened.
270+ expected_tool_calls: A list of tool calls that were expected to happen.
271+
272+ Returns:
273+ True if actual tool calls exactly match expected tool calls,
274+ False otherwise.
275+ """
119276 if len (actual_tool_calls ) != len (expected_tool_calls ):
120277 return False
121278
0 commit comments