Skip to content

Commit 582b642

Browse files
authored
Adding ECI to SafetyEvaluation (Azure#39915)
* Adding ECI to SafetyEvaluation * fix typos
1 parent ff36cdb commit 582b642

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_xpia,
2222
_coherence,
2323
)
24+
from azure.ai.evaluation._evaluators._eci._eci import ECIEvaluator
2425
from azure.ai.evaluation._evaluate import _evaluate
2526
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
2627
from azure.ai.evaluation._model_configurations import AzureAIProject, EvaluationResult
@@ -30,8 +31,9 @@
3031
AdversarialScenario,
3132
AdversarialScenarioJailbreak,
3233
IndirectAttackSimulator,
33-
DirectAttackSimulator,
34+
DirectAttackSimulator ,
3435
)
36+
from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario
3537
from azure.ai.evaluation.simulator._utils import JsonLineList
3638
from azure.ai.evaluation._common.utils import validate_azure_ai_project
3739
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
@@ -75,6 +77,7 @@ class _SafetyEvaluator(Enum):
7577
COHERENCE = "coherence"
7678
INDIRECT_ATTACK = "indirect_attack"
7779
DIRECT_ATTACK = "direct_attack"
80+
ECI = "eci"
7881

7982

8083
@experimental
@@ -148,7 +151,7 @@ async def _simulate(
148151
max_simulation_results: int = 3,
149152
conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
150153
tasks: List[str] = [],
151-
adversarial_scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] = None,
154+
adversarial_scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak, _UnstableAdversarialScenario]] = None,
152155
source_text: Optional[str] = None,
153156
direct_attack: bool = False,
154157
) -> Dict[str, str]:
@@ -231,7 +234,7 @@ async def callback(
231234
)
232235

233236
# if DirectAttack, run DirectAttackSimulator
234-
elif direct_attack:
237+
elif direct_attack and isinstance(adversarial_scenario, AdversarialScenario):
235238
self.logger.info(
236239
f"Running DirectAttackSimulator with inputs: adversarial_scenario={adversarial_scenario}, max_conversation_turns={max_conversation_turns}, max_simulation_results={max_simulation_results}"
237240
)
@@ -267,7 +270,7 @@ async def callback(
267270
)
268271
simulator = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
269272
simulator_outputs = await simulator(
270-
scenario=adversarial_scenario,
273+
scenario=adversarial_scenario, #type: ignore
271274
max_conversation_turns=max_conversation_turns,
272275
max_simulation_results=max_simulation_results,
273276
conversation_turns=conversation_turns,
@@ -340,7 +343,7 @@ def _get_scenario(
340343
evaluators: List[_SafetyEvaluator],
341344
num_turns: int = 3,
342345
scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] = None,
343-
) -> Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]]:
346+
) -> Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak, _UnstableAdversarialScenario]]:
344347
"""
345348
Returns the Simulation scenario based on the provided list of SafetyEvaluator.
346349
@@ -362,6 +365,8 @@ def _get_scenario(
362365
if num_turns > 1
363366
else AdversarialScenario.ADVERSARIAL_QA
364367
)
368+
if evaluator == _SafetyEvaluator.ECI:
369+
return _UnstableAdversarialScenario.ECI
365370
if evaluator in [
366371
_SafetyEvaluator.GROUNDEDNESS,
367372
_SafetyEvaluator.RELEVANCE,
@@ -439,6 +444,10 @@ def _get_evaluators(
439444
evaluators_dict["content_safety"] = _content_safety.ContentSafetyEvaluator(
440445
azure_ai_project=self.azure_ai_project, credential=self.credential
441446
)
447+
elif evaluator == _SafetyEvaluator.ECI:
448+
evaluators_dict["eci"] = ECIEvaluator(
449+
azure_ai_project=self.azure_ai_project, credential=self.credential
450+
)
442451
else:
443452
msg = (
444453
f"Invalid evaluator: {evaluator}. Supported evaluators are: {_SafetyEvaluator.__members__.values()}"

sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_target(query: str) -> str:
115115
safety_evaluation_content_safety_scenario_results = asyncio.run(safety_evaluation_content_safety_scenario(
116116
evaluators=[_SafetyEvaluator.CONTENT_SAFETY],
117117
target=test_target,
118-
scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,,
118+
scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
119119
num_rows=3,
120120
output_path="evaluation_outputs_safety_scenario.jsonl",
121121
))
@@ -242,6 +242,28 @@ def test_target(query: str) -> str:
242242
))
243243
# [END upia_safety_evaluation]
244244

245+
# [START eci_safety_evaluation]
246+
def test_target(query: str) -> str:
247+
return "some response"
248+
249+
azure_ai_project = {
250+
"subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"),
251+
"resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"),
252+
"project_name": os.environ.get("AZURE_PROJECT_NAME"),
253+
}
254+
255+
credential = DefaultAzureCredential()
256+
257+
safety_evaluation_eci = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential)
258+
safety_evaluation_eci_results = asyncio.run(safety_evaluation_eci(
259+
evaluators=[_SafetyEvaluator.ECI],
260+
target=test_target,
261+
num_turns=1,
262+
num_rows=3,
263+
output_path="evaluation_outputs_eci.jsonl",
264+
))
265+
# [END eci_safety_evaluation]
266+
245267
if __name__ == "__main__":
246268
print("Loading samples in evaluation_samples_safety_evaluation.py")
247269
sample = EvaluationSafetyEvaluationSamples()

0 commit comments

Comments
 (0)