|
| 1 | +# --------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# --------------------------------------------------------- |
| 4 | +import math |
| 5 | +import os |
| 6 | +import logging |
| 7 | +import re |
| 8 | +from typing import Dict, List, Union, TypeVar, cast |
| 9 | +from typing_extensions import overload, override |
| 10 | +from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase |
| 11 | +from azure.ai.evaluation._common.utils import remove_optional_singletons, parse_quality_evaluator_reason_score |
| 12 | +from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException |
| 13 | +from azure.ai.evaluation._common.constants import PROMPT_BASED_REASON_EVALUATORS |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +T_EvalValue = TypeVar("T_EvalValue") |
| 18 | + |
| 19 | +class ToolCallAccuracyEvaluator(PromptyEvaluatorBase[Union[str, float]]): |
| 20 | + """The Tool Call Accuracy evaluator assesses how accurately an AI uses tools by examining: |
| 21 | + - Relevance to the conversation |
| 22 | + - Parameter correctness according to tool definitions |
| 23 | + - Parameter value extraction from the conversation |
| 24 | + - Potential usefulness of the tool call |
| 25 | +
|
| 26 | + The evaluator uses a binary scoring system (0 or 1): |
| 27 | + - Score 0: The tool call is irrelevant or contains information not in the conversation/definition |
| 28 | + - Score 1: The tool call is relevant with properly extracted parameters from the conversation |
| 29 | +
|
| 30 | + This evaluation focuses on measuring whether tool calls meaningfully contribute to addressing |
| 31 | + user needs while properly following tool definitions and using information present in the |
| 32 | + conversation history. |
| 33 | +
|
| 34 | + :param model_config: Configuration for the Azure OpenAI model. |
| 35 | + :type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, |
| 36 | + ~azure.ai.evaluation.OpenAIModelConfiguration] |
| 37 | +
|
| 38 | + .. admonition:: Example: |
| 39 | +
|
| 40 | + .. literalinclude:: ../samples/evaluation_samples_evaluate.py |
| 41 | + :start-after: [START tool_call_accuracy_evaluator] |
| 42 | + :end-before: [END tool_call_accuracy_evaluator] |
| 43 | + :language: python |
| 44 | + :dedent: 8 |
| 45 | + :caption: Initialize and call a ToolCallAccuracyEvaluator. |
| 46 | +
|
| 47 | + .. note:: |
| 48 | +
|
| 49 | + To align with our support of a diverse set of models, an output key without the `gpt_` prefix has been added. |
| 50 | + To maintain backwards compatibility, the old key with the `gpt_` prefix is still be present in the output; |
| 51 | + however, it is recommended to use the new key moving forward as the old key will be deprecated in the future. |
| 52 | + """ |
| 53 | + |
| 54 | + _PROMPTY_FILE = "tool_call_accuracy.prompty" |
| 55 | + _RESULT_KEY = "tool_call_accurate" |
| 56 | + _AGGREGATE_RESULT_KEY = "tool_call_accuracy" |
| 57 | + |
| 58 | + id = "id" |
| 59 | + """Evaluator identifier, experimental and to be used only with evaluation in cloud.""" |
| 60 | + |
| 61 | + @override |
| 62 | + def __init__(self, model_config): |
| 63 | + current_dir = os.path.dirname(__file__) |
| 64 | + prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) |
| 65 | + super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY) |
| 66 | + |
| 67 | + # Types are the closet from Agent 1.0 I could find since Agent 2.0 python classes does not exist |
| 68 | + @overload |
| 69 | + def __call__( |
| 70 | + self, |
| 71 | + *, |
| 72 | + query: Union[str, List["Message"]], # Chat history upto the message that has the tool call being evaluated. Not including the message that has tool call. -- chat history |
| 73 | + tool_definitions: Union["FunctionToolDefinition", List["FunctionToolDefinition"]], # Definition of tool whose call is being evaluated |
| 74 | + tool_calls: Union["FunctionToolCall", List["FunctionToolCall"]] = None, |
| 75 | + response: Union[str, List["Message"]] = None |
| 76 | + ) -> Dict[str, Union[str, float]]: |
| 77 | + """ |
| 78 | + Evaluate tool call accuracy. Accepts a query, tool definitions, and tool calls for evaluation. |
| 79 | +
|
| 80 | + :keyword query: Query or Chat history up to the message that has the tool call being evaluated. |
| 81 | + :paramtype query: Union[str, List[Message]] |
| 82 | + :keyword tool_definitions: List of tool definitions whose calls are being evaluated. |
| 83 | + :paramtype tool_definitions: Union[FunctionToolDefinition, List[FunctionToolDefinition]] |
| 84 | + :keyword tool_calls: Optional List of tool calls to evaluate. If not provided response should be provided and should have |
| 85 | + tool call(s) in it. |
| 86 | + :paramtype tool_calls: Union[FunctionToolCall, List[FunctionToolCall]] |
| 87 | + :keyword response: Optional response to be evaluated alongside the tool calls. |
| 88 | + If provided all tool calls in response will be evaluated when tool_calls parameter is not provided. |
| 89 | + If provided and tool_calls parameter is provided, only the tool calls in tool_calls parameter will be evaluated. |
| 90 | + If response has extra tool calls they will not be evaluated, response will be used to extract any tool calls that are needed for evaluating a certain tool call. |
| 91 | + Recommended to provide it when there are tool calls that depend on output of a previous tool call. |
| 92 | + :paramtype response: Union[str, List[Message]] |
| 93 | + :return: The tool selection evaluation results. |
| 94 | + :rtype: Dict[str, Union[str, float]] |
| 95 | + """ |
| 96 | + |
| 97 | + def _convert_kwargs_to_eval_input(self, **kwargs): |
| 98 | + """Convert an arbitrary input into a list of inputs for evaluators. |
| 99 | + It is assumed that evaluators generally make use of their inputs in one of two ways. |
| 100 | + Either they receive a collection of keyname inputs that are all single values |
| 101 | + (like a query and response), or they receive conversation that iss a list of dictionary |
| 102 | + values. |
| 103 | +
|
| 104 | + The self._singleton_inputs list assigned during initialization is used to find and extract |
| 105 | + singleton keywords, and self._allow_conversation_input is used to determine if a conversation |
| 106 | + is a valid input. |
| 107 | +
|
| 108 | + If both conversations and singletons are allowed, the function will raise an exception if both |
| 109 | + are inputted. |
| 110 | +
|
| 111 | + This function must be overridden by child classes IF they need to both a conversation and |
| 112 | + other inputs to be passed in. |
| 113 | +
|
| 114 | + :keyword kwargs: The inputs to convert. |
| 115 | + :type kwargs: Dict |
| 116 | + :return: A list of arbitrary values that are valid inputs for this evaluator's do_eval function. |
| 117 | + :rtype: List |
| 118 | + """ |
| 119 | + # TODO add warning that only tool calls of type function are supported |
| 120 | + # Collect inputs |
| 121 | + tool_calls = kwargs.get("tool_calls", None) |
| 122 | + tool_definitions = kwargs.get("tool_definitions") |
| 123 | + query = kwargs.get("query", None) |
| 124 | + response = kwargs.get("response", None) |
| 125 | + |
| 126 | + if response is None and tool_calls is None: |
| 127 | + raise EvaluationException( |
| 128 | + message="Either response or tool_calls must be provided.", |
| 129 | + blame=ErrorBlame.USER_ERROR, |
| 130 | + category=ErrorCategory.MISSING_FIELD, |
| 131 | + target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR, |
| 132 | + ) |
| 133 | + |
| 134 | + if tool_definitions is None: |
| 135 | + raise EvaluationException( |
| 136 | + message="Tool definitions must be provided.", |
| 137 | + blame=ErrorBlame.USER_ERROR, |
| 138 | + category=ErrorCategory.MISSING_FIELD, |
| 139 | + target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR, |
| 140 | + ) |
| 141 | + |
| 142 | + # TODO : Support classes that represents tool calls, messages etc once client side definitions are available |
| 143 | + if tool_calls is None: |
| 144 | + # Extract tool calls from response if not provided |
| 145 | + tool_calls = [] |
| 146 | + if isinstance(response, list): |
| 147 | + for message in response: |
| 148 | + if message.get("role") == "assistant": |
| 149 | + tool_calls.extend([content for content in message.get("content") |
| 150 | + if content.get("type") == "tool_call" and content.get("tool_call").get("type") == "function"]) |
| 151 | + else: |
| 152 | + raise EvaluationException( |
| 153 | + message="response does not have tool calls. Either provide tool_calls or response with tool calls.", |
| 154 | + blame=ErrorBlame.USER_ERROR, |
| 155 | + category=ErrorCategory.MISSING_FIELD, |
| 156 | + target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR, |
| 157 | + ) |
| 158 | + |
| 159 | + if not isinstance(tool_calls, list): |
| 160 | + tool_calls = [tool_calls] |
| 161 | + |
| 162 | + if not isinstance(tool_definitions, list): |
| 163 | + tool_definitions = [tool_definitions] |
| 164 | + |
| 165 | + eval_inputs = [] |
| 166 | + # TODO : When evaluating an agent tool that depends on the output of a previous tool call, |
| 167 | + # we need to provide the output of the previous tool call as part of messages. |
| 168 | + for tool_call in tool_calls: |
| 169 | + if isinstance(tool_call, dict) and tool_call.get("type") == "tool_call" and tool_call.get("tool_call").get("type") == "function": # TODO assuming dict here but it can be a class |
| 170 | + function_name = tool_call.get("tool_call").get("function").get("name") |
| 171 | + tool_definition = [tool for tool in tool_definitions if tool.get("name") == function_name] |
| 172 | + if len(tool_definition) > 0: |
| 173 | + tool_definition = tool_definition |
| 174 | + else: |
| 175 | + raise EvaluationException( |
| 176 | + message="Tool definition not found", |
| 177 | + blame=ErrorBlame.USER_ERROR, |
| 178 | + category=ErrorCategory.INVALID_VALUE, |
| 179 | + target=ErrorTarget.TOOL_CALL_ACCURACY_EVALUATOR, |
| 180 | + ) |
| 181 | + eval_inputs.append({"query": query, "tool_call": tool_call, "tool_definition": tool_definition}) |
| 182 | + |
| 183 | + return eval_inputs |
| 184 | + |
| 185 | + @override |
| 186 | + async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # type: ignore[override] |
| 187 | + """Do a relevance evaluation. |
| 188 | +
|
| 189 | + :param eval_input: The input to the evaluator. Expected to contain |
| 190 | + whatever inputs are needed for the _flow method, including context |
| 191 | + and other fields depending on the child class. |
| 192 | + :type eval_input: Dict |
| 193 | + :return: The evaluation result. |
| 194 | + :rtype: Dict |
| 195 | + """ |
| 196 | + llm_output = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input) |
| 197 | + |
| 198 | + score = math.nan |
| 199 | + if llm_output: |
| 200 | + score, reason = parse_quality_evaluator_reason_score(llm_output, valid_score_range="[0-1]") |
| 201 | + return { |
| 202 | + self._result_key: bool(float(score)), |
| 203 | + f"{self._result_key}_reason": reason, |
| 204 | + "tool_call_id" : eval_input.get("tool_call").get("tool_call").get("id"), |
| 205 | + } |
| 206 | + return {self._result_key: float(score)} |
| 207 | + |
| 208 | + async def _real_call(self, **kwargs): |
| 209 | + """The asynchronous call where real end-to-end evaluation logic is performed. |
| 210 | +
|
| 211 | + :keyword kwargs: The inputs to evaluate. |
| 212 | + :type kwargs: Dict |
| 213 | + :return: The evaluation result. |
| 214 | + :rtype: Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]] |
| 215 | + """ |
| 216 | + # Convert inputs into list of evaluable inputs. |
| 217 | + eval_input_list = self._convert_kwargs_to_eval_input(**kwargs) |
| 218 | + per_turn_results = [] |
| 219 | + # Evaluate all inputs. |
| 220 | + for eval_input in eval_input_list: |
| 221 | + per_turn_results.append(await self._do_eval(eval_input)) |
| 222 | + |
| 223 | + return self._aggregate_results(per_turn_results=per_turn_results) |
| 224 | + |
| 225 | + def _aggregate_results(self, per_turn_results): |
| 226 | + """Aggregate the evaluation results of each conversation turn into a single result. |
| 227 | +
|
| 228 | + Exact implementation might need to vary slightly depending on the results produced. |
| 229 | + Default behavior is to average the all number-based outputs. |
| 230 | +
|
| 231 | + :param per_turn_results: List of evaluation results for each turn in the conversation. |
| 232 | + :type per_turn_results: List[Dict] |
| 233 | + :return: A dictionary containing aggregated results, with numeric metrics having their |
| 234 | + means as top-level values in the dictionary, and all original |
| 235 | + values (including non-numerics) located in under the "evaluation_per_turn" key, |
| 236 | + which each sub-key being a metric and each sub-value being a the list of that metric's |
| 237 | + per-turn values. |
| 238 | + :rtype: AggregateResult[T_EvalValue] |
| 239 | + """ |
| 240 | + |
| 241 | + aggregated: Dict[str, Union[float, Dict[str, List[T_EvalValue]]]] = {} |
| 242 | + evaluation_per_turn: Dict[str, List[T_EvalValue]] = {} |
| 243 | + |
| 244 | + # Go over each turn, and rotate the results into a |
| 245 | + # metric: List[values] format for the evals_per_turn dictionary. |
| 246 | + |
| 247 | + score = sum([1 if per_turn_result.get(self._result_key) else 0 for per_turn_result in per_turn_results])/len(per_turn_results) |
| 248 | + aggregated[self._AGGREGATE_RESULT_KEY] = score |
| 249 | + |
| 250 | + # for turn in per_turn_results: |
| 251 | + # for metric, value in turn.items(): |
| 252 | + # if metric not in evaluation_per_turn: |
| 253 | + # evaluation_per_turn[metric] = [] |
| 254 | + # evaluation_per_turn[metric].append(value) |
| 255 | + # |
| 256 | + # # Find and average all numeric values |
| 257 | + # for metric, values in evaluation_per_turn.items(): |
| 258 | + # if all(isinstance(value, (int, float)) for value in values): |
| 259 | + # aggregated[metric] = self._conversation_aggregation_function(cast(List[Union[int, float]], values)) |
| 260 | + # # Slap the per-turn results back in. |
| 261 | + aggregated["per_tool_call_details"] = per_turn_results |
| 262 | + return aggregated |
| 263 | + |
| 264 | + @override |
| 265 | + def __call__( # pylint: disable=docstring-missing-param |
| 266 | + self, |
| 267 | + *args, |
| 268 | + **kwargs, |
| 269 | + ): |
| 270 | + """ |
| 271 | + Evaluate tool call accuracy. Accepts a query, tool definitions, and tool calls for evaluation. |
| 272 | +
|
| 273 | + :keyword query: Query or Chat history up to the message that has the tool call being evaluated. |
| 274 | + :paramtype query: Union[str, List[ThreadMessage]] |
| 275 | + :keyword tool_definitions: List of tool definitions whose calls are being evaluated. |
| 276 | + :paramtype tool_definitions: List[ToolDefinition] |
| 277 | + :keyword tool_calls: Optional List of tool calls to evaluate. If not provided response should be provided and should have |
| 278 | + tool call(s) in it. |
| 279 | + :paramtype tool_calls: List[ToolCall] |
| 280 | + :keyword response: Optional response to be evaluated alongside the tool calls. |
| 281 | + If provided all tool calls in response will be evaluated when tool_calls parameter is not provided. |
| 282 | + If provided and tool_calls parameter is provided, only the tool calls in tool_calls parameter will be evaluated. |
| 283 | + If response has extra tool calls they will not be evaluated, response will be used to extract any tool calls that are needed for evaluating a certain tool call. |
| 284 | + Recommended to provide it when there are tool calls that depend on output of a previous tool call. |
| 285 | + :paramtype response: Union[str, List[ThreadMessage]] |
| 286 | + :return: The tool selection evaluation results. |
| 287 | + :rtype: Dict[str, Union[str, float]] |
| 288 | + """ |
| 289 | + return super().__call__(*args, **kwargs) |
0 commit comments