Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding WER metric #383

Merged
merged 63 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
f26a1bb
Update functional.rst
gagan3012 Jul 16, 2021
8c63e5b
Update modules.rst
gagan3012 Jul 16, 2021
9bc95ad
Update test.txt
gagan3012 Jul 16, 2021
de17c62
Create text.txt
gagan3012 Jul 16, 2021
7ee4abf
Update setup.py
gagan3012 Jul 16, 2021
90d7791
Update __init__.py
gagan3012 Jul 16, 2021
cbb9e6b
Update __init__.py
gagan3012 Jul 16, 2021
843326a
Update functional.rst
gagan3012 Jul 16, 2021
546b796
Create __init__.py
gagan3012 Jul 16, 2021
f4e7296
Create wer.py
gagan3012 Jul 16, 2021
c8fa0ee
Create __init__.py
gagan3012 Jul 16, 2021
3b07ff8
Create wer.py
gagan3012 Jul 16, 2021
c740905
Merge branch 'feature/wer' of https://github.com/gagan3012/metrics in…
gagan3012 Jul 16, 2021
05ef82e
Create text_wer.py
gagan3012 Jul 16, 2021
f0b4e80
Create __init__.py
gagan3012 Jul 16, 2021
93ad65e
Create test_wer.py
gagan3012 Jul 16, 2021
f5dc82c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
1ee8656
Update wer.py
gagan3012 Jul 16, 2021
fab3778
Merge branch 'master' into feature/wer
SkafteNicki Jul 17, 2021
2697b16
Requested changes
gagan3012 Jul 18, 2021
b697a89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2021
da2ad70
Update wer.py
gagan3012 Jul 18, 2021
67fcd0b
Merge branch 'feature/wer' of https://github.com/gagan3012/metrics in…
gagan3012 Jul 18, 2021
05d7fac
Update test_wer.py
gagan3012 Jul 18, 2021
05f225b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2021
001c27f
Update wer.py
gagan3012 Jul 19, 2021
4da4400
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2021
04e8ca7
Update torchmetrics/functional/text/wer.py
gagan3012 Jul 19, 2021
79c8988
Update torchmetrics/text/wer.py
gagan3012 Jul 19, 2021
770528b
Fixes
gagan3012 Jul 19, 2021
04b9df9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2021
edd14a7
test updates
gagan3012 Jul 19, 2021
acdc481
Better explanation for tests, skip if JIWER not available
Jul 20, 2021
5cc099e
Merge remote-tracking branch 'upstream/master' into feature/wer
Jul 20, 2021
727dd12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
5bd7b74
Combine functionality to rely on functional wer
Jul 21, 2021
7977731
Cleanup docs
Jul 21, 2021
09d96c0
Add indent
Jul 21, 2021
f9bf667
Add doc
Jul 21, 2021
484c51c
Add spaces
Jul 21, 2021
93fa477
Add space
Jul 21, 2021
c848b15
Apply suggestions from code review
Borda Jul 22, 2021
7e79377
Update torchmetrics/__init__.py
gagan3012 Jul 22, 2021
69fcb0c
Update torchmetrics/functional/text/wer.py
gagan3012 Jul 22, 2021
002b177
Update torchmetrics/functional/text/wer.py
gagan3012 Jul 22, 2021
317e3e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
4dabca7
Adding extra arguments to metrics
gagan3012 Jul 22, 2021
198d8e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2021
ca8809e
Update torchmetrics/functional/text/wer.py
Jul 23, 2021
5a54a7d
Cleanups
Jul 23, 2021
1c92d5b
Update torchmetrics/text/wer.py
gagan3012 Jul 23, 2021
e04b076
Update torchmetrics/functional/text/wer.py
gagan3012 Jul 23, 2021
9f41a49
Update torchmetrics/text/wer.py
gagan3012 Jul 23, 2021
603b105
Update torchmetrics/text/wer.py
gagan3012 Jul 23, 2021
a2e9a15
Update CHANGELOG.md
gagan3012 Jul 23, 2021
05d2d52
Merge branch 'master' into feature/wer
mergify[bot] Jul 24, 2021
8abdbb5
Apply suggestions from code review
Borda Jul 24, 2021
a085347
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2021
6ebb283
flake8
Borda Jul 24, 2021
9244fd2
Merge branch 'master' into feature/wer
mergify[bot] Jul 24, 2021
5446b2f
Merge branch 'master' into feature/wer
mergify[bot] Jul 24, 2021
0df5cb6
Merge branch 'master' into feature/wer
mergify[bot] Jul 24, 2021
7e22476
Merge branch 'master' into feature/wer
mergify[bot] Jul 24, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,17 @@ bleu_score [func]
.. autofunction:: torchmetrics.functional.bleu_score
:noindex:

***
Text
***

wer [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.wer
:noindex:


********
Pairwise
********
Expand Down
9 changes: 9 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,15 @@ RetrievalNormalizedDCG
.. autoclass:: torchmetrics.RetrievalNormalizedDCG
:noindex:

************
Text Metrics
************

WER
~~~

.. autoclass:: torchmetrics.WER
:noindex:

********
Wrappers
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ nltk>=3.6

# add extra requirements
-r image.txt
-r text.txt

# audio
pypesq
Expand Down
1 change: 1 addition & 0 deletions requirements/text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
jiwer
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _load_py_module(fname, pkg="torchmetrics"):
def _prepare_extras():
extras = {
'image': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='image.txt'),
'text': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='text.txt'),
}
return extras

Expand Down
11 changes: 11 additions & 0 deletions tests/functional/text_wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved

from torchmetrics.functional.text.wer import wer


@pytest.mark.parametrize(
"hyp,ref,score",
[("hello world", "hello world", 0.0), ("hello world", "Firwww", 1.0)],
)
def test_wer_same(hyp, ref, score):
assert wer(ref, hyp) == score
Empty file added tests/text/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions tests/text/test_wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from torchmetrics.text.wer import WER


@pytest.mark.parametrize(
"hyp,ref,score",
[("hello world", "hello world", 0.0), ("hello world", "Firwww", 1.0)],
)
def test_wer_same(hyp, ref, score):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
metric = WER()
metric.update(hyp, ref)
assert metric.compute() == score
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import WER # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from torchmetrics.functional.text.wer import wer # noqa: F401
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 29 additions & 0 deletions torchmetrics/functional/text/wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved

from jiwer import compute_measures
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved


def wer(target: Any, preds: Any, concatenate_texts: bool = False) -> float:
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
"""
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
Args:
references: List of references for each speech input.
predictions: List of transcriptions to score.
concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.
Returns:
(float): the word error rate
Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> wer_score = wer(preds=predictions, target=references)
>>> print(wer_score)
0.5
"""
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
if concatenate_texts:
return compute_measures(target, preds)["wer"]
incorrect = 0
total = 0
for prediction, reference in zip(preds, target):
measures = compute_measures(reference, prediction)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total
1 change: 1 addition & 0 deletions torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from torchmetrics.text.wer import WER # noqa: F401
55 changes: 55 additions & 0 deletions torchmetrics/text/wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved

from jiwer import compute_measures
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved

from torchmetrics.metric import Metric


class WER(Metric):
"""
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
Word error rate (WER) is a common metric of the performance of an automatic speech recognition system.
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
WER's output is always a number between 0 and 1.
This value indicates the percentage of words that were incorrectly predicted.
The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score.
Word error rate can then be computed as:
WER = (S + D + I) / N = (S + D + I) / (S + D + C)
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
where:
S is the number of substitutions,
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
D is the number of deletions,
I is the number of insertions,
C is the number of correct words,
N is the number of words in the reference (N=S+D+C).
Compute WER score of transcribed segments against references.
Args:
references: List of references for each speech input.
predictions: List of transcriptions to score.
concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.
Borda marked this conversation as resolved.
Show resolved Hide resolved
Returns:
(float): the word error rate
Examples:
>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> wer = WER(predictions=predictions, references=references)
>>> wer_score = wer.compute()
>>> print(wer_score)
0.5
"""

def __init__(self, concatenate_texts: bool = False):
super().__init__()
gagan3012 marked this conversation as resolved.
Show resolved Hide resolved
self.concatenate_texts = concatenate_texts

def update(self, preds: Any, target: Any) -> None:
self.preds.append(preds)
self.target.append(target)

def compute(self) -> float:
if self.concatenate_texts:
return compute_measures(self.target, self.preds)["wer"]
incorrect = 0
total = 0
for prediction, reference in zip(self.preds, self.target):
measures = compute_measures(reference, prediction)
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
return incorrect / total