Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Ease-of-use improvements to wrapper teacher #3247

Merged
merged 4 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions parlai/tasks/style_gen/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import os

from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.teachers import FixedDialogTeacher, ParlAIDialogTeacher
from parlai.core.teachers import ParlAIDialogTeacher
from parlai.tasks.style_gen.build import (
build_personality_list,
build_style_labeled_datasets,
Expand Down Expand Up @@ -105,17 +106,10 @@ class PrevCurrUttStyleTeacher(AbstractWrapperTeacher):
will be flattened into one example each.
"""

def __init__(self, opt: Opt, shared=None):
super().__init__(opt, shared)
assert isinstance(self.task, FixedDialogTeacher)

def act(self):
def _edit_action(self, act: Message) -> Message:
"""
Act on the previous observation.
Edit the fields of the action manually.
"""
act = self.task.get_orig_action()

# Edit the fields of the action manually
if 'labels' in act:
labels = act['labels']
if len(labels) != 1:
Expand All @@ -129,5 +123,4 @@ def act(self):
else:
assert 'text' not in act and act['episode_done'] is True
act.force_set('episode_done', True) # Clear the dialogue history

return self.task.process_action(act)
return act
63 changes: 44 additions & 19 deletions parlai/tasks/wrapper/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@

import copy

from abc import ABC, abstractmethod
from abc import ABC
from parlai.core.agents import create_agent_from_shared
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.teachers import create_task_agent_from_taskname, Teacher
from parlai.core.teachers import (
create_task_agent_from_taskname,
FixedDialogTeacher,
Teacher,
)
from parlai.utils.misc import warn_once


class AbstractWrapperTeacher(Teacher, ABC):
"""
Abstract teacher that will wrap around another teacher and allow for manipulating
the fields returned by the inner teacher.
Abstract teacher that wraps around another teacher.

This teacher allows for manipulating the fields returned by the inner teacher, in
the abstract self._edit_action() method that is called during self.act(). The inner
teacher must subclass FixedDialogTeacher in order to make use of that teacher's
.get_orig_action() and .process_action() methods.
"""

@classmethod
Expand Down Expand Up @@ -62,15 +71,32 @@ def __init__(self, opt: Opt, shared=None):
opt_singletask = copy.deepcopy(opt)
opt_singletask['task'] = opt['wrapper_task']
self.task = create_task_agent_from_taskname(opt_singletask)[0]
assert isinstance(self.task, FixedDialogTeacher)

@abstractmethod
def act(self):
"""
Act on the previous observation.

Typically, self.task.act() will need to be called in this method.
Normally, the inner teacher would call .get_orig_action() and .process_action();
here, we insert an ._edit_action() method in between these two methods in order
to allow for arbitrary manipulation of the action before it is registered and
processed further by the inner teacher.
"""
orig_action = self.task.get_orig_action()
edited_action = self._edit_action(orig_action)
processed_action = self.task.process_action(edited_action)
return processed_action

def _edit_action(self, act: Message) -> Message:
"""
Edit and return the input action.

The input action typically comes from the inner teacher's .get_orig_action()
method.
"""
raise NotImplementedError('Abstract class: user must implement act() method')
raise NotImplementedError(
'Abstract class: user must implement the _edit_action() method'
)

def num_examples(self):
"""
Expand Down Expand Up @@ -145,20 +171,19 @@ class LabelToTextTeacher(AbstractWrapperTeacher):
def __init__(self, opt: Opt, shared=None):
super().__init__(opt, shared)

def act(self):
def _edit_action(self, act: Message) -> Message:
"""
Act on the previous observation.
Edit the fields of the action manually.
"""
act = self.task.act()
new_act = copy.deepcopy(act)
if 'labels' in act or 'eval_labels' in act:
labels_type = 'labels' if 'labels' in act else 'eval_labels'
labels = act[labels_type]
if 'labels' in act:
labels = act['labels']
if len(labels) != 1:
raise ValueError('LabelToTextTeacher can only be used with one label!')
new_act.force_set('text', labels[0])
new_act.force_set(labels_type, [''])
raise ValueError(
f'{type(self).__name__} can only be used with one label!'
)
act.force_set('text', labels[0])
act.force_set('labels', [''])
else:
assert 'text' not in act and act['episode_done'] is True
new_act.force_set('episode_done', True) # Clear the dialogue history
return new_act
act.force_set('episode_done', True) # Clear the dialogue history
return act