Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency on jiwer for WER #446

Merged
merged 8 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed `rouge-score` as dependency for text package ([#443](https://github.com/PyTorchLightning/metrics/pull/443))


- Removed `jiwer` as dependency for text package ([#446](https://github.com/PyTorchLightning/metrics/pull/446))


### Fixed

- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ mir_eval>=0.6
speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip

# text
jiwer>=2.2.0
rouge-score>=0.0.4
1 change: 0 additions & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
jiwer>=2.2.0
nltk>=3.6
bert-score==0.3.10
97 changes: 41 additions & 56 deletions tests/text/test_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,76 +8,61 @@
from torchmetrics.functional.text.wer import wer
from torchmetrics.text.wer import WER

PREDICTION1 = "hello world"
REFERENCE1 = "hello world"

@pytest.mark.parametrize(
"hyp,ref,score",
[(["hello world"], ["hello world"], 0.0), (["Firwww"], ["hello world"], 1.0)],
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_same(hyp, ref, score):
"""Test to ensure that the torchmetric WER matches reference scores."""
metric = WER()
metric.update(hyp, ref)
assert metric.compute() == score

PREDICTION2 = "what a day"
REFERENCE2 = "what a wonderful day"

@pytest.mark.parametrize(
"hyp,ref,expected_score,expected_incorrect,expected_total",
[
(["hello world"], ["hello world"], 0.0, 0, 2),
(["Firwww"], ["hello world"], 1.0, 2, 2),
],
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_functional(ref, hyp, expected_score, expected_incorrect, expected_total):
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
assert wer(ref, hyp) == expected_score
BATCH_PREDICTIONS = [PREDICTION1, PREDICTION2]
BATCH_REFERENCES = [REFERENCE1, REFERENCE2]


@pytest.mark.parametrize(
"hyp,ref",
[
(["hello world"], ["hello world"]),
(["Firwww"], ["hello world"]),
],
"prediction,reference",
[(PREDICTION1, REFERENCE1), (PREDICTION2, REFERENCE2)],
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_functional(hyp, ref):
"""Test to ensure that the torchmetric functional WER matches the jiwer reference."""
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
def test_wer_functional_single_sentence(prediction, reference):
"""Test functional with strings as inputs."""
pl_output = wer(prediction, reference)
jiwer_output = compute_measures(reference, prediction)["wer"]
assert pl_output == jiwer_output


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_functional_concatenate():
"""Test to ensure that the torchmetric functional WER matches the jiwer reference when concatenating."""
ref = ["hello world", "hello world"]
hyp = ["hello world", "Firwww"]
assert wer(ref, hyp) == compute_measures(ref, hyp)["wer"]
assert wer(hyp, ref, concatenate_texts=True) == compute_measures("".join(ref), "".join(hyp))["wer"]
def test_wer_functional_batch():
"""Test functional with a batch of sentences."""
pl_output = wer(BATCH_PREDICTIONS, BATCH_REFERENCES)
jiwer_output = compute_measures(BATCH_REFERENCES, BATCH_PREDICTIONS)["wer"]
assert pl_output == jiwer_output


@pytest.mark.parametrize(
"hyp,ref",
[
(["hello world"], ["hello world"]),
(["Firwww"], ["hello world"]),
],
"prediction,reference",
[(PREDICTION1, REFERENCE1), (PREDICTION2, REFERENCE2)],
)
@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference(hyp, ref):
"""Test to ensure that the torchmetric WER matches the jiwer reference."""
def test_wer_class_single_sentence(prediction, reference):
"""Test class with strings as inputs."""
metric = WER()
metric.update(hyp, ref)
assert metric.compute() == compute_measures(ref, hyp)["wer"]
metric.update(prediction, reference)
pl_output = metric.compute()
jiwer_output = compute_measures(reference, prediction)["wer"]
assert pl_output == jiwer_output


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
def test_wer_reference_batch():
"""Test to ensure that the torchmetric WER matches the jiwer reference with accumulation."""
batches = [("hello world", "Firwww"), ("hello world", "hello world")]
def test_wer_class_batch():
"""Test class with a batch of sentences."""
metric = WER()
metric.update(BATCH_PREDICTIONS, BATCH_REFERENCES)
pl_output = metric.compute()
jiwer_output = compute_measures(BATCH_REFERENCES, BATCH_PREDICTIONS)["wer"]
assert pl_output == jiwer_output

for hyp, ref in batches:
metric.update(ref, hyp)
reference_score = compute_measures(truth=[x[0] for x in batches], hypothesis=[x[1] for x in batches])["wer"]
assert metric.compute() == reference_score

def test_wer_class_batches():
"""Test class with two batches of sentences."""
metric = WER()
for prediction, reference in zip(BATCH_PREDICTIONS, BATCH_REFERENCES):
metric.update(prediction, reference)
pl_output = metric.compute()
jiwer_output = compute_measures(BATCH_REFERENCES, BATCH_PREDICTIONS)["wer"]
assert pl_output == jiwer_output
105 changes: 80 additions & 25 deletions torchmetrics/functional/text/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,103 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List, Tuple, Union
from warnings import warn

from torchmetrics.utilities.imports import _JIWER_AVAILABLE
import torch
from torch import Tensor, tensor

if _JIWER_AVAILABLE:
from jiwer import compute_measures

def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int:
"""Standard dynamic programming algorithm to compute the edit distance.

def wer(
Args:
prediction_tokens: A tokenized predicted sentence
reference_tokens: A tokenized reference sentence

Returns:
(int) Edit distance between the predicted sentence and the reference sentence
"""
dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)]
for i in range(len(prediction_tokens) + 1):
dp[i][0] = i
for j in range(len(reference_tokens) + 1):
dp[0][j] = j
for i in range(1, len(prediction_tokens) + 1):
for j in range(1, len(reference_tokens) + 1):
if prediction_tokens[i - 1] == reference_tokens[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
return dp[-1][-1]


def _wer_update(
predictions: Union[str, List[str]],
references: Union[str, List[str]],
) -> Tuple[Tensor, Tensor]:
"""Update the wer score with the current set of references and predictions.

Args:
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings

Returns:
(Tensor) Number of edit operations to get from the reference to the prediction, summed over all samples
(Tensor) Number of words over all references
"""
if isinstance(predictions, str):
predictions = [predictions]
if isinstance(references, str):
references = [references]
errors = tensor(0, dtype=torch.float)
total = tensor(0, dtype=torch.float)
for prediction, reference in zip(predictions, references):
prediction_tokens = prediction.split()
reference_tokens = reference.split()
errors += _edit_distance(prediction_tokens, reference_tokens)
total += len(reference_tokens)
return errors, total


def _wer_compute(errors: Tensor, total: Tensor) -> Tensor:
"""Compute the word error rate.

Args:
errors: Number of edit operations to get from the reference to the prediction, summed over all samples
total: Number of words over all references

Returns:
(Tensor) Word error rate
"""
return errors / total


def wer(
predictions: Union[str, List[str]],
references: Union[str, List[str]],
concatenate_texts: bool = False,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
) -> float:
"""Word error rate (WER_) is a common metric of the performance of an automatic speech recognition system. This
) -> Tensor:
"""Word error rate (WER) is a common metric of the performance of an automatic speech recognition system. This
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
performance of the ASR system with a WER of 0 being a perfect score.

Args:
references: List of references for each speech input.
predictions: List of transcriptions to score.
concatenate_texts: Whether to concatenate all input texts or compute WER iteratively.
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings
concatenate_texts: Whether to concatenate all input texts or compute WER iteratively
This argument is deprecated in v0.6 and it will be removed in v0.7.

Returns:
(float): the word error rate, or if ``return_measures`` is True, we include the incorrect and total.
(Tensor) Word error rate

Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> wer(predictions=predictions, references=references)
0.5
tensor(0.5000)
"""
if not _JIWER_AVAILABLE:
raise ModuleNotFoundError(
"wer metric requires that jiwer is installed."
" Either install as `pip install torchmetrics[text]` or `pip install jiwer`"
)
if concatenate_texts:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return compute_measures(references, predictions)["wer"]
incorrect = 0
total = 0
for prediction, reference in zip(predictions, references):
measures = compute_measures(reference, prediction)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total
warn("`concatenate_texts` has been deprecated in v0.6 and it will be removed in v0.7", DeprecationWarning)
errors, total = _wer_update(predictions, references)
return _wer_compute(errors, total)
47 changes: 27 additions & 20 deletions torchmetrics/text/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.

from typing import Any, Callable, List, Optional, Union
from warnings import warn

from torchmetrics.functional import wer
import torch
from torch import Tensor, tensor

from torchmetrics.functional.text.wer import _wer_compute, _wer_update
from torchmetrics.metric import Metric


class WER(Metric):
r"""
Word error rate (WER_) is a common metric of
the performance of an automatic speech recognition system.
Word error rate (WER) is a common metric of the performance of an automatic speech recognition system.
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
This value indicates the percentage of words that were incorrectly predicted.
The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score.
Word error rate can then be computed as:
Expand All @@ -30,7 +33,6 @@ class WER(Metric):
WER = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}

where:

- S is the number of substitutions,
- D is the number of deletions,
- I is the number of insertions,
Expand All @@ -41,6 +43,7 @@ class WER(Metric):

Args:
concatenate_texts: Whether to concatenate all input texts or compute WER iteratively.
This argument is deprecated in v0.6 and it will be removed in v0.7.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Expand All @@ -53,17 +56,19 @@ class WER(Metric):
will be used to perform the allgather

Returns:
(float): the word error rate
(Tensor) Word error rate

Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> metric = WER()
>>> metric(predictions, references)
0.5

tensor(0.5000)
"""

error: Tensor
total: Tensor

def __init__(
self,
concatenate_texts: bool = False,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -78,24 +83,26 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.concatenate_texts = concatenate_texts
self.add_state("predictions", [], dist_reduce_fx="cat")
self.add_state("references", [], dist_reduce_fx="cat")
if concatenate_texts:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
warn("`concatenate_texts` has been deprecated in v0.6 and it will be removed in v0.7", DeprecationWarning)
self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum")

def update(self, predictions: Union[str, List[str]], references: Union[str, List[str]]) -> None: # type: ignore
"""Store predictions/references for computing Word Error Rate scores.
"""Store references/predictions for computing Word Error Rate scores.

Args:
predictions: List of transcriptions to score.
references: List of references for each speech input.
predictions: Transcription(s) to score as a string or list of strings
references: Reference(s) for each speech input as a string or list of strings
"""
self.predictions.append(predictions)
self.references.append(references)
errors, total = _wer_update(predictions, references)
self.errors += errors
self.total += total

def compute(self) -> float:
"""Calculate Word Error Rate scores.
def compute(self) -> Tensor:
"""Calculate the word error rate.

Return:
Float with WER Score.
Returns:
(Tensor) Word error rate
"""
return wer(self.references, self.predictions, concatenate_texts=self.concatenate_texts)
return _wer_compute(self.errors, self.total)