Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency on jiwer for WER #446

Merged
merged 8 commits into from
Aug 18, 2021
Merged

Remove dependency on jiwer for WER #446

merged 8 commits into from
Aug 18, 2021

Conversation

kingyiusuen
Copy link
Contributor

@kingyiusuen kingyiusuen commented Aug 14, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

This PR is a part of the effort of #433 to implement own metrics and minimize third-party dependencies. It adds the own implementation of the word error rate (WER) metric.

This PR also strips the use of the concatenate_texts argument as this seems not to be used by any other metric.

Furthermore, this PR changes the order of references and predictions in the function signature of the functional version of WER. Currently, as far as I can tell, all metrics (including the class-based version of WER) accept model predictions as their first argument and the ground truths second. This change can keep the API consistent. (But it breaks backward compatibility? Any suggestion?)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Aug 14, 2021

Codecov Report

Merging #446 (bbcce5d) into master (db281f7) will decrease coverage by 0.02%.
The diff coverage is 96.07%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #446      +/-   ##
==========================================
- Coverage   95.92%   95.90%   -0.03%     
==========================================
  Files         130      130              
  Lines        4274     4301      +27     
==========================================
+ Hits         4100     4125      +25     
- Misses        174      176       +2     
Flag Coverage Δ
Linux 75.14% <92.15%> (+0.11%) ⬆️
Windows 75.14% <92.15%> (+0.11%) ⬆️
cpu 95.90% <96.07%> (+<0.01%) ⬆️
gpu ?
macOS 95.90% <96.07%> (+<0.01%) ⬆️
pytest 95.90% <96.07%> (-0.03%) ⬇️
python3.6 95.18% <96.07%> (+<0.01%) ⬆️
python3.8 95.25% <96.07%> (-0.65%) ⬇️
python3.9 95.83% <96.07%> (+<0.01%) ⬆️
torch1.3.1 95.18% <96.07%> (+<0.01%) ⬆️
torch1.4.0 95.25% <96.07%> (+<0.01%) ⬆️
torch1.9.0 95.83% <96.07%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchmetrics/text/wer.py 95.23% <93.33%> (-4.77%) ⬇️
torchmetrics/functional/text/wer.py 97.22% <97.22%> (+3.47%) ⬆️
torchmetrics/metric.py 95.48% <0.00%> (-0.31%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update db281f7...bbcce5d. Read the comment docs.

@kingyiusuen
Copy link
Contributor Author

Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files Coverage Δ
torchmetrics/utilities/imports.py 82.50% <ø> (-0.43%) ⬇️
torchmetrics/text/wer.py 95.23% <93.33%> (-4.77%) ⬇️
torchmetrics/functional/text/wer.py 97.22% <97.14%> (+3.47%) ⬆️
torchmetrics/average.py 84.61% <0.00%> (-11.54%) ⬇️

Not sure why the test coverage of torchmetrics/average.py changed. I didn't even modify this file.

@Borda Borda added the enhancement New feature or request label Aug 15, 2021
@Borda Borda added this to the v0.5 milestone Aug 15, 2021
tests/text/test_wer.py Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Contributor

This is awesome :) any chance we could get a quick speed comparison to the existing implementation on master and this implementation? just to ensure we're not regressing in terms of how long it take to compute the measurement

@Borda
Copy link
Member

Borda commented Aug 16, 2021

Not sure why the test coverage of torchmetrics/average.py changed. I didn't even modify this file.

this can be meantime results before also GPU coverage is aggregated, do not wories about sections which you did not touch 🐰

@SkafteNicki SkafteNicki modified the milestones: v0.5, v0.6 Aug 16, 2021
@kingyiusuen
Copy link
Contributor Author

kingyiusuen commented Aug 16, 2021

This is awesome :) any chance we could get a quick speed comparison to the existing implementation on master and this implementation? just to ensure we're not regressing in terms of how long it take to compute the measurement

Sure. @SeanNaren do you have any code example of how the speed comparison can be done?

@SkafteNicki
Copy link
Member

@kingyiusuen could you make some variation of the following script which I used to measure some optimization I did to the accuracy metric?

import torch
import torchmetrics
from time import time
import numpy as np

accuracy_train = torchmetrics.Accuracy(num_classes=5, average='none')
x = torch.rand(1, 5, 3, 28, 28)
y = torch.randint(0, 5, (1, 3, 28, 28))

N_reps = 50

def run(device):
    accuracy_train.to(device)
    x_ = x.to(device)
    y_ = y.to(device)
    times = []
    for _ in range(N_reps):
        start = time()
        accuracy_train(x_, y_)
        times.append(time() - start)
    times = np.array(times)
    print(f"Timing {device}: {np.mean(times)}+-{np.std(times)}")

run("cpu")
run("cuda")

just adjust with metric that you want to use and they try to run it first on the master branch and then on this branch and note down the times.

Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, else LGTM

torchmetrics/functional/text/wer.py Outdated Show resolved Hide resolved
torchmetrics/functional/text/wer.py Outdated Show resolved Hide resolved
torchmetrics/functional/text/wer.py Outdated Show resolved Hide resolved
torchmetrics/text/wer.py Outdated Show resolved Hide resolved
torchmetrics/text/wer.py Outdated Show resolved Hide resolved
torchmetrics/text/wer.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready label Aug 17, 2021
@kingyiusuen
Copy link
Contributor Author

@SkafteNicki The existing implementation is little faster, I think it is because jiwer uses C to calculate the edit distance.

New implementation:
Timing cpu: 0.00024170875549316405+-0.0002145938151226467
Timing cuda: 0.0003341794013977051+-0.00012605400813783062

Existing implementation:
Timing cpu: 0.00012010812759399414+-9.325457291475573e-05
I got an error when I run the existing implementation on GPU.

<ipython-input-3-75ab6f7af532> in run(device)
     11 
     12 def run(device):
---> 13     wer_train.to(device)
     14     times = []
     15     for _ in range(N_reps):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in to(self, *args, **kwargs)
    850             return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
    851 
--> 852         return self._apply(convert)
    853 
    854     def register_backward_hook(

/content/metrics/torchmetrics/metric.py in _apply(self, fn)
    427                 setattr(this, key, fn(current_val))
    428             elif isinstance(current_val, Sequence):
--> 429                 setattr(this, key, [fn(cur_v) for cur_v in current_val])
    430             else:
    431                 raise TypeError(

/content/metrics/torchmetrics/metric.py in <listcomp>(.0)
    427                 setattr(this, key, fn(current_val))
    428             elif isinstance(current_val, Sequence):
--> 429                 setattr(this, key, [fn(cur_v) for cur_v in current_val])
    430             else:
    431                 raise TypeError(

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in convert(t)
    848                 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
    849                             non_blocking, memory_format=convert_to_format)
--> 850             return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
    851 
    852         return self._apply(convert)

AttributeError: 'list' object has no attribute 'to'

@SkafteNicki
Copy link
Member

Hi @kingyiusuen,
The time increase is not something to worry about especially when the old implementation did not support gpu.

@Borda Borda modified the milestones: v0.6, v0.5 Aug 18, 2021
@Borda Borda enabled auto-merge (squash) August 18, 2021 09:37
@Borda Borda merged commit 689b218 into Lightning-AI:master Aug 18, 2021
Borda pushed a commit that referenced this pull request Aug 18, 2021
* Remove dependency on jiwer for WER

* Add deprecation warning for concatenate_texts and add documnetation

* Use jiwer as reference in tests

* Apply suggestions from review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: SkafteNicki <skaftenicki@gmail.com>
(cherry picked from commit 689b218)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants