Skip to content

Commit

Permalink
Fix RougeL/RougeLSum implementation (#944)
Browse files Browse the repository at this point in the history
* Add failing tests showing broken rouge-L/rouge-Lsum implementation
* Fix RougeL/RougeLsum scoreing
* Add type ignore for creating lcs union
* Add reference link
  • Loading branch information
stancld authored Apr 11, 2022
1 parent 717622e commit 0d56376
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed `BestScore` on GPU ([#912](https://github.com/PyTorchLightning/metrics/pull/912))

- Fixed Lsum computation for `ROUGEScore` ([#944](https://github.com/PyTorchLightning/metrics/pull/944))


## [0.7.3] - 2022-03-23

Expand Down
67 changes: 56 additions & 11 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import re
from functools import partial
from typing import Callable, Sequence
from typing import Callable, Sequence, Union

import numpy as np
import pytest
import torch
from torch import Tensor
from typing_extensions import Literal

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference
from tests.text.inputs import Input, _inputs_multiple_references, _inputs_single_sentence_single_reference
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
Expand All @@ -35,25 +36,36 @@
ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")


# Some randomly adjusted input from CNN/DailyMail dataset which brakes the test
_preds = "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto ."
_target = "A trainer said her and Moschetto, 54s or weapons say . \nAuthorities Moschetto of ."
_inputs_summarization = Input(preds=_preds, targets=_target)


def _compute_rouge_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[Union[str, Sequence[str]]]],
use_stemmer: bool,
rouge_level: str,
metric: str,
accumulate: str,
):
accumulate: Literal = ["avg", "best", None],
) -> Tensor:
"""Evaluates rouge scores from rouge-score package for baseline evaluation."""
if isinstance(target, list) and all(isinstance(tgt, str) for tgt in target):
target = [target] if isinstance(preds, str) else [[tgt] for tgt in target]

if isinstance(preds, str):
if isinstance(preds, str) and accumulate:
preds = [preds]

if isinstance(target, str):
if isinstance(target, str) and accumulate:
target = [[target]]

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
if not accumulate:
rs_scores = scorer.score(target, preds)
rs_result = getattr(rs_scores[rouge_level], metric)
return torch.tensor(rs_result, dtype=torch.float)

aggregator = BootstrapAggregator()

for target_raw, pred_raw in zip(target, preds):
Expand All @@ -75,7 +87,7 @@ def _compute_rouge_score(

rs_scores = aggregator.aggregate()
rs_result = getattr(rs_scores[rouge_level].mid, metric)
return rs_result
return torch.tensor(rs_result, dtype=torch.float)


@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
Expand Down Expand Up @@ -208,4 +220,37 @@ def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key):
)
metrics_score = scorer.compute()

np.isclose(metrics_score[rouge_level + "_" + metric], original_score, atol=1e-8, equal_nan=True)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)


@pytest.mark.parametrize(
"pl_rouge_metric_key",
[
"rougeL_precision",
"rougeL_recall",
"rougeL_fmeasure",
"rougeLsum_precision",
"rougeLsum_recall",
"rougeLsum_fmeasure",
],
)
@pytest.mark.parametrize("use_stemmer", [False, True])
def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
"""Specific tests to verify the correctness of Rouge-L and Rouge-LSum metric."""
rouge_level, metric = pl_rouge_metric_key.split("_")
original_score = _compute_rouge_score(
preds=_inputs_summarization.preds,
target=_inputs_summarization.targets,
rouge_level=rouge_level,
metric=metric,
accumulate=None,
use_stemmer=use_stemmer,
)

metrics_score = rouge_score(
_inputs_summarization.preds,
_inputs_summarization.targets,
rouge_keys=rouge_level,
use_stemmer=use_stemmer,
)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)
137 changes: 119 additions & 18 deletions torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@
ALLOWED_ACCUMULATE_VALUES = ("avg", "best")


def _add_newline_to_end_of_each_sentence(x: str) -> str:
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
def _split_sentence(x: str) -> Sequence[str]:
"""The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
if not _NLTK_AVAILABLE:
raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.")
import nltk

nltk.download("punkt", quiet=True, force=False)

re.sub("<n>", "", x) # remove pegasus newline char
return "\n".join(nltk.sent_tokenize(x))
return nltk.sent_tokenize(x)


def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, Tensor]:
Expand All @@ -72,7 +72,9 @@ def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[s
return dict(precision=tensor(precision), recall=tensor(recall), fmeasure=tensor(fmeasure))


def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
def _lcs(
pred_tokens: Sequence[str], target_tokens: Sequence[str], return_full_table: bool = False
) -> Union[int, Sequence[Sequence[int]]]:
"""Common DP algorithm to compute the length of the longest common subsequence.
Args:
Expand All @@ -88,9 +90,66 @@ def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
lcs[i][j] = lcs[i - 1][j - 1] + 1
else:
lcs[i][j] = max(lcs[i - 1][j], lcs[i][j - 1])
if return_full_table:
return lcs
return lcs[-1][-1]


def _backtracked_lcs(
lcs_table: Sequence[Sequence[int]], pred_tokens: Sequence[str], target_tokens: Sequence[str]
) -> Sequence[int]:
"""Backtrack LCS table.
Args:
lcs_table:
A table containing information for the calculation of the longest common subsequence.
pred_tokens:
A tokenized predicted sentence.
target_tokens:
A tokenized target sentence.
"""
i = len(pred_tokens)
j = len(target_tokens)
backtracked_lcs: List[int] = []
while i > 0 and j > 0:
if pred_tokens[i - 1] == target_tokens[j - 1]:
backtracked_lcs.insert(0, j - 1)
i -= 1
j -= 1
elif lcs_table[j][i - 1] > lcs_table[j - 1][i]:
i -= 1
else:
j -= 1
return backtracked_lcs


def _union_lcs(pred_tokens_list: Sequence[Sequence[str]], target_tokens: Sequence[str]) -> Sequence[str]:
"""Find union LCS between a target sentence and iterable of predicted tokens.
Args:
pred_tokens_list:
A tokenized predicted sentence split by '\n'.
target_tokens:
A tokenized single part of target sentence split by '\n'.
Return:
"""

def lcs_ind(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> Sequence[int]:
"""Returns one of the longest of longest common subsequence via backtracked lcs table."""
lcs_table: Sequence[Sequence[int]] = _lcs(pred_tokens, target_tokens, return_full_table=True) # type: ignore
backtracked_lcs_table = _backtracked_lcs(lcs_table, pred_tokens, target_tokens)
return backtracked_lcs_table

def find_union(lcs_tables: Sequence[Sequence[int]]) -> Sequence[int]:
"""Find union LCS given a list of LCS."""
return sorted(list(set().union(*lcs_tables))) # type: ignore

lcs_tables = [lcs_ind(pred_tokens, target_tokens) for pred_tokens in pred_tokens_list]
union_lcs = [target_tokens[i] for i in find_union(lcs_tables)]
return union_lcs


def _normalize_and_tokenize_text(
text: str,
stemmer: Optional[Any] = None,
Expand Down Expand Up @@ -160,7 +219,7 @@ def _create_ngrams(tokens: Sequence[str], n: int) -> Counter:


def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric.
"""This computes precision, recall and F1 score for the Rouge-L metric.
Args:
pred:
Expand All @@ -172,10 +231,52 @@ def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tens
if 0 in (pred_len, target_len):
return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0))

lcs = _lcs(pred, target)
lcs: int = _lcs(pred, target) # type: ignore
return _compute_metrics(lcs, pred_len, target_len)


def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-LSum metric. More information can be found in Section
3.2 of the referenced paper [1]. This implementation follow the official implementation from:
https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py
Args:
pred:
An iterable of predicted sentence split by '\n'.
target:
An iterable target sentence split by '\n'.
References
[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin. https://aclanthology.org/W04-1013/
"""
pred_len = sum(map(len, pred))
target_len = sum(map(len, target))
if 0 in (pred_len, target_len):
return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0))

# Get token counts
def _get_token_counts(sentences: Sequence[Sequence[str]]) -> Counter:
ngrams: Counter = Counter()
for sentence in sentences:
ngrams.update(sentence)
return ngrams

pred_tokens_count = _get_token_counts(pred)
target_tokens_count = _get_token_counts(target)

# Calculate hits
hits = 0
for tgt in target:
lcs = _union_lcs(pred, tgt)
for token in lcs:
if pred_tokens_count[token] > 0 and target_tokens_count[token] > 0:
hits += 1
pred_tokens_count[token] -= 1
target_tokens_count[token] -= 1

return _compute_metrics(hits, pred_len, target_len)


def _rouge_score_update(
preds: Sequence[str],
target: Sequence[Sequence[str]],
Expand Down Expand Up @@ -239,27 +340,27 @@ def _rouge_score_update(
result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}
list_results = []
pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer)
pred_lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(pred_raw), stemmer, normalizer, tokenizer
)
pred_lsum = [
_normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer)
for pred_sentence in _split_sentence(pred_raw)
]

for target_raw_inner in target_raw:
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer)

if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
target_lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer, normalizer, tokenizer
)
target_lsum = [
_normalize_and_tokenize_text(tgt_sentence, stemmer, normalizer, tokenizer)
for tgt_sentence in _split_sentence(target_raw_inner)
]

for rouge_key in rouge_keys_values:
if isinstance(rouge_key, int):
score = _rouge_n_score(pred, tgt, rouge_key)
else:
score = _rouge_l_score(
pred if rouge_key != "Lsum" else pred_lsum,
tgt if rouge_key != "Lsum" else target_lsum,
)
elif rouge_key == "L":
score = _rouge_l_score(pred, tgt)
elif rouge_key == "Lsum":
score = _rouge_lsum_score(pred_lsum, target_lsum)
result_inner[rouge_key] = score
result_avg[rouge_key].append(score)
list_results.append(result_inner.copy())
Expand Down

0 comments on commit 0d56376

Please sign in to comment.