Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Metrics] Fairseq BLEU Re-Implementation #3518

Merged
merged 8 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 110 additions & 7 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import Counter
import functools
import datetime
import math
from typing import Union, List, Optional, Tuple, Set, Any, Dict, Counter as TCounter

import torch
Expand Down Expand Up @@ -458,7 +459,103 @@ def compute(guess: str, answers: List[str], k: int = 4) -> Optional[BleuMetric]:
return BleuMetric(score)


class FairseqBleuMetric(AverageMetric):
class FairseqBleuMetric(Metric):
"""
Re-implementation of
https://github.com/pytorch/fairseq/blob/master/fairseq/scoring/bleu.py.
"""

def __init__(
self,
pred: Union[torch.Tensor, List[int]],
ref: Union[torch.Tensor, List[int]],
pad_idx: int,
eos_idx: int,
unk_idx: int,
order: int,
):
try:
from fairseq import libbleu
from fairseq.scoring.bleu import BleuStat
import ctypes
except ImportError:
return

self.stat = BleuStat()
self.order = order

C = ctypes.cdll.LoadLibrary(libbleu.__file__)
C.bleu_zero_init(ctypes.byref(self.stat))

if not torch.is_tensor(pred):
pred = torch.LongTensor(pred)
if not torch.is_tensor(ref):
ref = torch.LongTensor(ref)

rref = ref.clone()
assert not rref.lt(0).any()
rref[rref.eq(unk_idx)] = -999

rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)

C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),
ctypes.c_size_t(pred.size(0)),
ctypes.c_void_p(pred.data_ptr()),
ctypes.c_int(pad_idx),
ctypes.c_int(eos_idx),
)

@property
def macro_average(self) -> bool:
"""
Indicates whether this metric should be macro-averaged when globally reported.
"""
return True

def __add__(self, other: Optional[FairseqBleuMetric]) -> FairseqBleuMetric:
if other is None:
return self
self.stat.match1 += other.stat.match1
self.stat.match2 += other.stat.match2
self.stat.match3 += other.stat.match3
self.stat.match4 += other.stat.match4
self.stat.count1 += other.stat.count1
self.stat.count2 += other.stat.count2
self.stat.count3 += other.stat.count3
self.stat.count4 += other.stat.count4
self.stat.predlen += other.stat.predlen
self.stat.reflen += other.stat.reflen
return self

def precision(self):
klshuster marked this conversation as resolved.
Show resolved Hide resolved
def ratio(a, b):
klshuster marked this conversation as resolved.
Show resolved Hide resolved
return a / b if b > 0 else 0

return [
ratio(self.stat.match1, self.stat.count1),
ratio(self.stat.match2, self.stat.count2),
ratio(self.stat.match3, self.stat.count3),
ratio(self.stat.match4, self.stat.count4),
]

def brevity(self):
r = self.stat.reflen / self.stat.predlen
return min(1, math.exp(1 - r))

def value(self) -> float:
"""
Reimplementation of Fairseq's score.
"""
psum = sum(
math.log(p) if p > 0 else float("-Inf")
for p in self.precision()[: self.order]
)
return self.brevity() * math.exp(psum / self.order) * 100

@staticmethod
def compute_many(
guess: torch.Tensor, answers: torch.Tensor, pad_idx, end_idx, unk_idx
Expand All @@ -467,15 +564,21 @@ def compute_many(
Return BLEU-1..4 using fairseq and tokens.
"""
try:
from fairseq.scoring import bleu as fairseqbleu
from fairseq.scoring import bleu as fairseqbleu # noqa
except ImportError:
return None

scorer = fairseqbleu.Scorer(pad_idx, end_idx, unk_idx)
answers = answers.cpu().int()
guess = guess.cpu().int()
scorer.add(answers, guess)
return [FairseqBleuMetric(scorer.score(i) / 100.0) for i in range(1, 5)]
return [
FairseqBleuMetric(
guess.cpu().int(),
answers.cpu().int(),
pad_idx,
end_idx,
unk_idx,
order=i,
)
for i in range(1, 5)
]


class RougeMetric(AverageMetric):
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def _compute_fairseq_bleu(self, batch: Batch, preds):
assert label_vec is not None, "label_vec must exist for fairseq bleu"
for i, t in enumerate(preds):
result = FairseqBleuMetric.compute_many(
t[1:],
t,
label_vec[i].unsqueeze(0),
pad_idx=self.NULL_IDX,
end_idx=self.END_IDX,
Expand Down
15 changes: 15 additions & 0 deletions parlai/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
DETECTRON_AVAILABLE = False


try:
import fairseq # noqa: F401

FAIRSEQ_AVAILABLE = True
except ImportError:
FAIRSEQ_AVAILABLE = False


def is_this_circleci():
"""
Return if we are currently running in CircleCI.
Expand Down Expand Up @@ -123,6 +131,13 @@ def skipUnlessDetectron(
return unittest.skipUnless(DETECTRON_AVAILABLE, reason)(testfn)


def skipUnlessFairseq(testfn, reason='fairseq not installed'):
"""
Decorate a test to skip unless fairseq is installed.
"""
return unittest.skipUnless(FAIRSEQ_AVAILABLE, reason)(testfn)


class retry(object):
"""
Decorator for flaky tests. Test is run up to ntries times, retrying on failure.
Expand Down
59 changes: 59 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Dict
import torch
import random

Expand All @@ -20,8 +21,10 @@
aggregate_named_reports,
InterDistinctMetric,
IntraDistinctMetric,
FairseqBleuMetric,
)
from parlai.core.torch_classifier_agent import ConfusionMatrixMetric, WeightedF1Metric
import parlai.utils.testing as testing_utils


class TestMetric(unittest.TestCase):
Expand Down Expand Up @@ -439,5 +442,61 @@ def test_intra_distinct(self):
self.assertAlmostEqual(m1 + m2, 3 / 5)


@testing_utils.skipUnlessFairseq
class TestFairseqBleuMetric(unittest.TestCase):
"""
We're just going to compare that scores from Fairseq's Bleu scorer are the same as
our scorer.
"""

def test_scorer(self):
import random

random.seed(42)
klshuster marked this conversation as resolved.
Show resolved Hide resolved

vocab_length = num_ex = 100
ex_length = 10
pad_idx = 0
eos_idx = 1
unk_idx = 2

try:
from fairseq.scoring.bleu import Scorer
from fairseq.scoring.bleu import BleuConfig

fairseq_metrics: Scorer = Scorer(
BleuConfig(pad=pad_idx, eos=eos_idx, unk=unk_idx)
)
except ImportError:
# Bleuconfig is a recent version of fairseq
fairseq_metrics: Scorer = Scorer(pad_idx, eos_idx, unk_idx)

parlai_metrics: Dict[int, FairseqBleuMetric] = {k: [] for k in range(1, 5)}

for _ in range(num_ex):
guess = torch.LongTensor(random.sample(range(vocab_length), ex_length))
klshuster marked this conversation as resolved.
Show resolved Hide resolved
answer = torch.LongTensor(random.sample(range(vocab_length), ex_length))

parlai_bleu = FairseqBleuMetric.compute_many(
guess, answer.unsqueeze(0), pad_idx, eos_idx, unk_idx
)
for i, bleu in enumerate(parlai_bleu):
parlai_metrics[i + 1].append(bleu)
fairseq_metrics.add(answer.int(), guess.int())

parlai_bleus = {}
for k, v in parlai_metrics.items():
total = v[0]
for vv in v[1:]:
total = total + vv
parlai_bleus[k] = total

fairseq_bleus = {k: fairseq_metrics.score(order=k) for k in range(1, 5)}

assert all(
parlai_bleus[k] == fairseq_bleus[k] for k in range(1, 5)
), f'{parlai_bleus}\n{fairseq_bleus}'


if __name__ == '__main__':
unittest.main()