Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Use standard tool for cleaning up metrics #3726

Merged
merged 2 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions parlai/crowdsourcing/tasks/model_chat/bot_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,15 @@ def act(self, timeout=None):
act_out = self.model_agent.act()
else:
act_out = self.model_agent.act()
act_out = Message(act_out)
# Wrap as a Message for compatibility with older ParlAI models

if 'metrics' in act_out:
# Metrics can't be serialized when saving results as a JSON
del act_out['metrics']
act_out = Message(act_out).json_safe_payload()

if 'dict_lower' in self.opt and not self.opt['dict_lower']:
# model is cased so we don't want to normalize the reply like below
final_message_text = act_out['text']
else:
final_message_text = normalize_reply(act_out['text'])

act_out.force_set('text', final_message_text)
act_out['text'] = final_message_text
assert ('episode_done' not in act_out) or (not act_out['episode_done'])
self.turn_idx += 1
return {**act_out, 'episode_done': False}
Expand Down
10 changes: 5 additions & 5 deletions parlai/crowdsourcing/tasks/model_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import numpy as np

from parlai.core.worlds import validate
from parlai.core.agents import create_agent_from_shared
from parlai.core.message import Message
from parlai.core.worlds import validate
from parlai.crowdsourcing.utils.acceptability import AcceptabilityChecker
from parlai.crowdsourcing.utils.worlds import CrowdOnboardWorld, CrowdTaskWorld
from parlai.crowdsourcing.tasks.model_chat.bot_agent import TurkLikeAgent
Expand Down Expand Up @@ -224,10 +225,9 @@ def parley(self):
for idx, agent in enumerate([self.agent, self.bot]):
if not self.chat_done:
acts[idx] = agent.act(timeout=self.max_resp_time)
acts[idx] = Compatibility.maybe_fix_act(acts[idx])
if 'metrics' in acts[idx]:
del acts[idx]['metrics']
# Metrics can't be saved to JSON and are not needed here
acts[idx] = Message(
Compatibility.maybe_fix_act(acts[idx])
).json_safe_payload()
print(
f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.'
)
Expand Down
11 changes: 5 additions & 6 deletions parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def _run_initial_turn(self) -> None:

# Have the bot respond
bot_first_act_raw = self.bot.act()
bot_first_act_raw = Compatibility.maybe_fix_act(bot_first_act_raw)
if 'metrics' in bot_first_act_raw:
del bot_first_act_raw['metrics']
# Metrics can't be saved to JSON and are not needed here
bot_first_act_raw = Message(
Compatibility.maybe_fix_act(bot_first_act_raw)
).json_safe_payload()
self.agent.observe(validate(bot_first_act_raw))
bot_first_act = {
'episode_done': False,
Expand All @@ -91,7 +90,7 @@ def _run_initial_turn(self) -> None:
self.dialog.append(image_act)
self.dialog.append(bot_first_act)

def _postprocess_acts(self, acts: List[Message], agent_idx: int):
def _postprocess_acts(self, acts: List[dict], agent_idx: int):
"""
Show the bot the image again on every turn.
"""
Expand All @@ -100,7 +99,7 @@ def _postprocess_acts(self, acts: List[Message], agent_idx: int):
# image-related fields needed by the model
for key, value in self.image_act.items():
if key not in ['episode_done', 'id', 'text', 'agent_idx']:
acts[agent_idx].force_set(key, value)
acts[agent_idx][key] = value

def get_final_chat_data(self) -> Dict[str, Any]:
"""
Expand Down