This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open-sourcing SM-Turn and SM-Dialog from the human eval paper #4333
Merged
Merged
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
8f6fcf8
Add SM-Turn code
EricMichaelSmith d1d93ef
Add in new good chat data folder
EricMichaelSmith 8f3d75a
Model chat temp fix
EricMichaelSmith 62b9a07
Fix duplicated title
EricMichaelSmith 6743d0e
Work on README
EricMichaelSmith 5d05032
README revisions
EricMichaelSmith 1ffa871
Revisions
EricMichaelSmith e98c4f2
Create __init__.py
EricMichaelSmith ce0a72b
Create __init__.py
EricMichaelSmith 482ac12
Merge branch 'main' into smturn
EricMichaelSmith 67e4772
Fix Lint error
EricMichaelSmith 67563bc
Comment
EricMichaelSmith d83b3f7
Break up line
EricMichaelSmith db95e1a
Add Bibtex
EricMichaelSmith 9421793
Make comment strings more useful
EricMichaelSmith d0c5691
Clarify remaining meta-comments
EricMichaelSmith 0530893
Update tests
EricMichaelSmith File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
parlai/crowdsourcing/projects/humaneval/single_model_eval/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Crowdsourcing task for single-model per-turn and per-dialogue evaluations | ||
|
||
Code to run human crowdworker evaluations on a single model, one of the types of evaluation techniques explored in [Smith, et al. "Human Evaluation of Conversations is an Open Problem: comparing the sensitivity of various methods for evaluating dialogue agents" (2022)](https://arxiv.org/abs/2201.04723). This crowdsourcing task consists of a conversation between a Turker and a model. The task will collect evaluations of engagingness, humanness, and interestingness after every model response (SM-Turn in the paper), as well as final ratings of 1-to-5 Likert scores of those same metrics at the end of the conversation (SM-Dialog in the paper). | ||
|
||
## Collecting evaluations | ||
|
||
To launch HITs, run `python run.py`. All Hydra flags are as in the base human/model crowdsourcing task in [`parlai/crowdsourcing/tasks/model_chat/`](https://github.com/facebookresearch/ParlAI/tree/main/parlai/crowdsourcing/tasks/model_chat), which this crowdsourcing task is a custom version of. | ||
|
||
To specify the set of models that you want to evaluate, pass in a custom YAML file with the `mephisto.blueprint.model_opt_path` flag. The example `task_config/model_opts.yaml` file specifies the set of models evaluated in this paper: | ||
- `blender_3B`: (**BlenderBot3B** in the paper) The 2.7-billion parameter variant of the [BlenderBot 1.0 model](https://parl.ai/projects/recipes/) | ||
- `blender_3B_beam_min_length_0`: (**BlenderBot3B-M0**) BlenderBot3B is typically used with a minimum generation length of 20 tokens: this variant removes the minimum generation length. | ||
- `blender_90M`: (**BlenderBot90M**) The variant of BlenderBot 1.0 with 90 million parameters, trained on the same datasets as BlenderBot3B. | ||
- `reddit_3B`: (**Reddit3B**) Pretraining-only BlenderBot3B, without any fine-tuning on dialogue datasets. | ||
|
||
## Running analysis | ||
|
||
Call `python analysis/compile_results.py` to analyze single-model evaluations collected with this crowdsourcing task. Required flags for this script are: | ||
- `--task-name`: The Mephisto task name used when collecting evaluations | ||
- `--output-folder`: The folder to save analysis output files to | ||
|
||
Set `--filter-uniform-hits` to `True` to filter out any HITs for which the Turker's annotations were the exact same on each turn of the conversation, as a quality check. | ||
|
||
Features of this script include: | ||
- Filtering out HITs with acceptability violations, and saving a file of all Turkers who violated acceptability checks | ||
- Saving a file of all per-turn ratings (SM-Turn scores) and per-dialogue ratings (SM-Dialog scores) across all conversations | ||
- Saving a file of the aggregate rates of selecting each annotation bucket across all turns (i.e. SM-Turn) | ||
- Saving statistics about the distribution of Likert scores for each question asked at the end of each conversation (i.e. SM-Dialog) |
5 changes: 5 additions & 0 deletions
5
parlai/crowdsourcing/projects/humaneval/single_model_eval/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
5 changes: 5 additions & 0 deletions
5
parlai/crowdsourcing/projects/humaneval/single_model_eval/analysis/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
341 changes: 341 additions & 0 deletions
341
parlai/crowdsourcing/projects/humaneval/single_model_eval/analysis/compile_results.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,341 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
from collections import defaultdict | ||
from datetime import datetime | ||
from typing import Any, Dict | ||
|
||
import pandas as pd | ||
from mephisto.data_model.worker import Worker | ||
|
||
import parlai.utils.logging as logging | ||
from parlai.crowdsourcing.tasks.model_chat.model_chat_blueprint import ( | ||
BLUEPRINT_TYPE, | ||
) # For registering the blueprint | ||
from parlai.crowdsourcing.utils.acceptability import AcceptabilityChecker | ||
from parlai.crowdsourcing.utils.analysis import AbstractResultsCompiler | ||
|
||
_ = BLUEPRINT_TYPE | ||
# NOTE: BLUEPRINT_TYPE needs to be imported here to register the blueprint | ||
|
||
|
||
class ModelChatResultsCompiler(AbstractResultsCompiler): | ||
""" | ||
Compile and save results of human+model chats. | ||
|
||
Results will be saved on the level of specific conversations, as well as aggregated | ||
up to the level of each worker as a whole. | ||
""" | ||
|
||
@classmethod | ||
def setup_args(cls): | ||
parser = super().setup_args() | ||
parser.add_argument( | ||
'--filter-uniform-hits', | ||
action='store_true', | ||
help='Filter out any HITs in which the worker\'s annotations were the exact same on each turn of the conversation', | ||
) | ||
return parser | ||
|
||
def __init__(self, opt: Dict[str, Any]): | ||
|
||
super().__init__(opt) | ||
|
||
# Input args | ||
os.makedirs(self.output_folder, exist_ok=True) | ||
# TODO: see if this can be moved to the superclass | ||
self.filter_uniform_hits = opt['filter_uniform_hits'] | ||
|
||
# Save paths | ||
self.unacceptable_worker_ids_path = os.path.join( | ||
self.output_folder, 'unacceptable_worker_ids.txt' | ||
) | ||
self.annotation_selection_rate_path = os.path.join( | ||
self.output_folder, 'annotation_selection_rates.csv' | ||
) | ||
self.likert_score_stat_path = os.path.join( | ||
self.output_folder, 'likert_score_stats.csv' | ||
) | ||
|
||
self.acceptability_checker = AcceptabilityChecker() | ||
|
||
def get_results_path_base(self) -> str: | ||
return os.path.join(self.output_folder, 'results') | ||
# TODO: see if this can be moved to the superclass | ||
|
||
def compile_results(self) -> pd.DataFrame: | ||
|
||
# Load task data | ||
logging.info('Retrieving task data from Mephisto.') | ||
task_units_data = self.get_task_data() | ||
logging.info(f'Data for {len(task_units_data)} units loaded successfully.') | ||
|
||
num_convos_with_no_save_data = 0 | ||
num_wrong_status_convos = 0 | ||
num_complete_convos = 0 | ||
|
||
unacceptable_task_units = [] | ||
unacceptable_worker_ids = [] | ||
conversation_idx = 0 | ||
conversation_dfs = [] | ||
|
||
for task_unit in task_units_data: | ||
|
||
worker_id = task_unit['worker_id'] | ||
assignment_id = task_unit['assignment_id'] | ||
|
||
# # Determining whether the task unit should be skipped | ||
|
||
# Extract out custom data | ||
if task_unit['data']['save_data'] is None: | ||
logging.info('Found a task unit with no save data! Skipping.') | ||
num_convos_with_no_save_data += 1 | ||
continue | ||
elif task_unit['status'] not in ['completed', 'approved']: | ||
logging.info( | ||
f'Found a HIT with the status "{task_unit["status"]}"!.' | ||
f'Skipping.' | ||
) | ||
num_wrong_status_convos += 1 | ||
continue | ||
else: | ||
num_complete_convos += 1 | ||
|
||
# Extract out useful data | ||
custom_data = task_unit['data']['save_data']['custom_data'] | ||
mturk_worker_id = Worker.get(self.get_mephisto_db(), worker_id).worker_name | ||
task_start = datetime.utcfromtimestamp(task_unit['task_start']) | ||
task_end = datetime.utcfromtimestamp(task_unit['task_end']) | ||
info_dict = { | ||
('worker_id', ''): worker_id, | ||
('mturk_worker_id', ''): mturk_worker_id, | ||
('unit_id', ''): task_unit['unit_id'], | ||
('assignment_id', ''): assignment_id, | ||
('conversation_idx', ''): conversation_idx, | ||
('date', ''): task_start.strftime('%Y-%m-%d'), | ||
('completion_time', ''): (task_end - task_start).total_seconds(), | ||
} | ||
|
||
# Check that the conversation consists of pairs of comments between | ||
# Speaker 1 and Speaker 2, with Speaker 1 speaking first | ||
assert 'final_rating' in task_unit['data']['messages'][-1]['task_data'] | ||
convo_messages = [m for m in task_unit['data']['messages'][:-1]] | ||
# The final message is just a final rating | ||
assert all( | ||
[ | ||
message['id'] == 'Speaker 2' if message_idx % 2 else 'Speaker 1' | ||
for message_idx, message in enumerate(convo_messages) | ||
] | ||
) | ||
messages_1 = [m for m in convo_messages if m['id'] == 'Speaker 1'] | ||
messages_2 = [m for m in convo_messages if m['id'] == 'Speaker 2'] | ||
assert len(messages_1) + len(messages_2) == len(convo_messages) | ||
|
||
# Determine whether the HIT contains unacceptable messages. (We do this for | ||
# every HIT, even if acceptability violation info was already saved, because | ||
# the violation criteria may have changed since the HIT was collected.) | ||
utterances_1 = [m['text'] for m in messages_1] | ||
assert utterances_1[0] == 'Hi!', ( | ||
'This script assumes that the first human message is "Hi!", which is ' | ||
'set by default and cannot be changed by the crowdsourcing worker.' | ||
) | ||
acceptability_violations = self.acceptability_checker.check_messages( | ||
messages=utterances_1[1:], # Don't use the initial "Hi!" | ||
is_worker_0=True, | ||
violation_types=self.acceptability_checker.ALL_VIOLATION_TYPES, | ||
) | ||
# Here, "worker 0" refers to Speaker 1, because we mix 0- and 1-indexing | ||
if acceptability_violations != '': | ||
logging.info( | ||
f'Conversation fails acceptability checks with a violation of ' | ||
f'"{acceptability_violations}", given the following utterances: ' | ||
f'{utterances_1[1:]}. Skipping.' | ||
) | ||
unacceptable_task_units.append(task_unit) | ||
assert ( | ||
mturk_worker_id is not None | ||
), "MTurk worker ID cannot be determined for this unacceptable conversation!" | ||
unacceptable_worker_ids.append(mturk_worker_id) | ||
continue | ||
|
||
# Ignore the conversation if ratings for all turns are the same, because | ||
# it's somewhat implausible that *all* turns in a conversation should garner | ||
# the same rating of engagingness, humanness, interestingness, or none. | ||
# (However, don't put these workers on the "unacceptable worker IDs" list, | ||
# to give them a little bit of the benefit of the doubt: i.e. maybe the | ||
# worker just didn't try hard enough to find which responses were more | ||
# engaging, etc. than others, but that doesn't mean that all of their HITs | ||
# across all evals are bad and should be removed.) | ||
if self.filter_uniform_hits: | ||
annotations = [ | ||
m['task_data']['problem_data_for_prior_message'] | ||
for m in task_unit['data']['messages'] | ||
if 'problem_data_for_prior_message' in m.get('task_data', {}) | ||
] | ||
hashable_annotations = [ | ||
tuple(a[key] for key in sorted(a.keys())) for a in annotations | ||
] | ||
unique_annotations = set(hashable_annotations) | ||
if len(unique_annotations) < 1: | ||
raise ValueError('No annotations found for this HIT!') | ||
elif len(unique_annotations) == 1: | ||
logging.info( | ||
f'All model responses in the conversation received the same ' | ||
f'annotation: {hashable_annotations[0]}. Skipping.' | ||
) | ||
unacceptable_task_units.append(task_unit) | ||
continue | ||
|
||
single_turn_dicts = [] | ||
|
||
# Compile personas and previous utterances | ||
text_parts = [] | ||
if custom_data['personas'] is not None and len(custom_data['personas']) > 0: | ||
assert len(custom_data['personas']) == 2 | ||
text_parts += [ | ||
'HUMAN PERSONA: ' + ' '.join(custom_data['personas'][0]), | ||
'BOT PERSONA: ' + ' '.join(custom_data['personas'][1]), | ||
] | ||
if ( | ||
custom_data['additional_context'] is not None | ||
and len(custom_data['additional_context']) > 0 | ||
): | ||
text_parts.append( | ||
'ADDITIONAL CONTEXT: ' + custom_data['additional_context'] | ||
) | ||
single_turn_dicts.append( | ||
{**info_dict, ('context', ''): ' '.join(text_parts)} | ||
) | ||
|
||
# Loop over conversation turns | ||
turns_per_speaker = defaultdict(int) | ||
for message in task_unit['data']['messages']: | ||
if 'text' in message: | ||
|
||
speaker_id = message['id'] | ||
|
||
# Add in annotation results, if they exist | ||
if 'problem_data_for_prior_message' in message.get('task_data', {}): | ||
bucket_data = { | ||
('annotation_bucket', bucket): value | ||
for bucket, value in message['task_data'][ | ||
'problem_data_for_prior_message' | ||
].items() | ||
} | ||
else: | ||
bucket_data = {} | ||
|
||
# Add in results from the final rating(s), if they exist | ||
if 'final_rating' in message.get('task_data', {}): | ||
ratings = message['task_data']['final_rating'].split('|') | ||
final_rating_data = { | ||
('final_rating', str(idx)): value | ||
for idx, value in enumerate(ratings) | ||
} | ||
else: | ||
final_rating_data = {} | ||
|
||
turns_per_speaker[speaker_id] += 1 | ||
|
||
single_turn_dicts.append( | ||
{ | ||
**info_dict, | ||
('speaker_id', ''): speaker_id, | ||
('speaker_turn_idx', ''): turns_per_speaker[speaker_id], | ||
('text', ''): message['text'].replace('\n', '__newline__'), | ||
**bucket_data, | ||
**final_rating_data, | ||
} | ||
) | ||
|
||
# Adding the full conversation to the list of conversations | ||
single_turn_series = [ | ||
pd.Series(dict_).to_frame().transpose() for dict_ in single_turn_dicts | ||
] | ||
single_convo_df = pd.concat(single_turn_series, axis=0, sort=False) | ||
conversation_dfs.append(single_convo_df) | ||
conversation_idx += 1 | ||
|
||
logging.info( | ||
f'{num_convos_with_no_save_data:d} conversations found with no save data.' | ||
) | ||
logging.info( | ||
f'{num_wrong_status_convos:d} conversations found with the wrong status.' | ||
) | ||
logging.info(f'{num_complete_convos:d} complete conversations found:') | ||
logging.info(f'\t{len(unacceptable_task_units):d} unacceptable conversations.') | ||
logging.info(f'\t{len(conversation_dfs):d} acceptable conversations.') | ||
|
||
# # Compile full results | ||
|
||
if len(conversation_dfs) == 0: | ||
raise ValueError('No acceptable conversations found!') | ||
unordered_conversation_df = pd.concat(conversation_dfs, axis=0) | ||
initial_ordered_columns = list(info_dict.keys()) + [ | ||
('context', ''), | ||
('speaker_id', ''), | ||
('speaker_turn_idx', ''), | ||
('text', ''), | ||
] | ||
all_ordered_columns = initial_ordered_columns + [ | ||
col | ||
for col in unordered_conversation_df.columns | ||
if col not in initial_ordered_columns | ||
] | ||
conversation_df = unordered_conversation_df[all_ordered_columns] | ||
# TODO: is there a less hacky way than this, which relies on the most recent | ||
# value of `info_dict`, to put the columns back into the right order? | ||
|
||
# # Calculate stats | ||
|
||
logging.info( | ||
f'Saving MTurk IDs of workers with unacceptable conversations to ' | ||
f'{self.unacceptable_worker_ids_path}.' | ||
) | ||
with open(self.unacceptable_worker_ids_path, 'w') as f: | ||
for worker_id in unacceptable_worker_ids: | ||
f.write(worker_id + '\n') | ||
|
||
# Calculate rates of selecting various annotation buckets | ||
annotation_bucket_df = conversation_df['annotation_bucket'].dropna( | ||
axis=0, how='any' | ||
) | ||
if annotation_bucket_df.isna().sum().sum() > 0: | ||
raise ValueError( | ||
'There is at least one row in which only partial annotation bucket data exists!' | ||
) | ||
annotation_selection_rate_df = annotation_bucket_df.mean().to_frame( | ||
'selection_rate' | ||
) | ||
annotation_selection_rate_df.to_csv(self.annotation_selection_rate_path) | ||
logging.info( | ||
f'Annotation bucket selection rates saved to {self.annotation_selection_rate_path}.' | ||
) | ||
output_strings = [ | ||
f'{series.name}: {100*series["selection_rate"]:0.0f}%' | ||
for _, series in annotation_selection_rate_df.iterrows() | ||
] | ||
logging.info('Annotation bucket selection rates:\n' + '\n'.join(output_strings)) | ||
|
||
# Calculate Likert score stats | ||
final_rating_df = conversation_df['final_rating'].dropna(axis=0, how='any') | ||
if final_rating_df.isna().sum().sum() > 0: | ||
raise ValueError( | ||
'There is at least one row in which only partial final rating data exists!' | ||
) | ||
likert_score_stat_df = final_rating_df.astype(int).describe() | ||
likert_score_stat_df.to_csv(self.likert_score_stat_path) | ||
logging.info(f'Likert score statistics saved to {self.likert_score_stat_path}.') | ||
logging.info(f'Mean Likert scores:\n{likert_score_stat_df.loc["mean"]}') | ||
|
||
return conversation_df | ||
|
||
|
||
if __name__ == '__main__': | ||
parser_ = ModelChatResultsCompiler.setup_args() | ||
args_ = parser_.parse_args() | ||
ModelChatResultsCompiler(vars(args_)).compile_and_save_results() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's stick to one # for comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point - so, my convention has often been to use two hash symbols for a new section of code: i.e., this comment is indicating that the following 100-ish lines deal with checking whether this task should be skipped. But if this isn't clear to others, then maybe something more obvious should be used. Is there a particular notation that you use for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an aside, often if I find a chunk of code is complex enough to warrant something like that, it's complex enough to move into a nicely-named helper function.
I understand that in this scripting setup, this chunk also has some external effects (changing convos counts, doing some data extraction, etc). But in that case, those are critical to its function, and thus the whole code block is doing more than "determining whether the task unit should be skipped". Overall feels a little like an anti-pattern to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - @JackUrb definitely agreed that the original phrasing of that comment was incomplete. I've just removed that comment entirely to avoid confusion. Thanks!