Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify preds, target input arguments for text metrics [2of2] cer, ter, wer, mer, rouge, squad #727

Merged
merged 19 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/text/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import namedtuple

Input = namedtuple("Input", ["preds", "targets"])
SquadInput = namedtuple("SquadInput", ["preds", "targets", "exact_match", "f1"])

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted
Expand Down Expand Up @@ -64,6 +65,45 @@

_inputs_error_rate_batch_size_2 = Input(**ERROR_RATES_BATCHES_2)

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"preds": {"prediction_text": "1976", "id": "id1"},
"targets": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"preds": {"prediction_text": "Hello", "id": "id2"},
"targets": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"preds": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"targets": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}

_inputs_squad_exact_match = SquadInput(
preds=SAMPLE_1["preds"], targets=SAMPLE_1["targets"], exact_match=SAMPLE_1["exact_match"], f1=SAMPLE_1["f1"]
)

_inputs_squad_exact_mismatch = SquadInput(
preds=SAMPLE_2["preds"], targets=SAMPLE_2["targets"], exact_match=SAMPLE_2["exact_match"], f1=SAMPLE_2["f1"]
)

_inputs_squad_batch_match = SquadInput(
preds=BATCH["preds"], targets=BATCH["targets"], exact_match=BATCH["exact_match"], f1=BATCH["f1"]
)

# single reference
TUPLE_OF_SINGLE_REFERENCES = (((REFERENCE_1A), (REFERENCE_1B)), ((REFERENCE_1B), (REFERENCE_1C)))
_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES)
4 changes: 2 additions & 2 deletions tests/text/test_cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
compute_measures = Callable


def compare_fn(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return cer(reference, prediction)
def compare_fn(preds: Union[str, List[str]], target: Union[str, List[str]]):
return cer(target, preds)


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
4 changes: 2 additions & 2 deletions tests/text/test_mer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchmetrics.text.mer import MatchErrorRate


def _compute_mer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return compute_measures(reference, prediction)["mer"]
def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["mer"]


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
14 changes: 7 additions & 7 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,27 @@

def _compute_rouge_score(
preds: Sequence[str],
targets: Sequence[Sequence[str]],
target: Sequence[Sequence[str]],
use_stemmer: bool,
rouge_level: str,
metric: str,
accumulate: str,
):
"""Evaluates rouge scores from rouge-score package for baseline evaluation."""
if isinstance(targets, list) and all(isinstance(target, str) for target in targets):
targets = [targets] if isinstance(preds, str) else [[target] for target in targets]
if isinstance(target, list) and all(isinstance(tgt, str) for tgt in target):
target = [target] if isinstance(preds, str) else [[tgt] for tgt in target]

if isinstance(preds, str):
preds = [preds]

if isinstance(targets, str):
targets = [[targets]]
if isinstance(target, str):
target = [[target]]

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()

for target_raw, pred_raw in zip(targets, preds):
list_results = [scorer.score(target, pred_raw) for target in target_raw]
for target_raw, pred_raw in zip(target, preds):
list_results = [scorer.score(tgt, pred_raw) for tgt in target_raw]
aggregator_avg = BootstrapAggregator()

if accumulate == "best":
Expand Down
72 changes: 35 additions & 37 deletions tests/text/test_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,26 @@
import torch.multiprocessing as mp

from tests.helpers.testers import _assert_allclose, _assert_tensor
from tests.text.inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch
from torchmetrics.functional.text import squad
from torchmetrics.text.squad import SQuAD

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"predictions": {"prediction_text": "1976", "id": "id1"},
"references": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"predictions": {"prediction_text": "Hello", "id": "id2"},
"references": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"predictions": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"references": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
"preds, targets, exact_match, f1",
[
(SAMPLE_1["predictions"], SAMPLE_1["references"], SAMPLE_1["exact_match"], SAMPLE_1["exact_match"]),
(SAMPLE_2["predictions"], SAMPLE_2["references"], SAMPLE_2["exact_match"], SAMPLE_2["exact_match"]),
(
_inputs_squad_exact_match.preds,
_inputs_squad_exact_match.targets,
_inputs_squad_exact_match.exact_match,
_inputs_squad_exact_match.f1,
),
(
_inputs_squad_exact_mismatch.preds,
_inputs_squad_exact_mismatch.targets,
_inputs_squad_exact_mismatch.exact_match,
_inputs_squad_exact_mismatch.f1,
),
],
)
def test_score_fn(preds, targets, exact_match, f1):
Expand All @@ -54,14 +38,21 @@ def test_score_fn(preds, targets, exact_match, f1):


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
"preds, targets, exact_match, f1",
[
(
_inputs_squad_batch_match.preds,
_inputs_squad_batch_match.targets,
_inputs_squad_batch_match.exact_match,
_inputs_squad_batch_match.f1,
)
],
)
def test_accumulation(preds, targets, exact_match, f1):
"""Tests for metric works with accumulation."""
squad_metric = SQuAD()
for pred, target in zip(preds, targets):
squad_metric.update(preds=[pred], targets=[target])
squad_metric.update(preds=[pred], target=[target])
metrics_score = squad_metric.compute()

_assert_tensor(metrics_score["exact_match"])
Expand All @@ -70,13 +61,13 @@ def test_accumulation(preds, targets, exact_match, f1):
_assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1)))


def _squad_score_ddp(rank, world_size, pred, target, exact_match, f1):
def _squad_score_ddp(rank, world_size, pred, targets, exact_match, f1):
"""Define a DDP process for SQuAD metric."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
squad_metric = SQuAD()
squad_metric.update(pred, target)
squad_metric.update(pred, targets)
metrics_score = squad_metric.compute()
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
Expand All @@ -91,8 +82,15 @@ def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1):


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
"preds, targets, exact_match, f1",
[
(
_inputs_squad_batch_match.preds,
_inputs_squad_batch_match.targets,
_inputs_squad_batch_match.exact_match,
_inputs_squad_batch_match.f1,
)
],
)
@pytest.mark.skipif(not dist.is_available(), reason="test requires torch distributed")
def test_score_ddp(preds, targets, exact_match, f1):
Expand Down
42 changes: 21 additions & 21 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def sacrebleu_ter_fn(
preds: Sequence[str],
targets: Sequence[Sequence[str]],
target: Sequence[Sequence[str]],
normalized: bool,
no_punct: bool,
asian_support: bool,
Expand All @@ -26,8 +26,8 @@ def sacrebleu_ter_fn(
normalized=normalized, no_punct=no_punct, asian_support=asian_support, case_sensitive=case_sensitive
)
# Sacrebleu CHRF expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
sacrebleu_ter = sacrebleu_ter.corpus_score(preds, targets).score / 100
target = [[tgt[i] for tgt in target] for i in range(len(target[0]))]
sacrebleu_ter = sacrebleu_ter.corpus_score(preds, target).score / 100
return tensor(sacrebleu_ter)


Expand Down Expand Up @@ -118,41 +118,41 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert translation_edit_rate(hyp, ref) == tensor(0.0)
preds = []
targets = [[]]
assert translation_edit_rate(preds, targets) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TranslationEditRate()
hyp = []
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
preds = []
targets = [[]]
assert ter_metric(preds, targets) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert translation_edit_rate(hyp, ref) == tensor(0.0)
preds = ["python"]
targets = [[]]
assert translation_edit_rate(preds, targets) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TranslationEditRate()
hyp = ["python"]
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
preds = ["python"]
targets = [[]]
assert ter_metric(preds, targets) == tensor(0.0)


def test_ter_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = translation_edit_rate(hyp, ref, return_sentence_level_score=True)
preds = _inputs_single_sentence_multiple_references.preds
targets = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = translation_edit_rate(preds, targets, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TranslationEditRate(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(hyp, ref)
preds = _inputs_single_sentence_multiple_references.preds
targets = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(preds, targets)
isinstance(sentence_ter, Tensor)
4 changes: 2 additions & 2 deletions tests/text/test_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchmetrics.text.wer import WordErrorRate


def _compute_wer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return compute_measures(reference, prediction)["wer"]
def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["wer"]


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
Loading