This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[Safety tests] Initial commit #3767
Merged
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Safety Bench: Checks for Anticipating Safety Issues with E2E Conversational AI Models | ||
|
||
A suite of dialogue safety unit tests and integration tests, in correspondence with the paper <TODO: PAPER LINK> | ||
|
||
## Paper Information | ||
TODO: fill me in | ||
|
||
**Abstract:** TODO: fill me in | ||
|
||
|
||
## Setting up the API | ||
The tests require *only* implementing only the following API: | ||
``` | ||
def get_response(self, input_text: str) -> str: | ||
``` | ||
This function takes as input the dialogue history (`input_text`) and returns the dialogue model's response (as a string). | ||
|
||
> NOTE: One does not need to implement a ParlAI agent to run these unit tests; the API only requires text in, text out. | ||
|
||
One must add one's model wrapper to the folder `projects/safety_bench/model_wrappers` and register it via `@register_model_wrapper("model_name")` so that it is accessible on the command line. | ||
|
||
## Unit Tests | ||
|
||
The unit tests run automatically provided the above API access to the model. | ||
|
||
Details on these tests can be found in Section 6 of the paper. We test both: | ||
1. The model's ability to generate offensive language and | ||
2. How the model responds to offensive language. | ||
|
||
### Example commands | ||
|
||
Run unit tests for the model `blenderbot_90M` and safe all logs to the folder `/tmp/blender90M`: | ||
``` | ||
python projects/safety_bench/run_unit_tests.py --wrapper blenderbot_90M --log-folder /tmp/blender90M | ||
``` | ||
|
||
Run unit tests for the model `gpt2_large` and safe all logs to the folder `/tmp/gpt2large`: | ||
``` | ||
python projects/safety_bench/run_unit_tests.py -w gpt2_large --log-folder /tmp/gpt2large | ||
``` | ||
|
||
## Integration Tests | ||
Provided the same API access as described above, we provide tooling to make it easy to run the human safety evaluations on Mechanical Turk from [here](https://parl.ai/projects/dialogue_safety/). | ||
|
||
These tools prepare data as input for the Mechanical Task. Further instructions for setting up [Mephisto](https://github.com/facebookresearch/Mephisto) and running the task on Mechanical Turk are printed at the completion of the script. | ||
|
||
### Example Commands | ||
Prepare integration tests for the adversarial setting for the model `blenderbot_3B`: | ||
``` | ||
python projects/safety_bench/prepare_integration_tests.py --wrapper blenderbot_3B --safety-setting adversarial | ||
``` | ||
|
||
Prepare integration tests for the nonadversarial setting for the model `dialogpt_medium`: | ||
``` | ||
python projects/safety_bench/prepare_integration_tests.py --wrapper dialogpt_medium --safety-setting nonadversarial | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Example wrapper which replies `hello` to every text. | ||
""" | ||
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper | ||
|
||
|
||
@register_model_wrapper("example_wrapper") | ||
class ExampleWrapper: | ||
""" | ||
Example wrapper which replies `hello` to every text. | ||
""" | ||
|
||
def __init__(self): | ||
# Do any initialization here, like loading the omdel | ||
pass | ||
|
||
def get_response(self, input_text: str) -> str: | ||
""" | ||
Takes dialogue history (string) as input, and returns the | ||
model's response (string). | ||
""" | ||
# This is the only method you are required to implement. | ||
# The input text is the corresponding input for the model. | ||
# Be sure to reset the model's dialogue history before/after | ||
# every call to `get_response`. | ||
|
||
return ( | ||
"Hello" | ||
) # In this example, we always respond 'Hello' regardless of the input |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Wrappers for GPT models from HF (in ParlAI). | ||
|
||
Available models include: | ||
- GPT2 large | ||
- DialoGPT medium | ||
""" | ||
from abc import ABC, abstractproperty | ||
from typing import Dict | ||
|
||
from parlai.core.agents import create_agent | ||
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper | ||
|
||
|
||
class GPTWrapper(ABC): | ||
""" | ||
Base class wrapper for GPT wrapper | ||
""" | ||
|
||
def __init__(self): | ||
# Load the model from the model zoo via ParlAI | ||
opt = { | ||
"skip_generation": False, | ||
"interactive_mode": True, | ||
"model": f"hugging_face/{self.model_name}", | ||
"gpt2_size": self.model_size, | ||
"add_special_tokens": False, | ||
} | ||
opt.update(self.additional_opts) | ||
self.model = create_agent(opt) | ||
|
||
@abstractproperty | ||
def model_name(self) -> str: | ||
# Return the path to the agent in the model zoo | ||
return "" | ||
|
||
@abstractproperty | ||
def model_size(self) -> str: | ||
# Return the requested model size | ||
return "" | ||
|
||
@abstractproperty | ||
def additional_opts(self) -> Dict: | ||
# Return any model specific opts | ||
return {} | ||
|
||
def get_response(self, input_text: str) -> str: | ||
# In ParlAI, we use observe/act syntax to get a response from the model | ||
# Please see the ParlAI docs for more info | ||
self.model.observe({"text": input_text, "episode_done": True}) | ||
response = self.model.act() | ||
|
||
return response.get("text") | ||
|
||
|
||
@register_model_wrapper("dialogpt_medium") | ||
class DialoGPTMediumWrapper(GPTWrapper): | ||
@property | ||
def model_name(self): | ||
return "dialogpt" | ||
|
||
@property | ||
def model_size(self): | ||
return "medium" | ||
|
||
@property | ||
def additional_opts(self): | ||
return { | ||
"beam_context_block_ngram": 3, | ||
"beam_block_ngram": 3, | ||
"beam_size": 10, | ||
"inference": "beam", | ||
"beam_min_length": 20, | ||
"beam_block_full_context": False, | ||
} | ||
|
||
|
||
@register_model_wrapper("gpt2_large") | ||
class GPT2LargeWrapper(GPTWrapper): | ||
@property | ||
def model_name(self): | ||
return "gpt2" | ||
|
||
@property | ||
def model_size(self): | ||
return "large" | ||
|
||
@property | ||
def additional_opts(self): | ||
return { | ||
"beam_context_block_ngram": 3, | ||
"beam_block_ngram": 3, | ||
"beam_size": 10, | ||
"inference": "beam", | ||
"beam_min_length": 20, | ||
"beam_block_full_context": False, | ||
} | ||
|
||
def get_response(self, input_text: str) -> str: | ||
# For GPT-2, we add punctuation and an extra newline if one does | ||
# not exist, and then take the first line generated | ||
|
||
if input_text.strip()[-1] not in ['.', '?', '!']: | ||
input_text += "." | ||
|
||
self.model.observe({"text": input_text + "\n", "episode_done": True}) | ||
response = self.model.act() | ||
# split on newline | ||
response_texts = response.get("text").split("\n") | ||
for response_text in response_texts: | ||
if response_text: | ||
# return first non-empty string | ||
return response_text | ||
|
||
# produced only newlines or empty strings | ||
return "" |
70 changes: 70 additions & 0 deletions
70
projects/safety_bench/model_wrappers/parlai_model_zoo_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Wrappers for ParlAI models in the model zoo. | ||
|
||
Available models include: | ||
- blenderbot_90M | ||
- blenderbot_400Mdistill | ||
- blenderbot_1Bdistill | ||
- blenderbot_3B | ||
""" | ||
from abc import ABC, abstractproperty | ||
|
||
from parlai.core.agents import create_agent_from_model_file | ||
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper | ||
|
||
|
||
class ParlAIModelZooWrapper(ABC): | ||
""" | ||
Base class wrapper for ParlAI models in the ParlAI zoo. | ||
""" | ||
|
||
def __init__(self): | ||
# Load the model from the model zoo via ParlAI | ||
overrides = {"skip_generation": False, "interactive_mode": True} | ||
self.model = create_agent_from_model_file(self.zoo_path, overrides) | ||
|
||
@abstractproperty | ||
def zoo_path(self): | ||
# Return the path to the agent in the model zoo | ||
pass | ||
|
||
def get_response(self, input_text: str) -> str: | ||
# In ParlAI, we use observe/act syntax to get a response from the model | ||
# Please see the ParlAI docs for more info | ||
self.model.observe({"text": input_text, "episode_done": True}) | ||
response = self.model.act() | ||
|
||
return response.get("text") | ||
|
||
|
||
@register_model_wrapper("blenderbot_90M") | ||
class BlenderBot90MWrapper(ParlAIModelZooWrapper): | ||
@property | ||
def zoo_path(self): | ||
return "zoo:blender/blender_90M/model" | ||
|
||
|
||
@register_model_wrapper("blenderbot_400Mdistill") | ||
class BlenderBot400MDistillWrapper(ParlAIModelZooWrapper): | ||
@property | ||
def zoo_path(self): | ||
return "zoo:blender/blender_400Mdistill/model" | ||
|
||
|
||
@register_model_wrapper("blenderbot_1Bdistill") | ||
class BlenderBot1BDistillWrapper(ParlAIModelZooWrapper): | ||
@property | ||
def zoo_path(self): | ||
return "zoo:blender/blender_1Bdistill/model" | ||
|
||
|
||
@register_model_wrapper("blenderbot_3B") | ||
class BlenderBot3BWrapper(ParlAIModelZooWrapper): | ||
@property | ||
def zoo_path(self): | ||
return "zoo:blender/blender_3B/model" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the link here pointing to bibifi or the safety recipe tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooops good point!