Skip to content

Commit ac9c6c5

Browse files
committed
Tool Call Accuracy Evaluator (#40068)
* Tool Call Accuracy Evaluator * Review comments * Updating score key and output structure * Tool Call Accuracy Evaluator * Review comments * Updating score key and output structure * Updating prompt * Renaming parameter
1 parent 7656cf2 commit ac9c6c5

File tree

9 files changed

+430
-4
lines changed

9 files changed

+430
-4
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from ._evaluators._xpia import IndirectAttackEvaluator
2828
from ._evaluators._code_vulnerability import CodeVulnerabilityEvaluator
2929
from ._evaluators._ungrounded_attributes import UngroundedAttributesEvaluator
30+
from ._evaluators._isa import ISAEvaluator
31+
from ._evaluators._tool_call_accuracy import ToolCallAccuracyEvaluator
3032
from ._model_configurations import (
3133
AzureAIProject,
3234
AzureOpenAIModelConfiguration,
@@ -69,4 +71,6 @@
6971
"EvaluationResult",
7072
"CodeVulnerabilityEvaluator",
7173
"UngroundedAttributesEvaluator",
74+
"ISAEvaluator",
75+
"ToolCallAccuracyEvaluator",
7276
]

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from azure.core import CaseInsensitiveEnumMeta
77

88

9-
PROMPT_BASED_REASON_EVALUATORS = ["coherence", "relevance", "retrieval", "groundedness", "fluency"]
9+
PROMPT_BASED_REASON_EVALUATORS = ["coherence", "relevance", "retrieval", "groundedness", "fluency", "tool_call_accurate"]
1010

1111

1212
class CommonConstants:

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def validate_annotation(v: object, annotation: Union[str, type, object]) -> bool
275275
return cast(T_TypedDict, o)
276276

277277

278-
def parse_quality_evaluator_reason_score(llm_output: str) -> Tuple[float, str]:
278+
def parse_quality_evaluator_reason_score(llm_output: str, valid_score_range: str = "[1-5]") -> Tuple[float, str]:
279279
"""Parse the output of prompt-based quality evaluators that return a score and reason.
280280
281281
Current supported evaluators:
@@ -294,7 +294,7 @@ def parse_quality_evaluator_reason_score(llm_output: str) -> Tuple[float, str]:
294294
reason = ""
295295
if llm_output:
296296
try:
297-
score_pattern = r"<S2>\D*?([1-5]).*?</S2>"
297+
score_pattern = rf"<S2>\D*?({valid_score_range}).*?</S2>"
298298
reason_pattern = r"<S1>(.*?)</S1>"
299299
score_match = re.findall(score_pattern, llm_output, re.DOTALL)
300300
reason_match = re.findall(reason_pattern, llm_output, re.DOTALL)

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_eval.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ def __init__(self, real_call): # DO NOT ADD TYPEHINT PROMPT FLOW WILL SCREAM AT
494494
# Since we want this to be relatively call-agnostic, we just account for every input that any children
495495
# are known to throw at this, mash them into kwargs, and then pass them into the real call.
496496
async def __call__(
497-
self, *, query=None, response=None, context=None, conversation=None, ground_truth=None, **kwargs
497+
self, *, query=None, response=None, context=None, conversation=None, ground_truth=None,
498+
tool_call=None, tool_definitions=None, messages=None, **kwargs
498499
):
499500
if conversation is not None:
500501
kwargs["conversation"] = conversation
@@ -506,4 +507,11 @@ async def __call__(
506507
kwargs["context"] = context
507508
if ground_truth is not None:
508509
kwargs["ground_truth"] = ground_truth
510+
if tool_call is not None:
511+
kwargs["tool_call"] = tool_call
512+
if tool_definitions is not None:
513+
kwargs["tool_definitions"] = tool_definitions
514+
if messages is not None:
515+
kwargs["messages"] = messages
516+
509517
return await self._real_call(**kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from ._tool_call_accuracy import ToolCallAccuracyEvaluator
6+
7+
__all__ = [
8+
"ToolCallAccuracyEvaluator",
9+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
import math
5+
import os
6+
import logging
7+
import re
8+
from typing import Dict, List, Union, TypeVar, cast
9+
from typing_extensions import overload, override
10+
from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
11+
from azure.ai.evaluation._common.utils import remove_optional_singletons, parse_quality_evaluator_reason_score
12+
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
13+
from azure.ai.evaluation._common.constants import PROMPT_BASED_REASON_EVALUATORS
14+
15+
logger = logging.getLogger(__name__)
16+
17+
T_EvalValue = TypeVar("T_EvalValue")
18+
19+
class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]):
20+
"""The Tool Call Accuracy evaluator assesses how accurately an AI uses tools by examining:
21+
- Relevance to the conversation
22+
- Parameter correctness according to tool definitions
23+
- Parameter value extraction from the conversation
24+
- Potential usefulness of the tool call
25+
26+
The evaluator uses a binary scoring system (0 or 1):
27+
- Score 0: The tool call is irrelevant or contains information not in the conversation/definition
28+
- Score 1: The tool call is relevant with properly extracted parameters from the conversation
29+
30+
This evaluation focuses on measuring whether tool calls meaningfully contribute to addressing
31+
user needs while properly following tool definitions and using information present in the
32+
conversation history.
33+
34+
:param model_config: Configuration for the Azure OpenAI model.
35+
:type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration,
36+
~azure.ai.evaluation.OpenAIModelConfiguration]
37+
38+
.. admonition:: Example:
39+
40+
.. literalinclude:: ../samples/evaluation_samples_evaluate.py
41+
:start-after: [START tool_call_accuracy_evaluator]
42+
:end-before: [END tool_call_accuracy_evaluator]
43+
:language: python
44+
:dedent: 8
45+
:caption: Initialize and call a ToolCallAccuracyEvaluator.
46+
47+
.. note::
48+
49+
To align with our support of a diverse set of models, an output key without the `gpt_` prefix has been added.
50+
To maintain backwards compatibility, the old key with the `gpt_` prefix is still be present in the output;
51+
however, it is recommended to use the new key moving forward as the old key will be deprecated in the future.
52+
"""
53+
54+
_PROMPTY_FILE = "tool_call_accuracy.prompty"
55+
_RESULT_KEY = "tool_call_accurate"
56+
_AGGREGATE_RESULT_KEY = "tool_call_accuracy"
57+
58+
id = "id"
59+
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
60+
61+
@override
62+
def __init__(self, model_config):
63+
current_dir = os.path.dirname(__file__)
64+
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
65+
super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY)
66+
67+
# Types are the closet from Agent 1.0 I could find since Agent 2.0 python classes does not exist
68+
@overload
69+
def __call__(
70+
self,
71+
*,
72+
query: Union[str, List["Message"]], # Chat history upto the message that has the tool call being evaluated. Not including the message that has tool call. -- chat history
73+
tool_definitions: Union["FunctionToolDefinition", List["FunctionToolDefinition"]], # Definition of tool whose call is being evaluated
74+
tool_calls: Union["FunctionToolCall", List["FunctionToolCall"]] = None,
75+
response: Union[str, List["Message"]] = None
76+
) -> Dict[str, Union[str, float]]:
77+
"""
78+
Evaluate tool call accuracy. Accepts a query, tool definitions, and tool calls for evaluation.
79+
80+
:keyword query: Query or Chat history up to the message that has the tool call being evaluated.
81+
:paramtype query: Union[str, List[Message]]
82+
:keyword tool_definitions: List of tool definitions whose calls are being evaluated.
83+
:paramtype tool_definitions: Union[FunctionToolDefinition, List[FunctionToolDefinition]]
84+
:keyword tool_calls: Optional List of tool calls to evaluate. If not provided response should be provided and should have
85+
tool call(s) in it.
86+
:paramtype tool_calls: Union[FunctionToolCall, List[FunctionToolCall]]
87+
:keyword response: Optional response to be evaluated alongside the tool calls.
88+
If provided all tool calls in response will be evaluated when tool_calls parameter is not provided.
89+
If provided and tool_calls parameter is provided, only the tool calls in tool_calls parameter will be evaluated.
90+
If response has extra tool calls they will not be evaluated, response will be used to extract any tool calls that are needed for evaluating a certain tool call.
91+
Recommended to provide it when there are tool calls that depend on output of a previous tool call.
92+
:paramtype response: Union[str, List[Message]]
93+
:return: The tool selection evaluation results.
94+
:rtype: Dict[str, Union[str, float]]
95+
"""
96+
97+
def _convert_kwargs_to_eval_input(self, **kwargs):
98+
"""Convert an arbitrary input into a list of inputs for evaluators.
99+
It is assumed that evaluators generally make use of their inputs in one of two ways.
100+
Either they receive a collection of keyname inputs that are all single values
101+
(like a query and response), or they receive conversation that iss a list of dictionary
102+
values.
103+
104+
The self._singleton_inputs list assigned during initialization is used to find and extract
105+
singleton keywords, and self._allow_conversation_input is used to determine if a conversation
106+
is a valid input.
107+
108+
If both conversations and singletons are allowed, the function will raise an exception if both
109+
are inputted.
110+
111+
This function must be overridden by child classes IF they need to both a conversation and
112+
other inputs to be passed in.
113+
114+
:keyword kwargs: The inputs to convert.
115+
:type kwargs: Dict
116+
:return: A list of arbitrary values that are valid inputs for this evaluator's do_eval function.
117+
:rtype: List
118+
"""
119+
# TODO add warning that only tool calls of type function are supported
120+
# Collect inputs
121+
tool_calls = kwargs.get("tool_calls", None)
122+
tool_definitions = kwargs.get("tool_definitions")
123+
query = kwargs.get("query", None)
124+
response = kwargs.get("response", None)
125+
126+
if response is None and tool_calls is None:
127+
raise EvaluationException(
128+
message="Either response or tool_calls must be provided.",
129+
blame=ErrorBlame.USER_ERROR,
130+
category=ErrorCategory.MISSING_FIELD,
131+
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
132+
)
133+
134+
if tool_definitions is None:
135+
raise EvaluationException(
136+
message="Tool definitions must be provided.",
137+
blame=ErrorBlame.USER_ERROR,
138+
category=ErrorCategory.MISSING_FIELD,
139+
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
140+
)
141+
142+
# TODO : Support classes that represents tool calls, messages etc once client side definitions are available
143+
if tool_calls is None:
144+
# Extract tool calls from response if not provided
145+
tool_calls = []
146+
if isinstance(response, list):
147+
for message in response:
148+
if message.get("role") == "assistant":
149+
tool_calls.extend([content for content in message.get("content")
150+
if content.get("type") == "tool_call" and content.get("tool_call").get("type") == "function"])
151+
else:
152+
raise EvaluationException(
153+
message="response does not have tool calls. Either provide tool_calls or response with tool calls.",
154+
blame=ErrorBlame.USER_ERROR,
155+
category=ErrorCategory.MISSING_FIELD,
156+
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
157+
)
158+
159+
if not isinstance(tool_calls, list):
160+
tool_calls = [tool_calls]
161+
162+
if not isinstance(tool_definitions, list):
163+
tool_definitions = [tool_definitions]
164+
165+
eval_inputs = []
166+
# TODO : When evaluating an agent tool that depends on the output of a previous tool call,
167+
# we need to provide the output of the previous tool call as part of messages.
168+
for tool_call in tool_calls:
169+
if isinstance(tool_call, dict) and tool_call.get("type") == "tool_call" and tool_call.get("tool_call").get("type") == "function": # TODO assuming dict here but it can be a class
170+
function_name = tool_call.get("tool_call").get("function").get("name")
171+
tool_definition = [tool for tool in tool_definitions if tool.get("name") == function_name]
172+
if len(tool_definition) > 0:
173+
tool_definition = tool_definition
174+
else:
175+
raise EvaluationException(
176+
message="Tool definition not found",
177+
blame=ErrorBlame.USER_ERROR,
178+
category=ErrorCategory.INVALID_VALUE,
179+
target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR,
180+
)
181+
eval_inputs.append({"query": query, "tool_call": tool_call, "tool_definition": tool_definition})
182+
183+
return eval_inputs
184+
185+
@override
186+
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # type: ignore[override]
187+
"""Do a relevance evaluation.
188+
189+
:param eval_input: The input to the evaluator. Expected to contain
190+
whatever inputs are needed for the _flow method, including context
191+
and other fields depending on the child class.
192+
:type eval_input: Dict
193+
:return: The evaluation result.
194+
:rtype: Dict
195+
"""
196+
llm_output = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input)
197+
198+
score = math.nan
199+
if llm_output:
200+
score, reason = parse_quality_evaluator_reason_score(llm_output, valid_score_range="[0-1]")
201+
return {
202+
self._result_key: bool(float(score)),
203+
f"{self._result_key}_reason": reason,
204+
"tool_call_id" : eval_input.get("tool_call").get("tool_call").get("id"),
205+
}
206+
return {self._result_key: float(score)}
207+
208+
async def _real_call(self, **kwargs):
209+
"""The asynchronous call where real end-to-end evaluation logic is performed.
210+
211+
:keyword kwargs: The inputs to evaluate.
212+
:type kwargs: Dict
213+
:return: The evaluation result.
214+
:rtype: Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]
215+
"""
216+
# Convert inputs into list of evaluable inputs.
217+
eval_input_list = self._convert_kwargs_to_eval_input(**kwargs)
218+
per_turn_results = []
219+
# Evaluate all inputs.
220+
for eval_input in eval_input_list:
221+
per_turn_results.append(await self._do_eval(eval_input))
222+
223+
return self._aggregate_results(per_turn_results=per_turn_results)
224+
225+
def _aggregate_results(self, per_turn_results):
226+
"""Aggregate the evaluation results of each conversation turn into a single result.
227+
228+
Exact implementation might need to vary slightly depending on the results produced.
229+
Default behavior is to average the all number-based outputs.
230+
231+
:param per_turn_results: List of evaluation results for each turn in the conversation.
232+
:type per_turn_results: List[Dict]
233+
:return: A dictionary containing aggregated results, with numeric metrics having their
234+
means as top-level values in the dictionary, and all original
235+
values (including non-numerics) located in under the "evaluation_per_turn" key,
236+
which each sub-key being a metric and each sub-value being a the list of that metric's
237+
per-turn values.
238+
:rtype: AggregateResult[T_EvalValue]
239+
"""
240+
241+
aggregated: Dict[str, Union[float, Dict[str, List[T_EvalValue]]]] = {}
242+
evaluation_per_turn: Dict[str, List[T_EvalValue]] = {}
243+
244+
# Go over each turn, and rotate the results into a
245+
# metric: List[values] format for the evals_per_turn dictionary.
246+
247+
score = sum([1 if per_turn_result.get(self._result_key) else 0 for per_turn_result in per_turn_results])/len(per_turn_results)
248+
aggregated[self._AGGREGATE_RESULT_KEY] = score
249+
250+
# for turn in per_turn_results:
251+
# for metric, value in turn.items():
252+
# if metric not in evaluation_per_turn:
253+
# evaluation_per_turn[metric] = []
254+
# evaluation_per_turn[metric].append(value)
255+
#
256+
# # Find and average all numeric values
257+
# for metric, values in evaluation_per_turn.items():
258+
# if all(isinstance(value, (int, float)) for value in values):
259+
# aggregated[metric] = self._conversation_aggregation_function(cast(List[Union[int, float]], values))
260+
# # Slap the per-turn results back in.
261+
aggregated["per_tool_call_details"] = per_turn_results
262+
return aggregated
263+
264+
@override
265+
def __call__( # pylint: disable=docstring-missing-param
266+
self,
267+
*args,
268+
**kwargs,
269+
):
270+
"""
271+
Evaluate tool call accuracy. Accepts a query, tool definitions, and tool calls for evaluation.
272+
273+
:keyword query: Query or Chat history up to the message that has the tool call being evaluated.
274+
:paramtype query: Union[str, List[ThreadMessage]]
275+
:keyword tool_definitions: List of tool definitions whose calls are being evaluated.
276+
:paramtype tool_definitions: List[ToolDefinition]
277+
:keyword tool_calls: Optional List of tool calls to evaluate. If not provided response should be provided and should have
278+
tool call(s) in it.
279+
:paramtype tool_calls: List[ToolCall]
280+
:keyword response: Optional response to be evaluated alongside the tool calls.
281+
If provided all tool calls in response will be evaluated when tool_calls parameter is not provided.
282+
If provided and tool_calls parameter is provided, only the tool calls in tool_calls parameter will be evaluated.
283+
If response has extra tool calls they will not be evaluated, response will be used to extract any tool calls that are needed for evaluating a certain tool call.
284+
Recommended to provide it when there are tool calls that depend on output of a previous tool call.
285+
:paramtype response: Union[str, List[ThreadMessage]]
286+
:return: The tool selection evaluation results.
287+
:rtype: Dict[str, Union[str, float]]
288+
"""
289+
return super().__call__(*args, **kwargs)

0 commit comments

Comments
 (0)