Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overloads for __call__ methods that accept query/response and conversation #38097

Merged
merged 19 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os
from typing import Optional
from typing import Dict, Union, List, Optional

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


class CoherenceEvaluator(PromptyEvaluatorBase):
Expand Down Expand Up @@ -49,6 +50,42 @@ def __init__(self, model_config):
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY)

@overload
needuv marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, float]:
"""Evaluate coherence for given input of query, response

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The coherence score.
:rtype: Dict[str, float]
"""
needuv marked this conversation as resolved.
Show resolved Hide resolved
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[float, Dict[str, List[float]]]]:
"""Evaluate coherence for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The coherence score.
:rtype: Dict[str, Union[float, Dict[str, List[float]]]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@override
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Callable, Dict, List, Optional, Union

from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._evaluators._common import EvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation

from ._hate_unfairness import HateUnfairnessEvaluator
from ._self_harm import SelfHarmEvaluator
Expand All @@ -31,7 +32,6 @@ class ContentSafetyEvaluator(EvaluatorBase):
:param kwargs: Additional arguments to pass to the evaluator.
:type kwargs: Any
:return: A function that evaluates content-safety metrics for "question-answering" scenario.
:rtype: Callable

**Usage**

Expand Down Expand Up @@ -78,6 +78,42 @@ def __init__(self, credential, azure_ai_project, eval_last_turn: bool = False, *
HateUnfairnessEvaluator(credential, azure_ai_project),
]

@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a collection of content safety metrics for the given query/response pair

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The content safety scores.
:rtype: Dict[str, Union[str, float]]
"""
...

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a collection of content safety metrics for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The content safety scores.
:rtype: Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]
needuv marked this conversation as resolved.
Show resolved Hide resolved
"""
...

@override
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._common.constants import EvaluationMetrics
from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


@experimental
Expand Down Expand Up @@ -57,6 +58,42 @@ def __init__(
credential=credential,
eval_last_turn=eval_last_turn,
)

@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate the given query/response pair for hateful content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The hate score
:rtype: Dict[str, Union[str, float]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for hateful content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The hate score
:rtype: Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@override
def __call__(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._common.constants import EvaluationMetrics
from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


@experimental
Expand Down Expand Up @@ -58,6 +59,42 @@ def __init__(
eval_last_turn=eval_last_turn,
)

@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for self-harm content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The self-harm score
:rtype: Dict[str, Union[str, float]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for self-harm content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The self-harm score
:rtype: Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@override
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._common.constants import EvaluationMetrics
from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


@experimental
Expand Down Expand Up @@ -58,6 +59,42 @@ def __init__(
eval_last_turn=eval_last_turn,
)

@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for sexual content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The sexual score
:rtype: Dict[str, Union[str, float]]
"""
...
needuv marked this conversation as resolved.
Show resolved Hide resolved

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for sexual content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The sexual score
:rtype: Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]
"""
...

@override
def __call__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from typing_extensions import override
from typing_extensions import overload, override

from azure.ai.evaluation._common._experimental import experimental
from azure.ai.evaluation._common.constants import EvaluationMetrics
from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase
from azure.ai.evaluation._model_configurations import Conversation


@experimental
Expand Down Expand Up @@ -58,6 +59,42 @@ def __init__(
eval_last_turn=eval_last_turn,
)

@overload
def __call__(
self,
*,
query: str,
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for violent content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
:paramtype response: str
:return: The content safety score.
:rtype: Dict[str, Union[str, float]]
"""
...

@overload
def __call__(
self,
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for violent content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
:paramtype conversation: Optional[~azure.ai.evaluation.Conversation]
:return: The violence score.
:rtype: Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]
"""
...

@override
def __call__(
self,
Expand All @@ -81,4 +118,5 @@ def __call__(
:return: The fluency score.
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]]
"""

return super().__call__(query=query, response=response, conversation=conversation, **kwargs)
Loading
Loading