Skip to content

Commit

Permalink
fix: correct reference length calculation (#195)
Browse files Browse the repository at this point in the history
Summary:
This PR fixes the way brevity penalty (specifically the effective reference corpus length) is calculated in BLEU.

Previously, `len_reference` was calculated as `min([len(ref) for ref in references_tokenized])`. However, this is incorrect, because according to the paper, we need to find the "best match length", not the minimum reference length.

For more information, see [wikipedia - brevity penalty](https://en.wikipedia.org/wiki/BLEU#Brevity_penalty) and [nltk implementation](https://www.nltk.org/_modules/nltk/translate/bleu_score.html#closest_ref_length).

Pull Request resolved: #195

Test Plan: I added another unit test to `test_bleu.py` and compared the results of the calculations to the results of the `nltk.translate.bleu_score.corpus_bleu` function to make sure the implementation is correct.

Reviewed By: galrotem

Differential Revision: D56846091

Pulled By: JKSenthil

fbshipit-source-id: 2bf1cd0ba169535a118222e60f4264259248f1fd
  • Loading branch information
yuxqiu authored and facebook-github-bot committed May 14, 2024
1 parent cb6bc39 commit ea813d3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
29 changes: 29 additions & 0 deletions tests/metrics/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,32 @@ def test_bleu_multiple_examples_per_update(self) -> None:
num_total_updates=2,
num_processes=2,
)

def test_bleu_brevity(self) -> None:
candidates = [["the squirrel is eating the nut"], ["the cat is on mat"]]
references = [
[
[
"a squirrel is eating a nut",
"the squirrel is eating a tasty nut",
"hi",
]
],
[["there is a cat on the mat", "a cat is on the mat"]],
]
self.run_class_implementation_tests(
metric=BLEUScore(n_gram=4),
state_names={
"input_len",
"target_len",
"matches_by_order",
"possible_matches_by_order",
},
update_kwargs={
"input": candidates,
"target": references,
},
compute_result=torch.tensor(0.41650065, dtype=torch.float64),
num_total_updates=2,
num_processes=2,
)
5 changes: 4 additions & 1 deletion torcheval/metrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def _bleu_score_update(
references_tokenized = [ref.split() for ref in references]

len_candidate = len(candidate_tokenized)
len_reference = min([len(ref) for ref in references_tokenized])
len_reference = min(
[len(ref) for ref in references_tokenized],
key=lambda ref_len: (abs(ref_len - len_candidate), ref_len),
)
input_len += len_candidate
target_len += len_reference

Expand Down

0 comments on commit ea813d3

Please sign in to comment.