diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 475b491fd33..0f0047471d9 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -14,6 +14,7 @@ from collections import Counter import functools import datetime +import math from typing import Union, List, Optional, Tuple, Set, Any, Dict, Counter as TCounter import torch @@ -458,7 +459,106 @@ def compute(guess: str, answers: List[str], k: int = 4) -> Optional[BleuMetric]: return BleuMetric(score) -class FairseqBleuMetric(AverageMetric): +class FairseqBleuMetric(Metric): + """ + Re-implementation of + https://github.com/pytorch/fairseq/blob/master/fairseq/scoring/bleu.py. + """ + + def __init__( + self, + pred: Union[torch.Tensor, List[int]], + ref: Union[torch.Tensor, List[int]], + pad_idx: int, + eos_idx: int, + unk_idx: int, + order: int, + ): + try: + from fairseq import libbleu + from fairseq.scoring.bleu import BleuStat + import ctypes + except ImportError: + return + + self.stat = BleuStat() + self.order = order + + C = ctypes.cdll.LoadLibrary(libbleu.__file__) + C.bleu_zero_init(ctypes.byref(self.stat)) + + if not torch.is_tensor(pred): + pred = torch.LongTensor(pred) + if not torch.is_tensor(ref): + ref = torch.LongTensor(ref) + + rref = ref.clone() + assert not rref.lt(0).any() + rref[rref.eq(unk_idx)] = -999 + + rref = rref.contiguous().view(-1) + pred = pred.contiguous().view(-1) + + C.bleu_add( + ctypes.byref(self.stat), + ctypes.c_size_t(rref.size(0)), + ctypes.c_void_p(rref.data_ptr()), + ctypes.c_size_t(pred.size(0)), + ctypes.c_void_p(pred.data_ptr()), + ctypes.c_int(pad_idx), + ctypes.c_int(eos_idx), + ) + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return True + + def __add__(self, other: Optional[FairseqBleuMetric]) -> FairseqBleuMetric: + if other is None: + return self + self.stat.match1 += other.stat.match1 + self.stat.match2 += other.stat.match2 + self.stat.match3 += other.stat.match3 + self.stat.match4 += other.stat.match4 + self.stat.count1 += other.stat.count1 + self.stat.count2 += other.stat.count2 + self.stat.count3 += other.stat.count3 + self.stat.count4 += other.stat.count4 + self.stat.predlen += other.stat.predlen + self.stat.reflen += other.stat.reflen + return self + + def _ratio(self, a: int, b: int) -> float: + """ + Safe division. + """ + return a / b if b > 0 else 0 + + def _precision(self): + return [ + self._ratio(self.stat.match1, self.stat.count1), + self._ratio(self.stat.match2, self.stat.count2), + self._ratio(self.stat.match3, self.stat.count3), + self._ratio(self.stat.match4, self.stat.count4), + ] + + def _brevity(self): + r = self.stat.reflen / self.stat.predlen + return min(1, math.exp(1 - r)) + + def value(self) -> float: + """ + Reimplementation of Fairseq's score. + """ + psum = sum( + math.log(p) if p > 0 else float("-Inf") + for p in self._precision()[: self.order] + ) + return self._brevity() * math.exp(psum / self.order) * 100 + @staticmethod def compute_many( guess: torch.Tensor, answers: torch.Tensor, pad_idx, end_idx, unk_idx @@ -467,15 +567,21 @@ def compute_many( Return BLEU-1..4 using fairseq and tokens. """ try: - from fairseq.scoring import bleu as fairseqbleu + from fairseq.scoring import bleu as fairseqbleu # noqa except ImportError: return None - scorer = fairseqbleu.Scorer(pad_idx, end_idx, unk_idx) - answers = answers.cpu().int() - guess = guess.cpu().int() - scorer.add(answers, guess) - return [FairseqBleuMetric(scorer.score(i) / 100.0) for i in range(1, 5)] + return [ + FairseqBleuMetric( + guess.cpu().int(), + answers.cpu().int(), + pad_idx, + end_idx, + unk_idx, + order=i, + ) + for i in range(1, 5) + ] class RougeMetric(AverageMetric): diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 1aaa342cc2f..75209bce14f 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -801,7 +801,7 @@ def _compute_fairseq_bleu(self, batch: Batch, preds): assert label_vec is not None, "label_vec must exist for fairseq bleu" for i, t in enumerate(preds): result = FairseqBleuMetric.compute_many( - t[1:], + t, label_vec[i].unsqueeze(0), pad_idx=self.NULL_IDX, end_idx=self.END_IDX, diff --git a/parlai/utils/testing.py b/parlai/utils/testing.py index 4048ede3053..7adfce0d909 100644 --- a/parlai/utils/testing.py +++ b/parlai/utils/testing.py @@ -63,6 +63,14 @@ DETECTRON_AVAILABLE = False +try: + import fairseq # noqa: F401 + + FAIRSEQ_AVAILABLE = True +except ImportError: + FAIRSEQ_AVAILABLE = False + + def is_this_circleci(): """ Return if we are currently running in CircleCI. @@ -123,6 +131,13 @@ def skipUnlessDetectron( return unittest.skipUnless(DETECTRON_AVAILABLE, reason)(testfn) +def skipUnlessFairseq(testfn, reason='fairseq not installed'): + """ + Decorate a test to skip unless fairseq is installed. + """ + return unittest.skipUnless(FAIRSEQ_AVAILABLE, reason)(testfn) + + class retry(object): """ Decorator for flaky tests. Test is run up to ntries times, retrying on failure. diff --git a/tests/test_metrics.py b/tests/test_metrics.py index de0e61930a4..c0818c18b37 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from typing import Dict import torch import random @@ -20,8 +21,10 @@ aggregate_named_reports, InterDistinctMetric, IntraDistinctMetric, + FairseqBleuMetric, ) from parlai.core.torch_classifier_agent import ConfusionMatrixMetric, WeightedF1Metric +import parlai.utils.testing as testing_utils class TestMetric(unittest.TestCase): @@ -439,5 +442,59 @@ def test_intra_distinct(self): self.assertAlmostEqual(m1 + m2, 3 / 5) +@testing_utils.skipUnlessFairseq +class TestFairseqBleuMetric(unittest.TestCase): + """ + We're just going to compare that scores from Fairseq's Bleu scorer are the same as + our scorer. + """ + + def test_scorer(self): + import random + + vocab_length = num_ex = 100 + ex_length = 10 + pad_idx = 0 + eos_idx = 1 + unk_idx = 2 + + try: + from fairseq.scoring.bleu import Scorer + from fairseq.scoring.bleu import BleuConfig + + fairseq_metrics: Scorer = Scorer( + BleuConfig(pad=pad_idx, eos=eos_idx, unk=unk_idx) + ) + except ImportError: + # Bleuconfig is a recent version of fairseq + fairseq_metrics: Scorer = Scorer(pad_idx, eos_idx, unk_idx) + + parlai_metrics: Dict[int, FairseqBleuMetric] = {k: [] for k in range(1, 5)} + + for _ in range(num_ex): + guess = torch.LongTensor(random.sample(range(vocab_length), ex_length)) + answer = torch.LongTensor(random.sample(range(vocab_length), ex_length)) + + parlai_bleu = FairseqBleuMetric.compute_many( + guess, answer.unsqueeze(0), pad_idx, eos_idx, unk_idx + ) + for i, bleu in enumerate(parlai_bleu): + parlai_metrics[i + 1].append(bleu) + fairseq_metrics.add(answer.int(), guess.int()) + + parlai_bleus = {} + for k, v in parlai_metrics.items(): + total = v[0] + for vv in v[1:]: + total = total + vv + parlai_bleus[k] = total + + fairseq_bleus = {k: fairseq_metrics.score(order=k) for k in range(1, 5)} + + assert all( + parlai_bleus[k] == fairseq_bleus[k] for k in range(1, 5) + ), f'{parlai_bleus}\n{fairseq_bleus}' + + if __name__ == '__main__': unittest.main()