diff --git a/parlai/crowdsourcing/tasks/model_chat/bot_agent.py b/parlai/crowdsourcing/tasks/model_chat/bot_agent.py index 642d83df2df..baf5ac1e800 100644 --- a/parlai/crowdsourcing/tasks/model_chat/bot_agent.py +++ b/parlai/crowdsourcing/tasks/model_chat/bot_agent.py @@ -44,12 +44,7 @@ def act(self, timeout=None): act_out = self.model_agent.act() else: act_out = self.model_agent.act() - act_out = Message(act_out) - # Wrap as a Message for compatibility with older ParlAI models - - if 'metrics' in act_out: - # Metrics can't be serialized when saving results as a JSON - del act_out['metrics'] + act_out = Message(act_out).json_safe_payload() if 'dict_lower' in self.opt and not self.opt['dict_lower']: # model is cased so we don't want to normalize the reply like below @@ -57,7 +52,7 @@ def act(self, timeout=None): else: final_message_text = normalize_reply(act_out['text']) - act_out.force_set('text', final_message_text) + act_out['text'] = final_message_text assert ('episode_done' not in act_out) or (not act_out['episode_done']) self.turn_idx += 1 return {**act_out, 'episode_done': False} diff --git a/parlai/crowdsourcing/tasks/model_chat/worlds.py b/parlai/crowdsourcing/tasks/model_chat/worlds.py index 7b03ddca3e8..32d7d19b4b5 100644 --- a/parlai/crowdsourcing/tasks/model_chat/worlds.py +++ b/parlai/crowdsourcing/tasks/model_chat/worlds.py @@ -12,8 +12,9 @@ import numpy as np -from parlai.core.worlds import validate from parlai.core.agents import create_agent_from_shared +from parlai.core.message import Message +from parlai.core.worlds import validate from parlai.crowdsourcing.utils.acceptability import AcceptabilityChecker from parlai.crowdsourcing.utils.worlds import CrowdOnboardWorld, CrowdTaskWorld from parlai.crowdsourcing.tasks.model_chat.bot_agent import TurkLikeAgent @@ -224,10 +225,9 @@ def parley(self): for idx, agent in enumerate([self.agent, self.bot]): if not self.chat_done: acts[idx] = agent.act(timeout=self.max_resp_time) - acts[idx] = Compatibility.maybe_fix_act(acts[idx]) - if 'metrics' in acts[idx]: - del acts[idx]['metrics'] - # Metrics can't be saved to JSON and are not needed here + acts[idx] = Message( + Compatibility.maybe_fix_act(acts[idx]) + ).json_safe_payload() print( f'Got act for agent idx {idx}, act was: {acts[idx]} and self.task_turn_idx: {self.task_turn_idx}.' ) diff --git a/parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py b/parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py index f874fdb3c5f..b96931d81bb 100644 --- a/parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py +++ b/parlai/crowdsourcing/tasks/model_chat/worlds_image_chat.py @@ -75,10 +75,9 @@ def _run_initial_turn(self) -> None: # Have the bot respond bot_first_act_raw = self.bot.act() - bot_first_act_raw = Compatibility.maybe_fix_act(bot_first_act_raw) - if 'metrics' in bot_first_act_raw: - del bot_first_act_raw['metrics'] - # Metrics can't be saved to JSON and are not needed here + bot_first_act_raw = Message( + Compatibility.maybe_fix_act(bot_first_act_raw) + ).json_safe_payload() self.agent.observe(validate(bot_first_act_raw)) bot_first_act = { 'episode_done': False, @@ -91,7 +90,7 @@ def _run_initial_turn(self) -> None: self.dialog.append(image_act) self.dialog.append(bot_first_act) - def _postprocess_acts(self, acts: List[Message], agent_idx: int): + def _postprocess_acts(self, acts: List[dict], agent_idx: int): """ Show the bot the image again on every turn. """ @@ -100,7 +99,7 @@ def _postprocess_acts(self, acts: List[Message], agent_idx: int): # image-related fields needed by the model for key, value in self.image_act.items(): if key not in ['episode_done', 'id', 'text', 'agent_idx']: - acts[agent_idx].force_set(key, value) + acts[agent_idx][key] = value def get_final_chat_data(self) -> Dict[str, Any]: """