Skip to content

Commit

Permalink
Move eval rai call to shared package (#3659)
Browse files Browse the repository at this point in the history
Moves the code formally located in _content_safety_common to
promptflow.evals.common, since new evaluators will soon need them.
  • Loading branch information
MilesHolland authored Aug 16, 2024
1 parent 89aa66f commit 51aef6f
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 44 deletions.
16 changes: 16 additions & 0 deletions src/promptflow-evals/promptflow/evals/_common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# To minimize relative imports in our evaluators, the scope of this package also includes anything
# that would have otherwise been a relative import scoped to single evaluator directories.

from . import constants
from .rai_service import evaluate_with_rai_service
from .utils import get_harm_severity_level

__all__ = [
"get_harm_severity_level",
"evaluate_with_rai_service",
"constants",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------

from ._content_safety import ContentSafetyEvaluator
from ._content_safety_base import ContentSafetyEvaluatorBase
from ._content_safety_chat import ContentSafetyChatEvaluator
from ._hate_unfairness import HateUnfairnessEvaluator
from ._self_harm import SelfHarmEvaluator
Expand All @@ -16,4 +17,5 @@
"HateUnfairnessEvaluator",
"ContentSafetyEvaluator",
"ContentSafetyChatEvaluator",
"ContentSafetyEvaluatorBase",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@

from abc import ABC

try:
from .constants import EvaluationMetrics
from .evaluate_with_rai_service import evaluate_with_rai_service
except ImportError:
from constants import EvaluationMetrics
from evaluate_with_rai_service import evaluate_with_rai_service
from promptflow.evals._common.constants import EvaluationMetrics
from promptflow.evals._common.rai_service import evaluate_with_rai_service


class ContentSafetyEvaluatorBase(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow.evals._common.constants import EvaluationMetrics

try:
from .common import ContentSafetyEvaluatorBase
from .common.constants import EvaluationMetrics
from ._content_safety_base import ContentSafetyEvaluatorBase
except ImportError:
from common import ContentSafetyEvaluatorBase
from common.constants import EvaluationMetrics
from _content_safety_base import ContentSafetyEvaluatorBase


class _AsyncHateUnfairnessEvaluator(ContentSafetyEvaluatorBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow.evals._common.constants import EvaluationMetrics

try:
from .common import ContentSafetyEvaluatorBase
from .common.constants import EvaluationMetrics
from ._content_safety_base import ContentSafetyEvaluatorBase
except ImportError:
from common import ContentSafetyEvaluatorBase
from common.constants import EvaluationMetrics
from _content_safety_base import ContentSafetyEvaluatorBase


class _AsyncSelfHarmEvaluator(ContentSafetyEvaluatorBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow.evals._common.constants import EvaluationMetrics

try:
from .common import ContentSafetyEvaluatorBase
from .common.constants import EvaluationMetrics
from ._content_safety_base import ContentSafetyEvaluatorBase
except ImportError:
from common import ContentSafetyEvaluatorBase
from common.constants import EvaluationMetrics
from _content_safety_base import ContentSafetyEvaluatorBase


class _AsyncSexualEvaluator(ContentSafetyEvaluatorBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow.evals._common.constants import EvaluationMetrics

try:
from .common import ContentSafetyEvaluatorBase
from .common.constants import EvaluationMetrics
from ._content_safety_base import ContentSafetyEvaluatorBase
except ImportError:
from common import ContentSafetyEvaluatorBase
from common.constants import EvaluationMetrics
from _content_safety_base import ContentSafetyEvaluatorBase


class _AsyncViolenceEvaluator(ContentSafetyEvaluatorBase):
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def test_content_safety_service_unavailable(self, project_scope, azure_cred):
# and the actual request made.
# Using not is_replay() because is_live doesn't apply to recording mode?
if not is_replay():
# Warning, live testing fails due to unstable region.
# We need a use a new region.
project_scope["project_name"] = "pf-evals-ws-westus2"

with pytest.raises(Exception) as exc_info:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
import pytest
from azure.identity import DefaultAzureCredential

from promptflow.evals.evaluators._content_safety.common.constants import (
EvaluationMetrics,
HarmSeverityLevel,
RAIService,
)
from promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service import (
from promptflow.evals._common.constants import EvaluationMetrics, HarmSeverityLevel, RAIService
from promptflow.evals._common.rai_service import (
_get_service_discovery_url,
ensure_service_availability,
evaluate_with_rai_service,
Expand Down Expand Up @@ -107,8 +103,8 @@ async def test_fetch_or_reuse_token(self, mock_token, mock_expired_token):
assert res == 100

@patch("httpx.AsyncClient.get", return_value=httpx.Response(200, json={"result": "stuff"}))
@patch("promptflow.evals.evaluators._content_safety.common.constants.RAIService.TIMEOUT", 1)
@patch("promptflow.evals.evaluators._content_safety.common.constants.RAIService.SLEEP_TIME", 1.2)
@patch("promptflow.evals._common.constants.RAIService.TIMEOUT", 1)
@patch("promptflow.evals._common.constants.RAIService.SLEEP_TIME", 1.2)
@pytest.mark.usefixtures("mock_token")
@pytest.mark.asyncio
async def test_fetch_result(self, client_mock, mock_token):
Expand Down Expand Up @@ -232,7 +228,7 @@ async def test_get_service_discovery_url(self, client_mock):
return_value=httpx.Response(200, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service._get_service_discovery_url",
"promptflow.evals._common.rai_service._get_service_discovery_url",
return_value="https://www.url.com:123",
)
async def test_get_rai_svc_url(self, client_mock, discovery_mock):
Expand All @@ -250,27 +246,27 @@ async def test_get_rai_svc_url(self, client_mock, discovery_mock):

@pytest.mark.asyncio
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.fetch_or_reuse_token",
"promptflow.evals._common.rai_service.fetch_or_reuse_token",
return_value="dummy-token",
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.get_rai_svc_url",
"promptflow.evals._common.rai_service.get_rai_svc_url",
return_value="www.rai_url.com",
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.ensure_service_availability",
"promptflow.evals._common.rai_service.ensure_service_availability",
return_value=None,
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.submit_request",
"promptflow.evals._common.rai_service.submit_request",
return_value="op_id",
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.fetch_result",
"promptflow.evals._common.rai_service.fetch_result",
return_value="response_object",
)
@patch(
"promptflow.evals.evaluators._content_safety.common.evaluate_with_rai_service.parse_response",
"promptflow.evals._common.rai_service.parse_response",
return_value="wow-that's-a-lot-of-patches",
)
@patch("azure.identity.DefaultAzureCredential")
Expand Down

0 comments on commit 51aef6f

Please sign in to comment.