Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Fix model chat speakers and personas #4273

Merged
merged 5 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions parlai/crowdsourcing/tasks/model_chat/frontend/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ function MainApp() {
renderSidePane={({ mephistoContext: { taskConfig }, appContext: { taskContext } }) => (
<DefaultTaskDescription
chatTitle={taskConfig.task_title}
taskDescriptionHtml={taskConfig.left_pane_text}
taskDescriptionHtml={
taskConfig.left_pane_text.replace(
"[persona_string_1]", taskContext.human_persona_string_1,
).replace(
"[persona_string_2]", taskContext.human_persona_string_2,
)
}
>
{(taskContext.hasOwnProperty('image_src') && taskContext['image_src']) ? (
<div>
Expand All @@ -45,24 +51,23 @@ function MainApp() {
</DefaultTaskDescription>
)}
renderTextResponse={
({
mephistoContext: { taskConfig },
({
mephistoContext: { taskConfig },
appContext: { appSettings },
onMessageSend,
active,

}) => (
<ResponseComponent
<ResponseComponent
appSettings={appSettings}
taskConfig={taskConfig}
active={active}
onMessageSend={onMessageSend}
/>
)
)
}
/>
);
}

ReactDOM.render(<MainApp />, document.getElementById("app"));

3 changes: 1 addition & 2 deletions parlai/crowdsourcing/tasks/model_chat/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from parlai.crowdsourcing.tasks.model_chat.impl import run_task
from parlai.crowdsourcing.utils.mturk import MTurkRunScriptConfig
import parlai.crowdsourcing.tasks.model_chat.worlds as world_module


TASK_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -40,7 +39,7 @@ class ScriptConfig(MTurkRunScriptConfig):

@hydra.main(config_path="hydra_configs", config_name="scriptconfig")
def main(cfg: DictConfig) -> None:
run_task(cfg=cfg, task_directory=TASK_DIRECTORY, world_module=world_module)
run_task(cfg=cfg, task_directory=TASK_DIRECTORY)


if __name__ == "__main__":
Expand Down
17 changes: 10 additions & 7 deletions parlai/crowdsourcing/tasks/model_chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,16 @@ def _check_final_chat_data(
for actual_message, expected_message in zip(
actual_value[key_inner], expected_value_inner
):
self.assertEqual(
{k: v for k, v in actual_message.items() if k != 'message_id'},
{
k: v
for k, v in expected_message.items()
if k != 'message_id'
},
clean_actual_message = {
k: v for k, v in actual_message.items() if k != 'message_id'
}
clean_expected_message = {
k: v for k, v in expected_message.items() if k != 'message_id'
}
self.assertDictEqual(
clean_actual_message,
clean_expected_message,
f'The following dictionaries are different: {clean_actual_message} and {clean_expected_message}',
)
elif key_inner == 'task_description':
for (key_inner2, expected_value_inner2) in expected_value_inner.items():
Expand Down
23 changes: 17 additions & 6 deletions parlai/crowdsourcing/tasks/model_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def __init__(self, opt, agent, bot):
self.task_turn_idx = 0
self.num_turns = num_turns

self.agent.agent_id = 'Speaker 1'
self.bot.agent_id = 'Speaker 2'

self.dialog = []
self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
Expand Down Expand Up @@ -488,7 +491,7 @@ def _run_initial_turn(self) -> None:
# Display the previous two utterances
human_first_msg = {
'episode_done': False,
'id': self.agent.id,
'id': self.agent.agent_id,
'text': self.context_info['person1_seed_utterance'],
'fake_start': True,
'agent_idx': 0,
Expand All @@ -497,7 +500,7 @@ def _run_initial_turn(self) -> None:
human_first_msg[k] = v
bot_first_msg = {
'episode_done': False,
'id': self.bot.id,
'id': self.bot.agent_id,
'text': self.context_info['person2_seed_utterance'],
'fake_start': True,
'agent_idx': 1,
Expand All @@ -513,12 +516,20 @@ def _run_initial_turn(self) -> None:

elif self.opt['conversation_start_mode'] == 'hi':
print('[Displaying "Hi!" only as per Meena task.]')
if self.personas is not None:
human_persona_strings = [s.strip() for s in self.personas[0]]
else:
human_persona_strings = ['', '']
human_first_msg = {
'episode_done': False,
'id': self.agent.id,
'id': self.agent.agent_id,
'text': 'Hi!',
'fake_start': True,
'agent_idx': 0,
'task_data': {
'human_persona_string_1': human_persona_strings[0],
'human_persona_string_2': human_persona_strings[1],
},
}
for k, v in control_msg.items():
human_first_msg[k] = v
Expand All @@ -528,7 +539,9 @@ def _run_initial_turn(self) -> None:
self.bot.observe(validate(human_first_msg))

first_bot_act = self.bot.act()
first_bot_act = Compatibility.maybe_fix_act(first_bot_act)
first_bot_act = Compatibility.backward_compatible_force_set(
first_bot_act, 'id', self.bot.agent_id
)

self.agent.observe(validate(first_bot_act))

Expand Down Expand Up @@ -671,8 +684,6 @@ def make_world(opt, agents):
statistics_condition = opt['statistics_condition']
context_generator = opt['context_generator']

agents[0].agent_id = "Worker"

# Get context: personas, previous utterances, etc.
if context_generator is not None:
context_info = context_generator.get_context()
Expand Down
3 changes: 1 addition & 2 deletions parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _run_initial_turn(self) -> None:
bot_first_act_raw = Message(
Compatibility.maybe_fix_act(bot_first_act_raw)
).json_safe_payload()
bot_first_act_raw['id'] = self.bot.agent_id
self.agent.observe(validate(bot_first_act_raw))
bot_first_act = {
'episode_done': False,
Expand Down Expand Up @@ -134,8 +135,6 @@ def shutdown(self):

def make_world(opt, agents):

agents[0].agent_id = "Worker"

# We are showing an image to the worker and bot, so grab the image path and other
# context info
image_idx, model_name, no_more_work = opt['image_stack'].get_next_image(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
"dialog": [
{
"episode_done": false,
"id": "Worker",
"id": "Speaker 1",
"text": "Hi!",
"fake_start": true,
"agent_idx": 0,
"message_id": "4f72800c-e1ff-4411-8ee5-286336c27859"
"message_id": "4f72800c-e1ff-4411-8ee5-286336c27859",
"task_data": {
"human_persona_string_1": "",
"human_persona_string_2": ""
}
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -29,12 +33,12 @@
{
"agent_idx": 0,
"text": "What are you nervous about?",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -47,12 +51,12 @@
{
"agent_idx": 0,
"text": "Do you have any plans for the weekend?",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -65,12 +69,12 @@
{
"agent_idx": 0,
"text": "Yeah that sounds great! I like to bike and try new restaurants.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -83,12 +87,12 @@
{
"agent_idx": 0,
"text": "Oh, Italian food is great. I also love Thai and Indian.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -101,12 +105,12 @@
{
"agent_idx": 0,
"text": "Hmmm - anything with peanuts? Or I like when they have spicy licorice-like herbs.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,69 @@
},
{
"episode_done": false,
"id": "TransresnetAgent",
"id": "Speaker 2",
"text": "I must learn that bird's name!",
"agent_idx": 1
},
{
"agent_idx": 0,
"text": "Response 1",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 2",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 3",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 4",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 5",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 6",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent",
"id": "Speaker 2",
"final_rating": 0
}
],
Expand Down
Loading