-
Notifications
You must be signed in to change notification settings - Fork 419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove dependency on jiwer for WER #446
Conversation
Codecov Report
@@ Coverage Diff @@
## master #446 +/- ##
==========================================
- Coverage 95.92% 95.90% -0.03%
==========================================
Files 130 130
Lines 4274 4301 +27
==========================================
+ Hits 4100 4125 +25
- Misses 174 176 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Not sure why the test coverage of torchmetrics/average.py changed. I didn't even modify this file. |
This is awesome :) any chance we could get a quick speed comparison to the existing implementation on master and this implementation? just to ensure we're not regressing in terms of how long it take to compute the measurement |
this can be meantime results before also GPU coverage is aggregated, do not wories about sections which you did not touch 🐰 |
Sure. @SeanNaren do you have any code example of how the speed comparison can be done? |
@kingyiusuen could you make some variation of the following script which I used to measure some optimization I did to the accuracy metric? import torch
import torchmetrics
from time import time
import numpy as np
accuracy_train = torchmetrics.Accuracy(num_classes=5, average='none')
x = torch.rand(1, 5, 3, 28, 28)
y = torch.randint(0, 5, (1, 3, 28, 28))
N_reps = 50
def run(device):
accuracy_train.to(device)
x_ = x.to(device)
y_ = y.to(device)
times = []
for _ in range(N_reps):
start = time()
accuracy_train(x_, y_)
times.append(time() - start)
times = np.array(times)
print(f"Timing {device}: {np.mean(times)}+-{np.std(times)}")
run("cpu")
run("cuda") just adjust with metric that you want to use and they try to run it first on the master branch and then on this branch and note down the times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, else LGTM
@SkafteNicki The existing implementation is little faster, I think it is because jiwer uses C to calculate the edit distance. New implementation: Existing implementation: <ipython-input-3-75ab6f7af532> in run(device)
11
12 def run(device):
---> 13 wer_train.to(device)
14 times = []
15 for _ in range(N_reps):
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in to(self, *args, **kwargs)
850 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
851
--> 852 return self._apply(convert)
853
854 def register_backward_hook(
/content/metrics/torchmetrics/metric.py in _apply(self, fn)
427 setattr(this, key, fn(current_val))
428 elif isinstance(current_val, Sequence):
--> 429 setattr(this, key, [fn(cur_v) for cur_v in current_val])
430 else:
431 raise TypeError(
/content/metrics/torchmetrics/metric.py in <listcomp>(.0)
427 setattr(this, key, fn(current_val))
428 elif isinstance(current_val, Sequence):
--> 429 setattr(this, key, [fn(cur_v) for cur_v in current_val])
430 else:
431 raise TypeError(
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in convert(t)
848 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
849 non_blocking, memory_format=convert_to_format)
--> 850 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
851
852 return self._apply(convert)
AttributeError: 'list' object has no attribute 'to' |
Hi @kingyiusuen, |
* Remove dependency on jiwer for WER * Add deprecation warning for concatenate_texts and add documnetation * Use jiwer as reference in tests * Apply suggestions from review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: SkafteNicki <skaftenicki@gmail.com> (cherry picked from commit 689b218)
Before submitting
What does this PR do?
This PR is a part of the effort of #433 to implement own metrics and minimize third-party dependencies. It adds the own implementation of the word error rate (WER) metric.
This PR also strips the use of the
concatenate_texts
argument as this seems not to be used by any other metric.Furthermore, this PR changes the order of
references
andpredictions
in the function signature of the functional version of WER. Currently, as far as I can tell, all metrics (including the class-based version of WER) accept model predictions as their first argument and the ground truths second. This change can keep the API consistent. (But it breaks backward compatibility? Any suggestion?)PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃