diff --git a/parlai/tasks/wizard_of_wikipedia/agents.py b/parlai/tasks/wizard_of_wikipedia/agents.py index 29fbe815c6d..054d4a66ae3 100644 --- a/parlai/tasks/wizard_of_wikipedia/agents.py +++ b/parlai/tasks/wizard_of_wikipedia/agents.py @@ -15,7 +15,8 @@ E.g. `wizard_of_wikipedia:WizardDialogKnowledgeTeacher:random_split` """ -from typing import Optional, Tuple +from __future__ import annotations +from typing import List, Optional, Tuple from parlai.core.message import Message from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric from parlai.core.params import ParlaiParser @@ -100,6 +101,67 @@ def _path(opt, split='random_split'): return os.path.join(dp, df) +class RareWordF1Calculator: + """ + Helper class for computing F1 with an emphasis on infrequent words. + """ + + def __init__(self, corpus: str, top_p: float = 0.5): + try: + import nltk + except ImportError: + raise ImportError('Please install nltk (e.g. pip install nltk).') + words = normalize_answer(corpus).split() + self._freq_dist = nltk.FreqDist(words) + self._cutoff_count = RareWordF1Calculator._find_cutoff_count( + self._freq_dist, top_p + ) + + @staticmethod + def _find_cutoff_count(freq_dist, top_p: float) -> int: + """ + Finds the word occurance for which the cumulative occurances are `top_p` of the + overall word count. + """ + assert top_p < 1 + target = sum(freq_dist.values()) * top_p + cumul = 0 + for _, v in freq_dist.most_common(): + cumul += v + if cumul > target: + return v + raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution") + + @staticmethod + def _filter(freq_dist, cutoff: int, text: str) -> str: + """ + For words that are found in the reference distribution, filters those with an + occurrence count less than the cutoff. + """ + words = normalize_answer(text).split() + return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff]) + + def compute(self, guess: str, answers: List[str]) -> F1Metric: + guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess) + answers = [ + RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) + for a in answers + ] + if not any(len(a) for a in answers): + # no rare words in labels, set denominator to zero + return F1Metric(0, 0) + return F1Metric.compute(guess, answers) + + +def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator: + all_text = '' + data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json') + with PathManager.open(data_path) as f: + data = json.load(f) + all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' ' + return RareWordF1Calculator(all_text, top_p=0.5) + + class WizardOfWikipediaTeacher(FixedDialogTeacher): """ The default teacher; essentially reads the json file and outputs the raw data. @@ -210,6 +272,10 @@ def __init__(self, opt, shared=None): self.knowledge_separator = opt.get('include_knowledge_separator', False) self.chosen_topic_delimiter = opt.get('chosen_topic_delimiter', '\n') self.num_exs = sum(self.len_episode(i) for i in range(len(self.data))) + if shared and 'rare_word_f1' in shared: + self.rare_word_f1 = shared['rare_word_f1'] + elif self.label_type == 'response': + self.rare_word_f1 = _build_rare_word_f1(opt['datapath']) @classmethod def add_cmdline_args( @@ -258,6 +324,12 @@ def add_cmdline_args( ) return parser + def share(self): + shared = super().share() + if hasattr(self, 'rare_word_f1'): + shared['rare_word_f1'] = self.rare_word_f1 + return shared + def len_episode(self, ep): d = self.data[ep] wizard_first = 'Wizard' in d['dialog'][0]['speaker'] @@ -390,6 +462,10 @@ def custom_evaluation( model_response['text'], [teacher_action['checked_sentence']] ), ) + self.metrics.add( + 'rare_word_f1', + self.rare_word_f1.compute(model_response['text'], labels), + ) elif ( self.label_type == 'chosen_sent' and TOKEN_KNOWLEDGE in model_response['text']