Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TER #646

Merged
merged 19 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `WordInfoLost` and `WordInfoPreserved` ([#630](https://github.com/PyTorchLightning/metrics/pull/630))
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))
- `TER` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))


- Add a default VSCode devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621))
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@
.. _SQuAD Metric: https://arxiv.org/pdf/1606.05250.pdf
.. _chrF score: https://aclanthology.org/W15-3049.pdf
.. _chrF++ score: https://aclanthology.org/W17-4770.pdf
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
7 changes: 6 additions & 1 deletion docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,18 @@ sacre_bleu_score [func]
.. autofunction:: torchmetrics.functional.sacre_bleu_score
:noindex:


squad [func]
~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.squad
:noindex:

ter [func]
~~~~~~~~~~

.. autofunction:: torchmetrics.functional.ter
:noindex:

wer [func]
~~~~~~~~~~

Expand Down
7 changes: 6 additions & 1 deletion docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,18 @@ SacreBLEUScore
.. autoclass:: torchmetrics.SacreBLEUScore
:noindex:


SQuAD
~~~~~

.. autoclass:: torchmetrics.SQuAD
:noindex:

TER
~~~

.. autoclass:: torchmetrics.TER
:noindex:

WER
~~~

Expand Down
185 changes: 185 additions & 0 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from functools import partial
from typing import Sequence

import pytest
from torch import Tensor, tensor

from tests.text.helpers import INPUT_ORDER, TextTester
from torchmetrics.functional.text.ter import ter
from torchmetrics.text.ter import TER
from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE

if _SACREBLEU_AVAILABLE:
from sacrebleu.metrics import TER as SacreTER

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted
# EXAMPLE 1
HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party"
REFERENCE_1A = "It is a guide to action that ensures that the military will forever heed Party commands"
REFERENCE_2A = "It is a guiding principle which makes the military forces always being under the command of the Party"

# EXAMPLE 2
HYPOTHESIS_B = "he read The Book because he was interested in World history"
REFERENCE_1B = "he was interested in world history because he read the book"
REFERENCE_2B = "It is the practical guide for the army always to heed the directions of the party"

# EXAMPLE 3 (add intentionally whitespaces)
HYPOTHESIS_C = "the cat the cat on the mat "
REFERENCE_1C = "the cat is on the mat "
REFERENCE_2C = "there is a cat on the mat"
stancld marked this conversation as resolved.
Show resolved Hide resolved

TUPLE_OF_REFERENCES = (
((REFERENCE_1A, REFERENCE_2A), (REFERENCE_1B, REFERENCE_2B)),
((REFERENCE_1B, REFERENCE_2B), (REFERENCE_1C, REFERENCE_2C)),
)
TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C))

BATCHES = {"preds": TUPLE_OF_HYPOTHESES, "targets": TUPLE_OF_REFERENCES}


def sacrebleu_ter_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
normalized: bool,
no_punct: bool,
asian_support: bool,
case_sensitive: bool,
) -> Tensor:
sacrebleu_ter = SacreTER(
normalized=normalized, no_punct=no_punct, asian_support=asian_support, case_sensitive=case_sensitive
)
# Sacrebleu CHRF expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
sacrebleu_ter = sacrebleu_ter.corpus_score(preds, targets).score / 100
return tensor(sacrebleu_ter)


@pytest.mark.parametrize(
["normalize", "no_punctuation", "asian_support", "lowercase"],
[
(False, False, False, False),
(True, False, False, False),
(False, True, False, False),
(False, False, True, False),
(False, False, False, True),
(True, True, True, True),
],
)
@pytest.mark.parametrize(
["preds", "targets"],
[(BATCHES["preds"], BATCHES["targets"])],
)
@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu")
class TestTER(TextTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_chrf_score_class(
self, ddp, dist_sync_on_step, preds, targets, normalize, no_punctuation, asian_support, lowercase
):
metric_args = {
"normalize": normalize,
"no_punctuation": no_punctuation,
"asian_support": asian_support,
"lowercase": lowercase,
}
nltk_metric = partial(
sacrebleu_ter_fn,
normalized=normalize,
no_punct=no_punctuation,
asian_support=asian_support,
case_sensitive=not lowercase,
)

self.run_class_metric_test(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=TER,
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)

def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
metric_args = {
"normalize": normalize,
"no_punctuation": no_punctuation,
"asian_support": asian_support,
"lowercase": lowercase,
}
nltk_metric = partial(
sacrebleu_ter_fn,
normalized=normalize,
no_punct=no_punctuation,
asian_support=asian_support,
case_sensitive=not lowercase,
)

self.run_functional_metric_test(
preds,
targets,
metric_functional=ter,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
metric_args = {
"normalize": normalize,
"no_punctuation": no_punctuation,
"asian_support": asian_support,
"lowercase": lowercase,
}

self.run_differentiability_test(
preds=preds,
targets=targets,
metric_module=TER,
metric_functional=ter,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TER()
hyp = []
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TER()
hyp = ["python"]
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)


def test_ter_return_sentence_level_score_functional():
hyp = [HYPOTHESIS_B]
ref = [[REFERENCE_1B, REFERENCE_2B]]
_, sentence_ter = ter(ref, hyp, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TER(return_sentence_level_score=True)
hyp = [HYPOTHESIS_B]
ref = [[REFERENCE_1B, REFERENCE_2B]]
_, sentence_ter = ter_metric(ref, hyp)
isinstance(sentence_ter, Tensor)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
RetrievalRPrecision,
)
from torchmetrics.text import ( # noqa: E402
TER,
WER,
BLEUScore,
CharErrorRate,
Expand Down Expand Up @@ -144,6 +145,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TER",
"WER",
"CharErrorRate",
"MatchErrorRate",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.squad import squad
from torchmetrics.functional.text.ter import ter
from torchmetrics.functional.text.wer import wer
from torchmetrics.functional.text.wil import word_information_lost
from torchmetrics.functional.text.wip import word_information_preserved
Expand Down Expand Up @@ -136,6 +137,7 @@
"ssim",
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"ter",
"wer",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
from torchmetrics.functional.text.ter import ter # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
from torchmetrics.functional.text.wil import word_information_lost # noqa: F401
from torchmetrics.functional.text.wip import word_information_preserved # noqa: F401
15 changes: 3 additions & 12 deletions torchmetrics/functional/text/chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import torch
from torch import Tensor, tensor

from torchmetrics.functional.text.helper import _validate_inputs

_EPS_SMOOTHING = tensor(1e-16)
# Taken from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py
_PUNCTUATIONS = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
Expand Down Expand Up @@ -485,18 +487,7 @@ def _chrf_score_update(
ValueError:
If length of `reference_corpus` and `hypothesis_corpus` differs.
"""
if isinstance(hypothesis_corpus, str):
hypothesis_corpus = [hypothesis_corpus]

# Ensure reference corpus is properly of a type Sequence[Sequence[str]]
if all(isinstance(ref, str) for ref in reference_corpus):
if len(hypothesis_corpus) == 1:
reference_corpus = [reference_corpus] # type: ignore
else:
reference_corpus = [[ref] for ref in reference_corpus] # type: ignore

if hypothesis_corpus and all(ref for ref in reference_corpus) and len(reference_corpus) != len(hypothesis_corpus):
raise ValueError(f"Corpus has different size {len(reference_corpus)} != {len(hypothesis_corpus)}")
reference_corpus, hypothesis_corpus = _validate_inputs(reference_corpus, hypothesis_corpus)

for (references, hypothesis) in zip(reference_corpus, hypothesis_corpus):
(
Expand Down
Loading