Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the input order for text (NLG) metrics - BLEU, SacreBLEU, TER, CHRF #696

Merged
merged 11 commits into from
Jan 4, 2022
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463))


- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))
- Untokenized for `BLEUScore` input stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))


- Arguments reordered for `TER`, `BLEUScore`, `SacreBLEUScore`, `CHRFScore` now expect input order as predictions first and target second ([#696](https://github.com/PyTorchLightning/metrics/pull/696))


- Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695))
Expand Down
16 changes: 8 additions & 8 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
smooth_func = SmoothingFunction().method2


def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing_function, **kwargs):
def _compute_bleu_metric_nltk(hypotheses, list_of_references, weights, smoothing_function, **kwargs):
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
return corpus_bleu(
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights,
sk_metric=compute_bleu_metric_nltk,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth):
Expand All @@ -80,7 +80,7 @@ def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_fun
metric_functional=bleu_score,
sk_metric=compute_bleu_metric_nltk,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smooth_func, smooth):
Expand All @@ -92,31 +92,31 @@ def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smo
metric_module=BLEUScore,
metric_functional=bleu_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_bleu_empty_functional():
hyp = [[]]
ref = [[[]]]
assert bleu_score(ref, hyp) == tensor(0.0)
assert bleu_score(hyp, ref) == tensor(0.0)


def test_no_4_gram_functional():
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(refs, hyps) == tensor(0.0)
assert bleu_score(hyps, refs) == tensor(0.0)


def test_bleu_empty_class():
bleu = BLEUScore()
hyp = [[]]
ref = [[[]]]
assert bleu(ref, hyp) == tensor(0.0)
assert bleu(hyp, ref) == tensor(0.0)


def test_no_4_gram_class():
bleu = BLEUScore()
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(refs, hyps) == tensor(0.0)
assert bleu(hyps, refs) == tensor(0.0)
16 changes: 8 additions & 8 deletions tests/text/test_chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


def sacrebleu_chrf_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
targets: Sequence[Sequence[str]],
char_order: int,
word_order: int,
lowercase: bool,
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_chrf_score_class(
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace):
Expand All @@ -91,7 +91,7 @@ def test_chrf_score_functional(self, preds, targets, char_order, word_order, low
metric_functional=chrf_score,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace):
Expand All @@ -108,33 +108,33 @@ def test_chrf_score_differentiability(self, preds, targets, char_order, word_ord
metric_module=CHRFScore,
metric_functional=chrf_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_chrf_empty_functional():
hyp = []
ref = [[]]
assert chrf_score(ref, hyp) == tensor(0.0)
assert chrf_score(hyp, ref) == tensor(0.0)


def test_chrf_empty_class():
chrf = CHRFScore()
hyp = []
ref = [[]]
assert chrf(ref, hyp) == tensor(0.0)
assert chrf(hyp, ref) == tensor(0.0)


def test_chrf_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, chrf_sentence_score = chrf_score(ref, hyp, return_sentence_level_score=True)
_, chrf_sentence_score = chrf_score(hyp, ref, return_sentence_level_score=True)
isinstance(chrf_sentence_score, Tensor)


def test_chrf_return_sentence_level_class():
chrf = CHRFScore(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, chrf_sentence_score = chrf(ref, hyp)
_, chrf_sentence_score = chrf(hyp, ref)
isinstance(chrf_sentence_score, Tensor)
8 changes: 4 additions & 4 deletions tests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TOKENIZERS = ("none", "13a", "zh", "intl", "char")


def sacrebleu_fn(targets: Sequence[Sequence[str]], preds: Sequence[str], tokenize: str, lowercase: bool) -> Tensor:
def sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool) -> Tensor:
sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase)
# Sacrebleu expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize
sk_metric=original_sacrebleu,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_functional(self, preds, targets, tokenize, lowercase):
Expand All @@ -74,7 +74,7 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase):
metric_functional=sacre_bleu_score,
sk_metric=original_sacrebleu,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase):
Expand All @@ -86,5 +86,5 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase)
metric_module=SacreBLEUScore,
metric_functional=sacre_bleu_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)
20 changes: 10 additions & 10 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


def sacrebleu_ter_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
targets: Sequence[Sequence[str]],
normalized: bool,
no_punct: bool,
asian_support: bool,
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_chrf_score_class(
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
Expand All @@ -99,7 +99,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a
metric_functional=ter,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
Expand All @@ -116,46 +116,46 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu
metric_module=TER,
metric_functional=ter,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)
assert ter(hyp, ref) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TER()
hyp = []
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)
assert ter_metric(hyp, ref) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)
assert ter(hyp, ref) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TER()
hyp = ["python"]
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)
assert ter_metric(hyp, ref) == tensor(0.0)


def test_ter_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter(ref, hyp, return_sentence_level_score=True)
_, sentence_ter = ter(hyp, ref, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TER(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(ref, hyp)
_, sentence_ter = ter_metric(hyp, ref)
isinstance(sentence_ter, Tensor)
28 changes: 19 additions & 9 deletions torchmetrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
import warnings
from collections import Counter
from typing import Callable, Sequence, Tuple, Union

Expand Down Expand Up @@ -57,8 +58,8 @@ def _tokenize_fn(sentence: str) -> Sequence[str]:


def _bleu_score_update(
reference_corpus: Sequence[Sequence[str]],
translate_corpus: Sequence[str],
reference_corpus: Sequence[Sequence[str]],
numerator: Tensor,
denominator: Tensor,
trans_len: Tensor,
Expand All @@ -69,11 +70,11 @@ def _bleu_score_update(
"""Updates and returns variables required to compute the BLEU score.

Args:
reference_corpus: An iterable of iterables of reference corpus
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
numerator: Numerator of precision score (true positives)
denominator: Denominator of precision score (true positives + false positives)
trans_len: count of words in a candidate translation
trans_len: count of words in a candidate prediction
ref_len: count of words in a reference translation
n_gram: gram value ranged 1 to 4
tokenizer: A function that turns sentence into list of words
Expand Down Expand Up @@ -106,7 +107,12 @@ def _bleu_score_update(


def _bleu_score_compute(
trans_len: Tensor, ref_len: Tensor, numerator: Tensor, denominator: Tensor, n_gram: int = 4, smooth: bool = False
trans_len: Tensor,
ref_len: Tensor,
numerator: Tensor,
denominator: Tensor,
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
"""Computes the BLEU score.

Expand Down Expand Up @@ -140,18 +146,18 @@ def _bleu_score_compute(


def bleu_score(
reference_corpus: Sequence[Union[str, Sequence[str]]],
translate_corpus: Union[str, Sequence[str]],
reference_corpus: Sequence[Union[str, Sequence[str]]],
n_gram: int = 4,
smooth: bool = False,
) -> Tensor:
"""Calculate `BLEU score`_ of machine translated text with one or more references.

Args:
reference_corpus:
An iterable of iterables of reference corpus
translate_corpus:
An iterable of machine translated corpus
reference_corpus:
An iterable of iterables of reference corpus
n_gram:
Gram value ranged from 1 to 4 (Default 4)
smooth:
Expand All @@ -164,7 +170,7 @@ def bleu_score(
>>> from torchmetrics.functional import bleu_score
>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu_score(reference_corpus, translate_corpus)
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)

References:
Expand All @@ -174,6 +180,10 @@ def bleu_score(
[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""
warnings.warn(
"Input order of targets and preds were changed to predictions firsts and targets second in v0.7."
" Warning will be removed in v0.8."
)
translate_corpus_ = [translate_corpus] if isinstance(translate_corpus, str) else translate_corpus
reference_corpus_ = [
[reference_text] if isinstance(reference_text, str) else reference_text for reference_text in reference_corpus
Expand All @@ -188,7 +198,7 @@ def bleu_score(
ref_len = tensor(0, dtype=torch.float)

trans_len, ref_len = _bleu_score_update(
reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn
translate_corpus_, reference_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn
)

return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth)
Loading