Skip to content

Commit

Permalink
MER - Match Error Rate (#619)
Browse files Browse the repository at this point in the history
* Metric - MatchErrorRate

Match Error Rate logic

* Update mer.py

* update

* Update torchmetrics/functional/text/mer.py

* Update torchmetrics/functional/text/mer.py

* Update torchmetrics/functional/text/mer.py

* Update torchmetrics/functional/text/mer.py

* Update torchmetrics/text/mer.py

* Update torchmetrics/text/mer.py

* Update torchmetrics/text/mer.py

* Update torchmetrics/text/mer.py

* Update torchmetrics/text/mer.py

* Update mer.py

* Update mer.py

* Update torchmetrics/text/mer.py

* update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
4 people authored Nov 15, 2021
1 parent b3f4727 commit b31599b
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased] - 2021-MM-DD

### Added
- Added NLP metrics:
- `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619))


### Changed
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ char_error_rate [func]
.. autofunction:: torchmetrics.functional.char_error_rate
:noindex:

match_error_rate [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.match_error_rate
:noindex:

rouge_score [func]
~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,12 @@ CharErrorRate
.. autoclass:: torchmetrics.CharErrorRate
:noindex:

MatchErrorRate
~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MatchErrorRate
:noindex:

ROUGEScore
~~~~~~~~~~

Expand Down
75 changes: 75 additions & 0 deletions tests/text/test_mer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Callable, List, Union

import pytest

from tests.text.helpers import INPUT_ORDER, TextTester
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures = Callable

from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate

BATCHES_1 = {"preds": [["hello world"], ["what a day"]], "targets": [["hello world"], ["what a wonderful day"]]}

BATCHES_2 = {
"preds": [
["i like python", "what you mean or swallow"],
["hello duck", "i like python"],
],
"targets": [
["i like monthy python", "what do you mean, african or european swallow"],
["hello world", "i like monthy python"],
],
}


def _compute_mer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return compute_measures(reference, prediction)["mer"]


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
@pytest.mark.parametrize(
["preds", "targets"],
[
pytest.param(BATCHES_1["preds"], BATCHES_1["targets"]),
pytest.param(BATCHES_2["preds"], BATCHES_2["targets"]),
],
)
class TestMatchErrorRate(TextTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_mer_class(self, ddp, dist_sync_on_step, preds, targets):

self.run_class_metric_test(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=MatchErrorRate,
sk_metric=_compute_mer_metric_jiwer,
dist_sync_on_step=dist_sync_on_step,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_mer_functional(self, preds, targets):

self.run_functional_metric_test(
preds,
targets,
metric_functional=match_error_rate,
sk_metric=_compute_mer_metric_jiwer,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_mer_differentiability(self, preds, targets):

self.run_differentiability_test(
preds=preds,
targets=targets,
metric_module=MatchErrorRate,
metric_functional=match_error_rate,
input_order=INPUT_ORDER.PREDS_FIRST,
)
11 changes: 10 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@
RetrievalRecall,
RetrievalRPrecision,
)
from torchmetrics.text import WER, BERTScore, BLEUScore, CharErrorRate, ROUGEScore, SacreBLEUScore # noqa: E402
from torchmetrics.text import ( # noqa: E402
WER,
BERTScore,
BLEUScore,
CharErrorRate,
MatchErrorRate,
ROUGEScore,
SacreBLEUScore,
)
from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402

__all__ = [
Expand Down Expand Up @@ -142,4 +150,5 @@
"SymmetricMeanAbsolutePercentageError",
"WER",
"CharErrorRate",
"MatchErrorRate",
]
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from torchmetrics.functional.text.bert import bert_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.wer import wer
Expand Down Expand Up @@ -137,4 +138,5 @@
"symmetric_mean_absolute_percentage_error",
"wer",
"char_error_rate",
"match_error_rate",
]
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
from torchmetrics.functional.text.cer import char_error_rate # noqa: F401
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
109 changes: 109 additions & 0 deletions torchmetrics/functional/text/mer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, Union

import torch
from torch import Tensor, tensor


def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int:
"""Standard dynamic programming algorithm to compute the edit distance.
Args:
prediction_tokens: A tokenized predicted sentence
reference_tokens: A tokenized reference sentence
Returns:
Editing distance between the predicted sentence and the reference sentence
"""
dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)]
dp[:][0] = list(range(len(prediction_tokens) + 1))
dp[0][:] = list(range(len(reference_tokens) + 1))
for i in range(1, len(prediction_tokens) + 1):
for j in range(1, len(reference_tokens) + 1):
if prediction_tokens[i - 1] == reference_tokens[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
return dp[-1][-1]


def _mer_update(
predictions: Union[str, List[str]],
references: Union[str, List[str]],
) -> Tuple[Tensor, Tensor]:
"""Update the mer score with the current set of references and predictions.
Args:
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings
Returns:
Number of edit operations to get from the reference to the prediction, summed over all samples
Number of words over all references
"""
if isinstance(predictions, str):
predictions = [predictions]
if isinstance(references, str):
references = [references]
errors = tensor(0, dtype=torch.float)
total = tensor(0, dtype=torch.float)
for prediction, reference in zip(predictions, references):
prediction_tokens = prediction.split()
reference_tokens = reference.split()
errors += _edit_distance(prediction_tokens, reference_tokens)
total += max(len(reference_tokens), len(prediction_tokens))

return errors, total


def _mer_compute(errors: Tensor, total: Tensor) -> Tensor:
"""Compute the match error rate.
Args:
errors: Number of edit operations to get from the reference to the prediction, summed over all samples
total: Number of words over all references
Returns:
(Tensor) Match error rate
"""
return errors / total


def match_error_rate(
predictions: Union[str, List[str]],
references: Union[str, List[str]],
) -> Tensor:
"""Match error rate is a metric of the performance of an automatic speech recognition system. This value
indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the better
the performance of the ASR system with a MatchErrorRate of 0 being a perfect score.
Args:
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings
Returns:
Match error rate score
Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> match_error_rate(predictions=predictions, references=references)
tensor(0.4444)
"""

errors, total = _mer_update(predictions, references)
return _mer_compute(errors, total)
1 change: 1 addition & 0 deletions torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.text.bert import BERTScore # noqa: F401
from torchmetrics.text.bleu import BLEUScore # noqa: F401
from torchmetrics.text.cer import CharErrorRate # noqa: F401
from torchmetrics.text.mer import MatchErrorRate # noqa: F401
from torchmetrics.text.rouge import ROUGEScore # noqa: F401
from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401
from torchmetrics.text.wer import WER # noqa: F401
102 changes: 102 additions & 0 deletions torchmetrics/text/mer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.text.mer import _mer_compute, _mer_update
from torchmetrics.metric import Metric


class MatchErrorRate(Metric):
r"""
Match error rate (MatchErrorRate_) is a common metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were incorrectly predicted and inserted.
The lower the value, the better the performance of the ASR system with a MatchErrorRate of 0 being a perfect score.
Match error rate can then be computed as:
.. math::
mer = \frac{S + D + I}{N + I} = \frac{S + D + I}{S + D + C + I}
where:
- S is the number of substitutions,
- D is the number of deletions,
- I is the number of insertions,
- C is the number of correct words,
- N is the number of words in the reference (N=S+D+C).
Args:
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Returns:
Match error rate score
Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> metric = MatchErrorRate()
>>> metric(predictions, references)
tensor(0.4444)
"""
is_differentiable = False
higher_is_better = False
error: Tensor
total: Tensor

def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum")

def update(self, predictions: Union[str, List[str]], references: Union[str, List[str]]) -> None: # type: ignore
"""Store references/predictions for computing Match Error Rate scores.
Args:
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings
"""
errors, total = _mer_update(predictions, references)
self.errors += errors
self.total += total

def compute(self) -> Tensor:
"""Calculate the Match error rate.
Returns:
Match error rate
"""
return _mer_compute(self.errors, self.total)

0 comments on commit b31599b

Please sign in to comment.