diff --git a/parlai/tasks/dialogue_nli/agents.py b/parlai/tasks/dialogue_nli/agents.py index 854d0ee825d..d7bf7190aac 100644 --- a/parlai/tasks/dialogue_nli/agents.py +++ b/parlai/tasks/dialogue_nli/agents.py @@ -13,6 +13,7 @@ import json import os +from parlai.core.message import Message from parlai.core.teachers import FixedDialogTeacher from .build import build from parlai.tasks.multinli.agents import convert_to_dialogData @@ -134,7 +135,7 @@ def get(self, episode_idx, entry_idx=0): binary_classes=self.binary_classes, ) new_entry = {k: entry[k] for k in ENTRY_FIELDS if k in entry} - return new_entry + return Message(new_entry) class ExtrasTeacher(DialogueNliTeacher): diff --git a/parlai/tasks/empathetic_dialogues/agents.py b/parlai/tasks/empathetic_dialogues/agents.py index 8371683730b..a17c10b8134 100644 --- a/parlai/tasks/empathetic_dialogues/agents.py +++ b/parlai/tasks/empathetic_dialogues/agents.py @@ -13,6 +13,7 @@ import numpy as np from parlai.utils.io import PathManager +from parlai.core.message import Message from parlai.core.teachers import FixedDialogTeacher from .build import build @@ -220,18 +221,21 @@ def get(self, episode_idx, entry_idx=0): ep = self.data[episode_idx] ep_i = ep[entry_idx] episode_done = entry_idx >= (len(ep) - 1) - action = { - 'situation': ep_i[3], - 'emotion': ep_i[2], - 'text': ep_i[0], - 'labels': [ep_i[1]], - 'prepend_ctx': ep_i[6], - 'prepend_cand': ep_i[7], - 'deepmoji_ctx': ep_i[4], - 'deepmoji_cand': ep_i[5], - 'episode_done': episode_done, - 'label_candidates': ep_i[8], - } + action = Message( + { + 'situation': ep_i[3], + 'emotion': ep_i[2], + 'text': ep_i[0], + 'labels': [ep_i[1]], + 'prepend_ctx': ep_i[6], + 'prepend_cand': ep_i[7], + 'deepmoji_ctx': ep_i[4], + 'deepmoji_cand': ep_i[5], + 'episode_done': episode_done, + 'label_candidates': ep_i[8], + } + ) + return action def share(self): @@ -268,7 +272,7 @@ def get(self, episode_idx, entry_idx=0): ex = self.data[episode_idx] episode_done = True - return {'labels': [ex[2]], 'text': ex[3], 'episode_done': episode_done} + return Message({'labels': [ex[2]], 'text': ex[3], 'episode_done': episode_done}) class DefaultTeacher(EmpatheticDialoguesTeacher):