diff --git a/parlai/tasks/style_gen/agents.py b/parlai/tasks/style_gen/agents.py index 2016560d3af..9c8eb799768 100644 --- a/parlai/tasks/style_gen/agents.py +++ b/parlai/tasks/style_gen/agents.py @@ -6,8 +6,9 @@ import os +from parlai.core.message import Message from parlai.core.opt import Opt -from parlai.core.teachers import FixedDialogTeacher, ParlAIDialogTeacher +from parlai.core.teachers import ParlAIDialogTeacher from parlai.tasks.style_gen.build import ( build_personality_list, build_style_labeled_datasets, @@ -105,17 +106,10 @@ class PrevCurrUttStyleTeacher(AbstractWrapperTeacher): will be flattened into one example each. """ - def __init__(self, opt: Opt, shared=None): - super().__init__(opt, shared) - assert isinstance(self.task, FixedDialogTeacher) - - def act(self): + def _edit_action(self, act: Message) -> Message: """ - Act on the previous observation. + Edit the fields of the action manually. """ - act = self.task.get_orig_action() - - # Edit the fields of the action manually if 'labels' in act: labels = act['labels'] if len(labels) != 1: @@ -129,5 +123,4 @@ def act(self): else: assert 'text' not in act and act['episode_done'] is True act.force_set('episode_done', True) # Clear the dialogue history - - return self.task.process_action(act) + return act diff --git a/parlai/tasks/wrapper/agents.py b/parlai/tasks/wrapper/agents.py index 0b13f2db906..f52fec41d6a 100644 --- a/parlai/tasks/wrapper/agents.py +++ b/parlai/tasks/wrapper/agents.py @@ -18,17 +18,26 @@ import copy -from abc import ABC, abstractmethod +from abc import ABC from parlai.core.agents import create_agent_from_shared +from parlai.core.message import Message from parlai.core.opt import Opt -from parlai.core.teachers import create_task_agent_from_taskname, Teacher +from parlai.core.teachers import ( + create_task_agent_from_taskname, + FixedDialogTeacher, + Teacher, +) from parlai.utils.misc import warn_once class AbstractWrapperTeacher(Teacher, ABC): """ - Abstract teacher that will wrap around another teacher and allow for manipulating - the fields returned by the inner teacher. + Abstract teacher that wraps around another teacher. + + This teacher allows for manipulating the fields returned by the inner teacher, in + the abstract self._edit_action() method that is called during self.act(). The inner + teacher must subclass FixedDialogTeacher in order to make use of that teacher's + .get_orig_action() and .process_action() methods. """ @classmethod @@ -62,15 +71,32 @@ def __init__(self, opt: Opt, shared=None): opt_singletask = copy.deepcopy(opt) opt_singletask['task'] = opt['wrapper_task'] self.task = create_task_agent_from_taskname(opt_singletask)[0] + assert isinstance(self.task, FixedDialogTeacher) - @abstractmethod def act(self): """ Act on the previous observation. - Typically, self.task.act() will need to be called in this method. + Normally, the inner teacher would call .get_orig_action() and .process_action(); + here, we insert an ._edit_action() method in between these two methods in order + to allow for arbitrary manipulation of the action before it is registered and + processed further by the inner teacher. + """ + orig_action = self.task.get_orig_action() + edited_action = self._edit_action(orig_action) + processed_action = self.task.process_action(edited_action) + return processed_action + + def _edit_action(self, act: Message) -> Message: + """ + Edit and return the input action. + + The input action typically comes from the inner teacher's .get_orig_action() + method. """ - raise NotImplementedError('Abstract class: user must implement act() method') + raise NotImplementedError( + 'Abstract class: user must implement the _edit_action() method' + ) def num_examples(self): """ @@ -145,20 +171,19 @@ class LabelToTextTeacher(AbstractWrapperTeacher): def __init__(self, opt: Opt, shared=None): super().__init__(opt, shared) - def act(self): + def _edit_action(self, act: Message) -> Message: """ - Act on the previous observation. + Edit the fields of the action manually. """ - act = self.task.act() - new_act = copy.deepcopy(act) - if 'labels' in act or 'eval_labels' in act: - labels_type = 'labels' if 'labels' in act else 'eval_labels' - labels = act[labels_type] + if 'labels' in act: + labels = act['labels'] if len(labels) != 1: - raise ValueError('LabelToTextTeacher can only be used with one label!') - new_act.force_set('text', labels[0]) - new_act.force_set(labels_type, ['']) + raise ValueError( + f'{type(self).__name__} can only be used with one label!' + ) + act.force_set('text', labels[0]) + act.force_set('labels', ['']) else: assert 'text' not in act and act['episode_done'] is True - new_act.force_set('episode_done', True) # Clear the dialogue history - return new_act + act.force_set('episode_done', True) # Clear the dialogue history + return act