Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Fix model chat speakers and personas (#4273)
Browse files Browse the repository at this point in the history
* Layer changes

* Fixes

* Work on CI checks

* Fixes

* Fixes
  • Loading branch information
EricMichaelSmith authored Jan 4, 2022
1 parent e686a02 commit 6df9361
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 145 deletions.
17 changes: 11 additions & 6 deletions parlai/crowdsourcing/tasks/model_chat/frontend/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ function MainApp() {
renderSidePane={({ mephistoContext: { taskConfig }, appContext: { taskContext } }) => (
<DefaultTaskDescription
chatTitle={taskConfig.task_title}
taskDescriptionHtml={taskConfig.left_pane_text}
taskDescriptionHtml={
taskConfig.left_pane_text.replace(
"[persona_string_1]", taskContext.human_persona_string_1,
).replace(
"[persona_string_2]", taskContext.human_persona_string_2,
)
}
>
{(taskContext.hasOwnProperty('image_src') && taskContext['image_src']) ? (
<div>
Expand All @@ -45,24 +51,23 @@ function MainApp() {
</DefaultTaskDescription>
)}
renderTextResponse={
({
mephistoContext: { taskConfig },
({
mephistoContext: { taskConfig },
appContext: { appSettings },
onMessageSend,
active,

}) => (
<ResponseComponent
<ResponseComponent
appSettings={appSettings}
taskConfig={taskConfig}
active={active}
onMessageSend={onMessageSend}
/>
)
)
}
/>
);
}

ReactDOM.render(<MainApp />, document.getElementById("app"));

3 changes: 1 addition & 2 deletions parlai/crowdsourcing/tasks/model_chat/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from parlai.crowdsourcing.tasks.model_chat.impl import run_task
from parlai.crowdsourcing.utils.mturk import MTurkRunScriptConfig
import parlai.crowdsourcing.tasks.model_chat.worlds as world_module


TASK_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -40,7 +39,7 @@ class ScriptConfig(MTurkRunScriptConfig):

@hydra.main(config_path="hydra_configs", config_name="scriptconfig")
def main(cfg: DictConfig) -> None:
run_task(cfg=cfg, task_directory=TASK_DIRECTORY, world_module=world_module)
run_task(cfg=cfg, task_directory=TASK_DIRECTORY)


if __name__ == "__main__":
Expand Down
17 changes: 10 additions & 7 deletions parlai/crowdsourcing/tasks/model_chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,16 @@ def _check_final_chat_data(
for actual_message, expected_message in zip(
actual_value[key_inner], expected_value_inner
):
self.assertEqual(
{k: v for k, v in actual_message.items() if k != 'message_id'},
{
k: v
for k, v in expected_message.items()
if k != 'message_id'
},
clean_actual_message = {
k: v for k, v in actual_message.items() if k != 'message_id'
}
clean_expected_message = {
k: v for k, v in expected_message.items() if k != 'message_id'
}
self.assertDictEqual(
clean_actual_message,
clean_expected_message,
f'The following dictionaries are different: {clean_actual_message} and {clean_expected_message}',
)
elif key_inner == 'task_description':
for (key_inner2, expected_value_inner2) in expected_value_inner.items():
Expand Down
23 changes: 17 additions & 6 deletions parlai/crowdsourcing/tasks/model_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def __init__(self, opt, agent, bot):
self.task_turn_idx = 0
self.num_turns = num_turns

self.agent.agent_id = 'Speaker 1'
self.bot.agent_id = 'Speaker 2'

self.dialog = []
self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
Expand Down Expand Up @@ -447,7 +450,7 @@ def _run_initial_turn(self) -> None:
# Display the previous two utterances
human_first_msg = {
'episode_done': False,
'id': self.agent.id,
'id': self.agent.agent_id,
'text': self.context_info['person1_seed_utterance'],
'fake_start': True,
'agent_idx': 0,
Expand All @@ -456,7 +459,7 @@ def _run_initial_turn(self) -> None:
human_first_msg[k] = v
bot_first_msg = {
'episode_done': False,
'id': self.bot.id,
'id': self.bot.agent_id,
'text': self.context_info['person2_seed_utterance'],
'fake_start': True,
'agent_idx': 1,
Expand All @@ -472,12 +475,20 @@ def _run_initial_turn(self) -> None:

elif self.opt['conversation_start_mode'] == 'hi':
print('[Displaying "Hi!" only as per Meena task.]')
if self.personas is not None:
human_persona_strings = [s.strip() for s in self.personas[0]]
else:
human_persona_strings = ['', '']
human_first_msg = {
'episode_done': False,
'id': self.agent.id,
'id': self.agent.agent_id,
'text': 'Hi!',
'fake_start': True,
'agent_idx': 0,
'task_data': {
'human_persona_string_1': human_persona_strings[0],
'human_persona_string_2': human_persona_strings[1],
},
}
for k, v in control_msg.items():
human_first_msg[k] = v
Expand All @@ -487,7 +498,9 @@ def _run_initial_turn(self) -> None:
self.bot.observe(validate(human_first_msg))

first_bot_act = self.bot.act()
first_bot_act = Compatibility.maybe_fix_act(first_bot_act)
first_bot_act = Compatibility.backward_compatible_force_set(
first_bot_act, 'id', self.bot.agent_id
)

self.agent.observe(validate(first_bot_act))

Expand Down Expand Up @@ -630,8 +643,6 @@ def make_world(opt, agents):
statistics_condition = opt['statistics_condition']
context_generator = opt['context_generator']

agents[0].agent_id = "Worker"

# Get context: personas, previous utterances, etc.
if context_generator is not None:
context_info = context_generator.get_context()
Expand Down
3 changes: 1 addition & 2 deletions parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _run_initial_turn(self) -> None:
bot_first_act_raw = Message(
Compatibility.maybe_fix_act(bot_first_act_raw)
).json_safe_payload()
bot_first_act_raw['id'] = self.bot.agent_id
self.agent.observe(validate(bot_first_act_raw))
bot_first_act = {
'episode_done': False,
Expand Down Expand Up @@ -134,8 +135,6 @@ def shutdown(self):

def make_world(opt, agents):

agents[0].agent_id = "Worker"

# We are showing an image to the worker and bot, so grab the image path and other
# context info
image_idx, model_name, no_more_work = opt['image_stack'].get_next_image(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
"dialog": [
{
"episode_done": false,
"id": "Worker",
"id": "Speaker 1",
"text": "Hi!",
"fake_start": true,
"agent_idx": 0,
"message_id": "4f72800c-e1ff-4411-8ee5-286336c27859"
"message_id": "4f72800c-e1ff-4411-8ee5-286336c27859",
"task_data": {
"human_persona_string_1": "",
"human_persona_string_2": ""
}
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -29,12 +33,12 @@
{
"agent_idx": 0,
"text": "What are you nervous about?",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -47,12 +51,12 @@
{
"agent_idx": 0,
"text": "Do you have any plans for the weekend?",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -65,12 +69,12 @@
{
"agent_idx": 0,
"text": "Yeah that sounds great! I like to bike and try new restaurants.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -83,12 +87,12 @@
{
"agent_idx": 0,
"text": "Oh, Italian food is great. I also love Thai and Indian.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand All @@ -101,12 +105,12 @@
{
"agent_idx": 0,
"text": "Hmmm - anything with peanuts? Or I like when they have spicy licorice-like herbs.",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "I don't know.",
"id": "FixedResponseAgent",
"id": "Speaker 2",
"problem_data": {
"bucket_0": false,
"bucket_1": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,69 @@
},
{
"episode_done": false,
"id": "TransresnetAgent",
"id": "Speaker 2",
"text": "I must learn that bird's name!",
"agent_idx": 1
},
{
"agent_idx": 0,
"text": "Response 1",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 2",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 3",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 4",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 5",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent"
"id": "Speaker 2"
},
{
"agent_idx": 0,
"text": "Response 6",
"id": "Worker"
"id": "Speaker 1"
},
{
"agent_idx": 1,
"text": "My, aren't you a pretty bird?",
"id": "TransresnetAgent",
"id": "Speaker 2",
"final_rating": 0
}
],
Expand Down
Loading

0 comments on commit 6df9361

Please sign in to comment.