-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create cer.py Character Error Rate logic updated * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update cer.py * Create cer.py * Update cer.py * Update cer.py * Update cer.py * removed unused imports * removed unused imports * Update torchmetrics/functional/text/cer.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update torchmetrics/text/cer.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update CHANGELOG.md * Update cer.py * Update cer.py * adding bindings and other changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * issue fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * issue fix * Update cer.py * Apply suggestions from code review * move to alphabetical order * Update test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor issue fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_cer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_cer.py * improve tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
- Loading branch information
1 parent
f100130
commit 3664c7b
Showing
10 changed files
with
311 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Callable, List, Union | ||
|
||
import pytest | ||
|
||
from tests.text.helpers import INPUT_ORDER, TextTester | ||
from torchmetrics.functional.text.cer import char_error_rate | ||
from torchmetrics.text.cer import CharErrorRate | ||
from torchmetrics.utilities.imports import _JIWER_AVAILABLE | ||
|
||
if _JIWER_AVAILABLE: | ||
from jiwer import compute_measures | ||
else: | ||
compute_measures = Callable | ||
|
||
BATCHES_1 = {"preds": [["hello world"], ["what a day"]], "targets": [["hello world"], ["what a wonderful day"]]} | ||
|
||
BATCHES_2 = { | ||
"preds": [ | ||
["i like python", "what you mean or swallow"], | ||
["hello duck", "i like python"], | ||
], | ||
"targets": [ | ||
["i like monthy python", "what do you mean, african or european swallow"], | ||
["hello world", "i like monthy python"], | ||
], | ||
} | ||
|
||
|
||
def compare_fn(prediction: Union[str, List[str]], reference: Union[str, List[str]]): | ||
"""compute cer as wer where we just split each word by character.""" | ||
# we also need to count spaces, so these need to be mapped to some not so common character | ||
prediction = map(lambda s: s.replace(" ", "@"), prediction) | ||
reference = map(lambda s: s.replace(" ", "@"), reference) | ||
# split into individual characters | ||
prediction = [char for seq in prediction for char in seq] | ||
reference = [char for seq in reference for char in seq] | ||
return compute_measures(reference, prediction)["wer"] | ||
|
||
|
||
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") | ||
@pytest.mark.parametrize( | ||
["preds", "targets"], | ||
[ | ||
pytest.param(BATCHES_1["preds"], BATCHES_1["targets"]), | ||
pytest.param(BATCHES_2["preds"], BATCHES_2["targets"]), | ||
], | ||
) | ||
class TestCharErrorRate(TextTester): | ||
"""test class for character error rate.""" | ||
|
||
@pytest.mark.parametrize("ddp", [False, True]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [False, True]) | ||
def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): | ||
"""test modular version of cer.""" | ||
self.run_class_metric_test( | ||
ddp=ddp, | ||
preds=preds, | ||
targets=targets, | ||
metric_class=CharErrorRate, | ||
sk_metric=compare_fn, | ||
dist_sync_on_step=dist_sync_on_step, | ||
input_order=INPUT_ORDER.PREDS_FIRST, | ||
) | ||
|
||
def test_cer_functional(self, preds, targets): | ||
"""test functional version of cer.""" | ||
self.run_functional_metric_test( | ||
preds, | ||
targets, | ||
metric_functional=char_error_rate, | ||
sk_metric=compare_fn, | ||
input_order=INPUT_ORDER.PREDS_FIRST, | ||
) | ||
|
||
def test_cer_differentiability(self, preds, targets): | ||
"""test differentiability of cer metric.""" | ||
self.run_differentiability_test( | ||
preds=preds, | ||
targets=targets, | ||
metric_module=CharErrorRate, | ||
metric_functional=char_error_rate, | ||
input_order=INPUT_ORDER.PREDS_FIRST, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
from torch import Tensor, tensor | ||
|
||
|
||
def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int: | ||
"""Standard dynamic programming algorithm to compute the edit distance. | ||
Args: | ||
prediction_tokens: A tokenized predicted sentence | ||
reference_tokens: A tokenized reference sentence | ||
Returns: | ||
(int) Edit distance between the predicted sentence and the reference sentence | ||
""" | ||
dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)] | ||
for i in range(len(prediction_tokens) + 1): | ||
dp[i][0] = i | ||
for j in range(len(reference_tokens) + 1): | ||
dp[0][j] = j | ||
for i in range(1, len(prediction_tokens) + 1): | ||
for j in range(1, len(reference_tokens) + 1): | ||
if prediction_tokens[i - 1] == reference_tokens[j - 1]: | ||
dp[i][j] = dp[i - 1][j - 1] | ||
else: | ||
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1 | ||
return dp[-1][-1] | ||
|
||
|
||
def _cer_update( | ||
predictions: Union[str, List[str]], | ||
references: Union[str, List[str]], | ||
) -> Tuple[Tensor, Tensor]: | ||
"""Update the cer score with the current set of references and predictions. | ||
Args: | ||
predictions: Transcription(s) to score as a string or list of strings | ||
references: Reference(s) for each speech input as a string or list of strings | ||
Returns: | ||
(Tensor) Number of edit operations to get from the reference to the prediction, summed over all samples | ||
(Tensor) Number of character over all references | ||
""" | ||
if isinstance(predictions, str): | ||
predictions = [predictions] | ||
if isinstance(references, str): | ||
references = [references] | ||
errors = tensor(0, dtype=torch.float) | ||
total = tensor(0, dtype=torch.float) | ||
for prediction, reference in zip(predictions, references): | ||
prediction_tokens = prediction | ||
reference_tokens = reference | ||
errors += _edit_distance(list(prediction_tokens), list(reference_tokens)) | ||
total += len(reference_tokens) | ||
return errors, total | ||
|
||
|
||
def _cer_compute(errors: Tensor, total: Tensor) -> Tensor: | ||
"""Compute the Character error rate. | ||
Args: | ||
errors: Number of edit operations to get from the reference to the prediction, summed over all samples | ||
total: Number of characters over all references | ||
Returns: | ||
(Tensor) Character error rate | ||
""" | ||
return errors / total | ||
|
||
|
||
def char_error_rate( | ||
predictions: Union[str, List[str]], | ||
references: Union[str, List[str]], | ||
) -> Tensor: | ||
"""character error rate is a common metric of the performance of an automatic speech recognition system. This | ||
value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the | ||
performance of the ASR system with a CER of 0 being a perfect score. | ||
Args: | ||
predictions: Transcription(s) to score as a string or list of strings | ||
references: Reference(s) for each speech input as a string or list of strings | ||
Returns: | ||
(Tensor) Character error rate | ||
Examples: | ||
>>> predictions = ["this is the prediction", "there is an other sample"] | ||
>>> references = ["this is the reference", "there is another one"] | ||
>>> char_error_rate(predictions=predictions, references=references) | ||
tensor(0.3415) | ||
""" | ||
errors, total = _cer_update(predictions, references) | ||
return _cer_compute(errors, total) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Callable, List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor, tensor | ||
|
||
from torchmetrics.functional.text.cer import _cer_compute, _cer_update | ||
from torchmetrics.metric import Metric | ||
|
||
|
||
class CharErrorRate(Metric): | ||
r""" | ||
Character error rate (CharErrorRate_) is a metric of the performance of an automatic speech recognition | ||
(ASR) system. This value indicates the percentage of characters that were incorrectly predicted. | ||
The lower the value, the better the performance of the ASR system with a CharErrorRate of 0 being | ||
a perfect score. | ||
Character error rate can then be computed as: | ||
.. math:: | ||
CharErrorRate = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C} | ||
where: | ||
- S is the number of substitutions, | ||
- D is the number of deletions, | ||
- I is the number of insertions, | ||
- C is the number of correct characters, | ||
- N is the number of characters in the reference (N=S+D+C). | ||
Compute CharErrorRate score of transcribed segments against references. | ||
Args: | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects the entire world) | ||
dist_sync_fn: | ||
Callback that performs the allgather operation on the metric state. When ``None``, DDP | ||
will be used to perform the allgather | ||
Returns: | ||
(Tensor) Character error rate | ||
Examples: | ||
>>> predictions = ["this is the prediction", "there is an other sample"] | ||
>>> references = ["this is the reference", "there is another one"] | ||
>>> metric = CharErrorRate() | ||
>>> metric(predictions, references) | ||
tensor(0.3415) | ||
""" | ||
is_differentiable = False | ||
higher_is_better = False | ||
error: Tensor | ||
total: Tensor | ||
|
||
def __init__( | ||
self, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
dist_sync_fn: Callable = None, | ||
): | ||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
dist_sync_fn=dist_sync_fn, | ||
) | ||
self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") | ||
self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") | ||
|
||
def update(self, predictions: Union[str, List[str]], references: Union[str, List[str]]) -> None: # type: ignore | ||
"""Store references/predictions for computing Character Error Rate scores. | ||
Args: | ||
predictions: Transcription(s) to score as a string or list of strings | ||
references: Reference(s) for each speech input as a string or list of strings | ||
""" | ||
errors, total = _cer_update(predictions, references) | ||
self.errors += errors | ||
self.total += total | ||
|
||
def compute(self) -> Tensor: | ||
"""Calculate the character error rate. | ||
Returns: | ||
(Tensor) Character error rate | ||
""" | ||
return _cer_compute(self.errors, self.total) |