Skip to content

Commit

Permalink
Add overloads for __call__ methods that accept query/response and con…
Browse files Browse the repository at this point in the history
…versation (Azure#38097)

* Add overloads for __call__ methods that take query/response and conversation

* remove callable type hint

* add docstrings/type hints

* fix a typo

* remove file

* remove a bad param

* add docs for relevance

* fix some missing type hints

* lint and run black

* merge with main

* fix some mypy errors, not all pylint

* fix black errors

* attempt to fix tests

* fix retrieval

* fix up tests and lint

* fix some docstrings to mark some things as optional
  • Loading branch information
needuv authored and allenkim0129 committed Nov 5, 2024
1 parent 1c3ffe2 commit 0619608
Show file tree
Hide file tree
Showing 18 changed files with 590 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os
from typing import Optional
from typing import Dict, Union, List

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


class CoherenceEvaluator(PromptyEvaluatorBase):
class CoherenceEvaluator(PromptyEvaluatorBase[Union[str, float]]):
"""
Initialize a coherence evaluator configured for a specific Azure OpenAI model.
Expand Down Expand Up @@ -49,13 +50,43 @@ def __init__(self, model_config):
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY)

@override
@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate coherence for given input of query, response
:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The coherence score.
:rtype: Dict[str, float]
"""

@overload
def __call__(
self,
*,
query: Optional[str] = None,
response: Optional[str] = None,
conversation=None,
conversation: Conversation,
) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate coherence for a conversation
:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The coherence score.
:rtype: Dict[str, Union[float, Dict[str, List[float]]]]
"""

@override
def __call__( # pylint: disable=docstring-missing-param
self,
*args,
**kwargs,
):
"""Evaluate coherence. Accepts either a query and response for a single evaluation,
Expand All @@ -73,4 +104,4 @@ def __call__(
:return: The relevance score.
:rtype: Union[Dict[str, float], Dict[str, Union[float, Dict[str, List[float]]]]]
"""
return super().__call__(query=query, response=response, conversation=conversation, **kwargs)
return super().__call__(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, Generic, List, TypedDict, TypeVar, Union, cast, final

from promptflow._utils.async_utils import async_run_allowing_running_loop
from typing_extensions import ParamSpec, TypeAlias
from typing_extensions import ParamSpec, TypeAlias, get_overloads

from azure.ai.evaluation._common.math import list_mean
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
Expand Down Expand Up @@ -88,7 +88,11 @@ def __init__(
# This needs to be overridden just to change the function header into something more informative,
# and to be able to add a more specific docstring. The actual function contents should just be
# super().__call__(<inputs>)
def __call__(self, **kwargs) -> Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]:
def __call__( # pylint: disable=docstring-missing-param
self,
*args,
**kwargs,
) -> Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]:
"""Evaluate a given input. This method serves as a wrapper and is meant to be overridden by child classes for
one main reason - to overwrite the method headers and docstring to include additional inputs as needed.
The actual behavior of this function shouldn't change beyond adding more inputs to the
Expand Down Expand Up @@ -127,11 +131,19 @@ def _derive_singleton_inputs(self) -> List[str]:
:rtype: List[str]
"""

overloads = get_overloads(self.__call__)
if not overloads:
call_signatures = [inspect.signature(self.__call__)]
else:
call_signatures = [inspect.signature(overload) for overload in overloads]
call_signature = inspect.signature(self.__call__)
singletons = []
for param in call_signature.parameters:
if param not in self._not_singleton_inputs:
singletons.append(param)
for call_signature in call_signatures:
params = call_signature.parameters
if any(not_singleton_input in params for not_singleton_input in self._not_singleton_inputs):
continue
# exclude self since it is not a singleton input
singletons.extend([p for p in params if p != "self"])
return singletons

def _derive_conversation_converter(self) -> Callable[[Dict], List[DerivedEvalInput]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
import re
from typing import Dict, Union
from typing import Dict, TypeVar, Union

from promptflow.core import AsyncPrompty
from typing_extensions import override
Expand All @@ -18,8 +18,10 @@
except ImportError:
USER_AGENT = "None"

T = TypeVar("T")

class PromptyEvaluatorBase(EvaluatorBase[float]):

class PromptyEvaluatorBase(EvaluatorBase[T]):
"""Base class for all evaluators that make use of context as an input. It's also assumed that such evaluators
make use of a prompty file, and return their results as a dictionary, with a single key-value pair
linking the result name to a float value (unless multi-turn evaluation occurs, in which case the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Dict, Optional, Union
from typing import Dict, TypeVar, Union

from typing_extensions import override

Expand All @@ -18,7 +18,7 @@

from . import EvaluatorBase

T = Union[str, float]
T = TypeVar("T")


class RaiServiceEvaluatorBase(EvaluatorBase[T]):
Expand Down Expand Up @@ -50,12 +50,9 @@ def __init__(
self._credential = credential

@override
def __call__(
def __call__( # pylint: disable=docstring-missing-param
self,
*,
query: Optional[str] = None,
response: Optional[str] = None,
conversation=None,
*args,
**kwargs,
):
"""Evaluate either a query and response or a conversation. Must supply either a query AND response,
Expand All @@ -71,7 +68,7 @@ def __call__(
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:rtype: Union[Dict[str, T], Dict[str, Union[float, Dict[str, List[T]]]]]
"""
return super().__call__(query=query, response=response, conversation=conversation, **kwargs)
return super().__call__(*args, **kwargs)

@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, T]:
Expand Down Expand Up @@ -108,7 +105,7 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, T]:
)
input_data["context"] = context

return await evaluate_with_rai_service(
return await evaluate_with_rai_service( # type: ignore
metric_name=self._eval_metric,
data=input_data,
project_scope=self._azure_ai_project,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from concurrent.futures import as_completed
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Union

from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._evaluators._common import EvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation

from ._hate_unfairness import HateUnfairnessEvaluator
from ._self_harm import SelfHarmEvaluator
Expand All @@ -17,7 +18,7 @@


@experimental
class ContentSafetyEvaluator(EvaluatorBase):
class ContentSafetyEvaluator(EvaluatorBase[Union[str, float]]):
"""
Initialize a content safety evaluator configured to evaluate content safetry metrics for QA scenario.
Expand All @@ -31,7 +32,6 @@ class ContentSafetyEvaluator(EvaluatorBase):
:param kwargs: Additional arguments to pass to the evaluator.
:type kwargs: Any
:return: A function that evaluates content-safety metrics for "question-answering" scenario.
:rtype: Callable
**Usage**
Expand Down Expand Up @@ -79,13 +79,43 @@ def __init__(self, credential, azure_ai_project, eval_last_turn: bool = False, *
HateUnfairnessEvaluator(credential, azure_ai_project),
]

@override
@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a collection of content safety metrics for the given query/response pair
:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The content safety scores.
:rtype: Dict[str, Union[str, float]]
"""

@overload
def __call__(
self,
*,
query: Optional[str] = None,
response: Optional[str] = None,
conversation=None,
conversation: Conversation,
) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a collection of content safety metrics for a conversation
:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The content safety scores.
:rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]
"""

@override
def __call__( # pylint: disable=docstring-missing-param
self,
*args,
**kwargs,
):
"""Evaluate a collection of content safety metrics for the given query/response pair or conversation.
Expand All @@ -100,9 +130,9 @@ def __call__(
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The evaluation result.
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]]
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]]
"""
return super().__call__(query=query, response=response, conversation=conversation, **kwargs)
return super().__call__(*args, **kwargs)

@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[str, float]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Optional
from typing import Dict, List, Union

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._common.constants import EvaluationMetrics
from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


@experimental
class HateUnfairnessEvaluator(RaiServiceEvaluatorBase):
class HateUnfairnessEvaluator(RaiServiceEvaluatorBase[Union[str, float]]):
"""
Initialize a hate-unfairness evaluator for hate unfairness score.
Expand Down Expand Up @@ -58,27 +59,57 @@ def __init__(
eval_last_turn=eval_last_turn,
)

@override
@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate the given query/response pair for hateful content
:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The hate score
:rtype: Dict[str, Union[str, float]]
"""

@overload
def __call__(
self,
*,
query: Optional[str] = None,
response: Optional[str] = None,
conversation=None,
conversation: Conversation,
) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for hateful content
:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The hate score
:rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]
"""

@override
def __call__( # pylint: disable=docstring-missing-param
self,
*args,
**kwargs,
):
"""
Evaluate whether hateful content is present in your AI system's response.
:keyword query: The query to be evaluated.
:paramtype query: str
:paramtype query: Optional[str]
:keyword response: The response to be evaluated.
:paramtype response: str
:paramtype response: Optional[str]
:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages". Conversation turns are expected
to be dictionaries with keys "content" and "role".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The fluency score.
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]]
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]]
"""
return super().__call__(query=query, response=response, conversation=conversation, **kwargs)
return super().__call__(*args, **kwargs)
Loading

0 comments on commit 0619608

Please sign in to comment.