Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchmetrics' own implementation of Rouge score metrics #443

Merged
merged 43 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6079376
Apply some changes to function/text/rouge.py
stancld Aug 11, 2021
b699140
Make Rouge-N working
stancld Aug 11, 2021
e6f604b
Add RougeL score calculation
stancld Aug 11, 2021
76e8581
Add some docs + enable using Porter stemmer
stancld Aug 11, 2021
b49f0cf
Enable RougeLSum calculation
stancld Aug 11, 2021
271e403
Add a few references and clean some parts
stancld Aug 11, 2021
41b6530
Fix flake8 issues
stancld Aug 11, 2021
10de32f
[WIP] Update tests (need to fix) + clean some unnecessary code
stancld Aug 11, 2021
6a9a77c
Fix some tests
stancld Aug 11, 2021
f171f1c
Fix a typo
stancld Aug 11, 2021
7c0fc40
Fix some remaining issues
stancld Aug 11, 2021
8ed0699
Return decimal_places argument to ROUGEScore and prepare depreciation…
stancld Aug 11, 2021
aacd0da
Fix some minor issues after the morning review
stancld Aug 12, 2021
542e525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
393198e
Fix some flake8 and mypy issues
stancld Aug 12, 2021
d1a800d
Fix some issues based on test results
stancld Aug 12, 2021
88d034b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
52da686
Use 0 in (x,y) instead of x == 0 or y == 0
stancld Aug 12, 2021
a78e33c
* Fix test issues and _rouge_score_update method
stancld Aug 12, 2021
68ee7f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
ba905d6
Fix docstring for _rouge_score_update
stancld Aug 12, 2021
caf60d4
Another fix for doc
stancld Aug 12, 2021
fad5b81
Replace dangerous default dict() values
stancld Aug 12, 2021
e86e3ba
Replace _RougeScore class with dict
stancld Aug 12, 2021
c1b2506
Use import nltk only conditionally when needed
stancld Aug 12, 2021
56e1bb8
Update tests using rouge-score package to be more generic
stancld Aug 12, 2021
22e2fee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
a67b0c6
Make some style changes to test_rouge.py
stancld Aug 12, 2021
74feff6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
514833e
Fix a styling typo
stancld Aug 12, 2021
9f326d0
Update error messages
stancld Aug 12, 2021
8a40825
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2021
79148d0
Add condition for importing nltk in ROUGEScore class
stancld Aug 13, 2021
b0c74ec
Apply suggestions from code review
Borda Aug 15, 2021
b2ff757
Merge branch 'master' into own_rouge
Borda Aug 15, 2021
acefe3d
Apply suggestions from code review
Borda Aug 16, 2021
2fc348e
fix warn
Borda Aug 16, 2021
eb30b5e
Merge branch 'master' into own_rouge
Borda Aug 16, 2021
28a142f
warn
SkafteNicki Aug 16, 2021
cae0c95
Update CHANGELOG.md
Borda Aug 16, 2021
b5ec3b8
Merge branch 'master' into own_rouge
mergify[bot] Aug 17, 2021
9b9aacc
Merge branch 'master' into own_rouge
mergify[bot] Aug 17, 2021
dcc2898
Merge branch 'master' into own_rouge
mergify[bot] Aug 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ mir_eval>=0.6
#pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip
#SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip
speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip

# text
rouge-score>=0.0.4
1 change: 0 additions & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
jiwer>=2.2.0
nltk>=3.6
rouge-score>=0.0.4
bert-score==0.3.10
188 changes: 89 additions & 99 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest
import torch
from torch import tensor

from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
Expand All @@ -30,16 +29,13 @@

ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")

PRECISION = 0
RECALL = 1
F_MEASURE = 2

SINGLE_SENTENCE_EXAMPLE_PREDS = "The quick brown fox jumps over the lazy dog"
SINGLE_SENTENCE_EXAMPLE_TARGET = "The quick brown dog jumps on the log."

PREDS = "My name is John".split()
TARGETS = "Is your name John".split()


BATCHES_RS_PREDS = [SINGLE_SENTENCE_EXAMPLE_PREDS]
BATCHES_RS_PREDS.extend(PREDS)
BATCHES_RS_TARGETS = [SINGLE_SENTENCE_EXAMPLE_TARGET]
Expand All @@ -55,145 +51,139 @@ def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool
scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()
for pred, target in zip(preds, targets):
aggregator.add_scores(scorer.score(pred, target))
aggregator.add_scores(scorer.score(target, pred))
return aggregator.aggregate()


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_functional_single_sentence(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
def test_rouge_metric_functional_single_sentence(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)

pl_output = rouge_score(
[SINGLE_SENTENCE_EXAMPLE_PREDS],
[SINGLE_SENTENCE_EXAMPLE_TARGET],
newline_sep=newline_sep,
use_stemmer=use_stemmer,
decimal_places=decimal_places,
)
pl_output = rouge_score([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET], use_stemmer=use_stemmer)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_functional(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
def test_rouge_metric_functional(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

rs_scores = _compute_rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)

pl_output = rouge_score(
PREDS, TARGETS, newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places
)
pl_output = rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_class(pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
def test_rouge_metric_class(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
rouge = ROUGEScore(use_stemmer=use_stemmer)
pl_output = rouge([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET])

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_class_batches(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
def test_rouge_metric_class_batches(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

rs_scores = _compute_rouge_score(BATCHES_RS_PREDS, BATCHES_RS_TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
rouge = ROUGEScore(use_stemmer=use_stemmer)
for batch in BATCHES:
rouge.update(batch["preds"], batch["targets"])
pl_output = rouge.compute()

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


def test_rouge_metric_raises_errors_and_warnings():
"""Test that expected warnings and errors are raised."""
if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
if not _NLTK_AVAILABLE:
with pytest.raises(
ValueError,
match="ROUGE metric requires that both nltk and rouge-score is installed."
"Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`",
match="ROUGE metric requires that nltk is installed."
"Either as `pip install torchmetrics[text]` or `pip install nltk`",
):
ROUGEScore()

Expand Down
Loading