diff --git a/parlai/tasks/wizard_of_wikipedia/agents.py b/parlai/tasks/wizard_of_wikipedia/agents.py index 054d4a66ae3..3e173949d27 100644 --- a/parlai/tasks/wizard_of_wikipedia/agents.py +++ b/parlai/tasks/wizard_of_wikipedia/agents.py @@ -16,7 +16,7 @@ """ from __future__ import annotations -from typing import List, Optional, Tuple +from typing import Iterable, Optional, Tuple from parlai.core.message import Message from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric from parlai.core.params import ParlaiParser @@ -141,7 +141,9 @@ def _filter(freq_dist, cutoff: int, text: str) -> str: words = normalize_answer(text).split() return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff]) - def compute(self, guess: str, answers: List[str]) -> F1Metric: + def compute(self, guess: str, answers: Iterable[str]) -> F1Metric: + if guess is None or answers is None: + return F1Metric(0, 0) guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess) answers = [ RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) @@ -462,10 +464,11 @@ def custom_evaluation( model_response['text'], [teacher_action['checked_sentence']] ), ) - self.metrics.add( - 'rare_word_f1', - self.rare_word_f1.compute(model_response['text'], labels), - ) + if labels: + self.metrics.add( + 'rare_word_f1', + self.rare_word_f1.compute(model_response['text'], labels), + ) elif ( self.label_type == 'chosen_sent' and TOKEN_KNOWLEDGE in model_response['text']