Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[WizInt OSS] data compiler using Mephisto data browser #4034

Merged
merged 1 commit into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions parlai/crowdsourcing/projects/wizard_of_internet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ You need to have a functional search server running, and sets its address in `se
This server responds to the search requests sent by the worker who takes *wizard* role during this task:
It receieves a json with two keys: `q` and `n`, which are a string that is the search query, and an integer that is the number of pages to return, respectively.
It sends its response also as a json under a key named `response` which has a list of documents retrieved for the received search query. Each document is a mapping (dictionary) of *string->string* with at least 3 fields: `url`, `title`, and `content` (see [SearchEngineRetriever](https://github.com/facebookresearch/ParlAI/blob/70ee4a2c63008774fc9e66a8392847554920a14d/parlai/agents/rag/retrieve_api.py#L73) for more info on how this task interacts with the search server).

## Creating the dataset

Having collected data from crowdsourcing task, you may use `compile_resullts.py` to create your dataset, as a json file.
For example, if you called your task `wizard-of-internet` (you set this name in the config file that you ran with your task from `hydra_config`),
the following code creates your dataset as a json file in the directory specified by `--output-folder` flag:

```.python
python compile_results.py --task-name wizard-of-internet --output-folder=/dataset/wizard-internet
```
186 changes: 186 additions & 0 deletions parlai/crowdsourcing/projects/wizard_of_internet/compile_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Compiles the final dataset (a json file) from the this Mephisto crowdsourcing task.

Example use:

python compile_results.py --task-name wizard-of-internet --output-folder=/dataset/wizard-internet

"""

from typing import Dict, Union
import parlai.utils.logging as logging
from parlai.crowdsourcing.projects.wizard_of_internet.wizard_internet_blueprint import ( # noqa: F401
WIZARD_INTERNET_PARLAICHAT_BLUEPRINT,
)
from parlai.crowdsourcing.utils.analysis import AbstractResultsCompiler
from mephisto.abstractions.blueprint import AgentState

# Roles of the task
WIZARD = 'Wizard'
APPRENTICE = 'Apprentice'
SEARCH_AGENT = 'SearchAgent'

# Constant names with many reuse
CONTENTS = 'contents'
SELECTED_CONTENTS = 'selected_contents'
DIALOG_HISTORY_KEY = 'dialog_history'


def is_persona_form_response(message):
k1 = 'task_data'
k2 = 'form_responses'
return message and k1 in message and k2 in message[k1]


def get_human_sender(message):
sender = message['id']
if sender in (WIZARD, APPRENTICE, SEARCH_AGENT):
return sender


def remove_keys_from_dict(dictionary_data, keys):
for rm_key in keys:
del dictionary_data[rm_key]
return dictionary_data


def chat_interruption(message):
for k in ('requested_finish', 'MEPHISTO_is_submit'):
if k in message and message[k]:
return True
return False


class ChatMessage:
"""
Container for keeping the content of one interaction between agents.
"""

def __init__(self, message_dict: dict) -> None:
self._message = message_dict
self._sender = get_human_sender(self._message)

def _format_action(self, receiver):
return f'{self._sender} => {receiver}'

def get_action(self):
if self._sender == APPRENTICE:
return self._format_action(WIZARD)

if self._sender == SEARCH_AGENT:
return self._format_action(WIZARD)

# Must be WIZARD
k = 'is_search_query'
if k in self._message and self._message[k]:
return self._format_action(SEARCH_AGENT)
else:
return self._format_action(APPRENTICE)

def get_text(self):
if self._sender == SEARCH_AGENT:
return ''
return self._message.get('text', '')

def get_context(self):
if self._sender == APPRENTICE:
return {}

context_data = self._message.get('task_data', None)

if self._sender == WIZARD:
if not context_data:
return {}
else:
if 'form_responses' in context_data:
return {'Persona': context_data['form_responses'][0]['response']}
return {
CONTENTS: context_data.get('text_candidates', ''),
SELECTED_CONTENTS: context_data.get('selected_text_candaidtes', ''),
}
# Must be SEARCH_AGENT
return {CONTENTS: context_data['search_results']}

def compile_message(self):
d = dict()
d['action'] = self.get_action()
d['text'] = self.get_text()
d['context'] = self.get_context()
return d


class WizardOfInternetResultsCompiler(AbstractResultsCompiler):
"""
Compiles the results of Wizard of Internet crowdsourcing task into a json dataset.
"""

def is_unit_acceptable(self, unit_data):
# Depending on the situation (in practice) we may be able to salvage incomplete data.
# Here, we only keep completed and approved ones (discarding the rest).
return unit_data['status'] in (
AgentState.STATUS_ACCEPTED,
AgentState.STATUS_APPROVED,
)

def format_chat_data(self, dialog_history):
data_dict = {'apprentice_persona': '', DIALOG_HISTORY_KEY: []}

for message in dialog_history:
if message['id'] == 'PersonaAgent':
data_dict['apprentice_persona'] = message['task_data'][
'apprentice_persona'
]
continue

if is_persona_form_response(message) or not get_human_sender(message):
continue

if chat_interruption(message):
# There was interruption (could be a clean Finish)
# Ignoring the rest of messages
break

else:
compiled_message = ChatMessage(message).compile_message()
data_dict[DIALOG_HISTORY_KEY].append(compiled_message)

return data_dict

def compile_results(self) -> Dict[str, Dict[str, Union[dict, str]]]:

logging.info('Retrieving task data from Mephisto.')
task_units_data = self.get_task_data()
logging.info(f'Data for {len(task_units_data)} units loaded successfully.')

results = dict()
for work_unit in task_units_data:
assignment_id = work_unit['assignment_id']

data = work_unit['data']
agent_name = data['agent_name']
if agent_name == 'Wizard':
# Only collecting Wizard side data. Because it contains
# data that is needed from the apprentice side too.
formatted_chat_data = self.format_chat_data(data['messages'])
results[assignment_id] = formatted_chat_data

logging.info(f'{len(results)} dialogues compiled.')
return results


if __name__ == '__main__':
parser_ = WizardOfInternetResultsCompiler.setup_args()
args = parser_.parse_args()
opt = {
'task_name': args.task_name,
'results_format': 'json',
'output_folder': args.output_folder,
}
wizard_data_compiler = WizardOfInternetResultsCompiler(opt)
wizard_data_compiler.compile_and_save_results()