Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[WoW] Update teachers to use DialogTeacher #4284

Merged
merged 9 commits into from
Mar 30, 2022
117 changes: 51 additions & 66 deletions parlai/tasks/wizard_of_wikipedia/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
"""

from __future__ import annotations
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional, Tuple, Dict, Any
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
import copy
from parlai.core.teachers import FixedDialogTeacher, MultiTaskTeacher
from parlai.core.teachers import DialogTeacher, MultiTaskTeacher
from parlai.utils.io import PathManager
from parlai.utils import logging
from parlai.utils.misc import warn_once
Expand Down Expand Up @@ -186,7 +186,7 @@ def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator:
return RareWordF1Calculator(all_text, top_p=0.5)


class WizardOfWikipediaTeacher(FixedDialogTeacher):
class WizardOfWikipediaTeacher(DialogTeacher):
"""
The default teacher; essentially reads the json file and outputs the raw data.

Expand All @@ -211,40 +211,34 @@ class WizardOfWikipediaTeacher(FixedDialogTeacher):
"""

def __init__(self, opt, shared=None):
super().__init__(opt, shared)
self.opt = opt
task = opt.get('task', 'wizard_of_wikipedia:WizardOfWikipedia:random_split')
split = task.split(':')
split = split[2] if len(split) == 3 else 'random_split'
opt['task'] = 'wizard_of_wikipedia'
if shared and 'data' in shared:
self.data = shared['data']
else:
self.data_path = _path(opt, split=split)
self._setup_data()
self.num_exs = sum(len(d['dialog']) for d in self.data)
self.reset()

def _setup_data(self):
print('loading: ' + self.data_path)
with PathManager.open(self.data_path) as f:
self.data = json.load(f)

def num_episodes(self):
return len(self.data)

def num_examples(self):
return self.num_exs

def get(self, episode_idx, entry_idx=0):
d = self.data[episode_idx]
dialog_entry = d['dialog'][entry_idx]
episode_done = entry_idx == len(d['dialog']) - 1
opt['datafile'] = _path(opt, split=split)
super().__init__(opt, shared)

def setup_data(self, datafile):
logging.info(f'loading {datafile}')
with PathManager.open(datafile) as f:
self.raw_data = json.load(f)
for episode_idx in range(len(self.raw_data)):
for entry_idx in range(self.len_episode(episode_idx)):
ex = self._format_example(episode_idx, entry_idx)
ex.pop('episode_done', '')
if 'label_candidates' in ex and not ex['label_candidates']:
ex.pop('label_candidates')
yield ex, entry_idx == 0

def _format_example(self, episode_idx: int, entry_idx: int) -> Message:
episode = self.raw_data[episode_idx]
dialog_entry = episode['dialog'][entry_idx]
episode_done = entry_idx == len(episode['dialog']) - 1
action = Message(
{
'wizard_eval': d['wizard_eval'],
'chosen_topic': d['chosen_topic'],
'chosen_topic_passage': d['chosen_topic_passage'],
'wizard_eval': episode['wizard_eval'],
'chosen_topic': episode['chosen_topic'],
'chosen_topic_passage': episode['chosen_topic_passage'],
'text': dialog_entry['text'],
'retrieved_topics': dialog_entry['retrieved_topics'],
'retrieved_passages': dialog_entry['retrieved_passages'],
Expand All @@ -261,6 +255,14 @@ def share(self):
shared['data'] = self.data
return shared

def len_episode(self, ep: int) -> int:
"""
Length of an episode.

Optionally overrideable.
"""
return len(self.raw_data[ep])


###############################################################
# #
Expand Down Expand Up @@ -292,17 +294,16 @@ class WizardDialogKnowledgeTeacher(WizardOfWikipediaTeacher):

def __init__(self, opt, shared=None):
self.add_missing_turns = opt.get('add_missing_turns', 'none')
super().__init__(opt, shared)
self.label_type = opt.get('label_type', 'response')
self.include_knowledge = opt.get('include_knowledge', True)
self.include_checked_sentence = opt.get('include_checked_sentence', False)
self.knowledge_separator = opt.get('include_knowledge_separator', False)
self.chosen_topic_delimiter = opt.get('chosen_topic_delimiter', '\n')
self.num_exs = sum(self.len_episode(i) for i in range(len(self.data)))
if shared and 'rare_word_f1' in shared:
self.rare_word_f1 = shared['rare_word_f1']
elif self.label_type == 'response':
self.rare_word_f1 = _build_rare_word_f1(opt['datapath'])
super().__init__(opt, shared)

@classmethod
def add_cmdline_args(
Expand Down Expand Up @@ -367,7 +368,7 @@ def share(self):
return shared

def len_episode(self, ep):
d = self.data[ep]
d = self.raw_data[ep]
wizard_first = 'Wizard' in d['dialog'][0]['speaker']
if wizard_first:
if self.add_missing_turns == 'none':
Expand All @@ -384,11 +385,8 @@ def len_episode(self, ep):
return len_ep
return len(d['dialog']) // 2

def num_examples(self):
return self.num_exs

def get(self, episode_idx, entry_idx=0):
d = self.data[episode_idx]
def _format_example(self, episode_idx, entry_idx=0):
d = self.raw_data[episode_idx]
episode_done = entry_idx == (self.len_episode(episode_idx) - 1)

wizard_first = 'Wizard' in d['dialog'][0]['speaker']
Expand Down Expand Up @@ -511,7 +509,7 @@ def custom_evaluation(
model_response['text'], [teacher_action['checked_sentence']]
),
)
if labels:
if labels and hasattr(self, 'rare_word_f1'):
self.metrics.add(
'rare_word_f1',
self.rare_word_f1.compute(model_response['text'], labels),
Expand Down Expand Up @@ -565,10 +563,9 @@ class BasicdialogTeacher(WizardOfWikipediaTeacher):

def __init__(self, opt, shared=None):
self.add_missing_turns = opt.get('add_missing_turns', 'none')
super().__init__(opt, shared)
self.speaker_label = opt.get('speaker_label', 'both')
self.add_topic = opt.get('add_topic', False)
self.num_exs = sum(self.len_episode(i) for i in range(len(self.data)))
super().__init__(opt, shared)

@classmethod
def add_cmdline_args(
Expand Down Expand Up @@ -600,11 +597,8 @@ def add_cmdline_args(
)
return parser

def num_examples(self):
return self.num_exs

def len_episode(self, ep):
d = self.data[ep]
d = self.raw_data[ep]
first_speaker = d['dialog'][0]['speaker'].lower()
if self.speaker_label != 'both' and self.speaker_label in first_speaker:
if self.add_missing_turns == 'none':
Expand All @@ -621,8 +615,8 @@ def len_episode(self, ep):
return len_ep
return len(d['dialog']) // 2

def get(self, episode_idx, entry_idx=0):
d = self.data[episode_idx]
def _format_example(self, episode_idx, entry_idx=0):
d = self.raw_data[episode_idx]
episode_done = entry_idx == (self.len_episode(episode_idx) - 1)

idx = entry_idx * 2
Expand Down Expand Up @@ -696,12 +690,12 @@ class GeneratorTeacher(WizardDialogKnowledgeTeacher):
def __init__(self, opt, shared=None):
opt['label_type'] = 'response'
opt['include_checked_sentence'] = True
super().__init__(opt, shared)
self.knowledge_separator = opt.get('include_knowledge_separator', True)
self.only_checked_knowledge = opt.get('only_checked_knowledge', False)
self.prepend_gold_knowledge = opt.get('prepend_gold_knowledge')
self.gold_knowledge_delimiter = opt.get('gold_knowledge_delimiter', '\n')
self.dropout = opt.get('ignorant_dropout', 0.0)
super().__init__(opt, shared)

@classmethod
def add_cmdline_args(
Expand Down Expand Up @@ -740,8 +734,8 @@ def add_cmdline_args(
def getID(self):
return "WizTeacher"

def get(self, episode_idx, entry_idx=0):
a = super().get(episode_idx, entry_idx)
def _format_example(self, episode_idx, entry_idx=0):
a = super()._format_example(episode_idx, entry_idx)
# zero out the label candidates?
if 'knowledge' not in a:
# just a batch padding item
Expand Down Expand Up @@ -799,7 +793,6 @@ class WikiPageTitleTeacher(WizardDialogKnowledgeTeacher):
def __init__(self, opt, shared=None):
self.opt = copy.deepcopy(opt)
self.opt['label_type'] = 'response'
super().__init__(self.opt, shared=shared)
self.id = 'WikiPageTitleTeacher'
self._conv_history_len = self.opt['conversation_history_length']
if not (self._conv_history_len > 0 or self._conv_history_len == -1):
Expand All @@ -809,6 +802,7 @@ def __init__(self, opt, shared=None):
)
self._conv_history_len = -1
self._skip_no_title = self.opt['skip_no_title']
super().__init__(self.opt, shared=shared)
if not shared:
self._preprocess_data()
else:
Expand Down Expand Up @@ -863,7 +857,7 @@ def _preprocess_data(self):
dialog_history = []
ex_idx = 0
while True:
a = super().get(episode_idx, ex_idx)
a = super()._format_example(episode_idx, ex_idx)
text_parts = a['text'].split('\n')
if ex_idx == 0:
# throwing away chosen_topic
Expand All @@ -884,13 +878,7 @@ def _preprocess_data(self):
)
self.titles_data = data

def num_episodes(self):
return len(self.titles_data)

def num_examples(self):
return self.num_episodes()

def get(self, episode_idx, entry_idx=0):
def _format_example(self, episode_idx, entry_idx=0):
return self.titles_data[episode_idx]


Expand Down Expand Up @@ -945,7 +933,7 @@ def __init__(self, opt, shared=None):
# get number of examples
self.num_exs = 0
for ep in range(self.num_episodes()):
d = self.data[ep]
d = self.raw_data[ep]
for entry in d['dialog']:
if (
entry.get('checked_sentence', None) is not None
Expand Down Expand Up @@ -1139,9 +1127,6 @@ def get_span(self, one, two):
)
return max_span

def num_examples(self):
return self.num_exs

def length_episode(self, dialog):
len_ep = 0
idxs = []
Expand All @@ -1158,8 +1143,8 @@ def length_episode(self, dialog):

return len_ep, idxs

def get(self, episode_idx, entry_idx=0):
d = self.data[episode_idx]
def _format_example(self, episode_idx, entry_idx=0):
d = self.raw_data[episode_idx]
len_ep, idxs = self.length_episode(d)
idx = idxs[entry_idx]

Expand Down