Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Safety tests] Initial commit #3767

Merged
merged 6 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ _Task & models for chitchat with a given persona._
- **Build-It Break-It Fix-It for Dialogue Safety** [[project]](https://parl.ai/projects/dialogue_safety/) [[paper]](https://arxiv.org/abs/1908.06083).
_Task and method for improving the detection of offensive language in the context of dialogue._

- **Anticipating Safety Issues in E2E Conversational AI** [[project]](https://parl.ai/projects/safety_bench/).
_Benchmarks for evaluating the safety of English-language dialogue models_

- **Multi-Dimensional Gender Bias Classification** [[project]](https://parl.ai/projects/md_gender/) [[paper]](https://arxiv.org/abs/2005.00614)
_Training fine-grained gender bias classifiers to identify gender bias in text._

Expand All @@ -83,7 +86,7 @@ _Task & models for chitchat with a given persona._

- **Wizard of Wikipedia** [[project]](http://parl.ai/projects/wizard_of_wikipedia/) [[paper]](https://openreview.net/forum?id=r1l73iRqKm).
_Knowledge-grounded open domain chitchat task & models._

- **Retrieval Augmentation Reduces Hallucination in Conversation** [[project]](http://parl.ai/projects/hallucination/) [[paper]](https://arxiv.org/abs/2104.07567). _Exploratory architectures that add retrieval mechanisms to dialogue models, reducing hallucination while maintaining conversational ability._

## Visually Grounded
Expand Down
56 changes: 56 additions & 0 deletions projects/safety_bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Safety Bench: Checks for Anticipating Safety Issues with E2E Conversational AI Models

A suite of dialogue safety unit tests and integration tests, in correspondence with the paper <TODO: PAPER LINK>

## Paper Information
TODO: fill me in

**Abstract:** TODO: fill me in


## Setting up the API
The tests require *only* implementing only the following API:
```
def get_response(self, input_text: str) -> str:
```
This function takes as input the dialogue history (`input_text`) and returns the dialogue model's response (as a string).

> NOTE: One does not need to implement a ParlAI agent to run these unit tests; the API only requires text in, text out.

One must add one's model wrapper to the folder `projects/safety_bench/model_wrappers` and register it via `@register_model_wrapper("model_name")` so that it is accessible on the command line.

## Unit Tests

The unit tests run automatically provided the above API access to the model.

Details on these tests can be found in Section 6 of the paper. We test both:
1. The model's ability to generate offensive language and
2. How the model responds to offensive language.

### Example commands

Run unit tests for the model `blenderbot_90M` and safe all logs to the folder `/tmp/blender90M`:
```
python projects/safety_bench/run_unit_tests.py --wrapper blenderbot_90M --log-folder /tmp/blender90M
```

Run unit tests for the model `gpt2_large` and safe all logs to the folder `/tmp/gpt2large`:
```
python projects/safety_bench/run_unit_tests.py -w gpt2_large --log-folder /tmp/gpt2large
```

## Integration Tests
Provided the same API access as described above, we provide tooling to make it easy to run the human safety evaluations on Mechanical Turk from [here](https://parl.ai/projects/dialogue_safety/).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the link here pointing to bibifi or the safety recipe tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooops good point!


These tools prepare data as input for the Mechanical Task. Further instructions for setting up [Mephisto](https://github.com/facebookresearch/Mephisto) and running the task on Mechanical Turk are printed at the completion of the script.

### Example Commands
Prepare integration tests for the adversarial setting for the model `blenderbot_3B`:
```
python projects/safety_bench/prepare_integration_tests.py --wrapper blenderbot_3B --safety-setting adversarial
```

Prepare integration tests for the nonadversarial setting for the model `dialogpt_medium`:
```
python projects/safety_bench/prepare_integration_tests.py --wrapper dialogpt_medium --safety-setting nonadversarial
```
5 changes: 5 additions & 0 deletions projects/safety_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions projects/safety_bench/model_wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
34 changes: 34 additions & 0 deletions projects/safety_bench/model_wrappers/example_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Example wrapper which replies `hello` to every text.
"""
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


@register_model_wrapper("example_wrapper")
class ExampleWrapper:
"""
Example wrapper which replies `hello` to every text.
"""

def __init__(self):
# Do any initialization here, like loading the omdel
pass

def get_response(self, input_text: str) -> str:
"""
Takes dialogue history (string) as input, and returns the
model's response (string).
"""
# This is the only method you are required to implement.
# The input text is the corresponding input for the model.
# Be sure to reset the model's dialogue history before/after
# every call to `get_response`.

return (
"Hello"
) # In this example, we always respond 'Hello' regardless of the input
121 changes: 121 additions & 0 deletions projects/safety_bench/model_wrappers/gpt_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrappers for GPT models from HF (in ParlAI).

Available models include:
- GPT2 large
- DialoGPT medium
"""
from abc import ABC, abstractproperty
from typing import Dict

from parlai.core.agents import create_agent
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


class GPTWrapper(ABC):
"""
Base class wrapper for GPT wrapper
"""

def __init__(self):
# Load the model from the model zoo via ParlAI
opt = {
"skip_generation": False,
"interactive_mode": True,
"model": f"hugging_face/{self.model_name}",
"gpt2_size": self.model_size,
"add_special_tokens": False,
}
opt.update(self.additional_opts)
self.model = create_agent(opt)

@abstractproperty
def model_name(self) -> str:
# Return the path to the agent in the model zoo
return ""

@abstractproperty
def model_size(self) -> str:
# Return the requested model size
return ""

@abstractproperty
def additional_opts(self) -> Dict:
# Return any model specific opts
return {}

def get_response(self, input_text: str) -> str:
# In ParlAI, we use observe/act syntax to get a response from the model
# Please see the ParlAI docs for more info
self.model.observe({"text": input_text, "episode_done": True})
response = self.model.act()

return response.get("text")


@register_model_wrapper("dialogpt_medium")
class DialoGPTMediumWrapper(GPTWrapper):
@property
def model_name(self):
return "dialogpt"

@property
def model_size(self):
return "medium"

@property
def additional_opts(self):
return {
"beam_context_block_ngram": 3,
"beam_block_ngram": 3,
"beam_size": 10,
"inference": "beam",
"beam_min_length": 20,
"beam_block_full_context": False,
}


@register_model_wrapper("gpt2_large")
class GPT2LargeWrapper(GPTWrapper):
@property
def model_name(self):
return "gpt2"

@property
def model_size(self):
return "large"

@property
def additional_opts(self):
return {
"beam_context_block_ngram": 3,
"beam_block_ngram": 3,
"beam_size": 10,
"inference": "beam",
"beam_min_length": 20,
"beam_block_full_context": False,
}

def get_response(self, input_text: str) -> str:
# For GPT-2, we add punctuation and an extra newline if one does
# not exist, and then take the first line generated

if input_text.strip()[-1] not in ['.', '?', '!']:
input_text += "."

self.model.observe({"text": input_text + "\n", "episode_done": True})
response = self.model.act()
# split on newline
response_texts = response.get("text").split("\n")
for response_text in response_texts:
if response_text:
# return first non-empty string
return response_text

# produced only newlines or empty strings
return ""
70 changes: 70 additions & 0 deletions projects/safety_bench/model_wrappers/parlai_model_zoo_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrappers for ParlAI models in the model zoo.

Available models include:
- blenderbot_90M
- blenderbot_400Mdistill
- blenderbot_1Bdistill
- blenderbot_3B
"""
from abc import ABC, abstractproperty

from parlai.core.agents import create_agent_from_model_file
from projects.safety_bench.utils.wrapper_loading import register_model_wrapper


class ParlAIModelZooWrapper(ABC):
"""
Base class wrapper for ParlAI models in the ParlAI zoo.
"""

def __init__(self):
# Load the model from the model zoo via ParlAI
overrides = {"skip_generation": False, "interactive_mode": True}
self.model = create_agent_from_model_file(self.zoo_path, overrides)

@abstractproperty
def zoo_path(self):
# Return the path to the agent in the model zoo
pass

def get_response(self, input_text: str) -> str:
# In ParlAI, we use observe/act syntax to get a response from the model
# Please see the ParlAI docs for more info
self.model.observe({"text": input_text, "episode_done": True})
response = self.model.act()

return response.get("text")


@register_model_wrapper("blenderbot_90M")
class BlenderBot90MWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_90M/model"


@register_model_wrapper("blenderbot_400Mdistill")
class BlenderBot400MDistillWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_400Mdistill/model"


@register_model_wrapper("blenderbot_1Bdistill")
class BlenderBot1BDistillWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_1Bdistill/model"


@register_model_wrapper("blenderbot_3B")
class BlenderBot3BWrapper(ParlAIModelZooWrapper):
@property
def zoo_path(self):
return "zoo:blender/blender_3B/model"
Loading