Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use argmax when topk=1 #419

Merged
merged 7 commits into from
Aug 3, 2021
Merged

Use argmax when topk=1 #419

merged 7 commits into from
Aug 3, 2021

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Aug 3, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #417
torch.topk has terrible performance on cpu. This PR exchanges the call to torch.topk to torch.argmax when topk=1 which give better performance (much better on cpu, the same on gpu):

Using torch.topk:
Timing cpu: 0.03081380844116211+-0.006388481400645466
Timing cuda: 0.0037766695022583008+-0.0007294738042081203

Using torch.argmax:
Timing cpu: 0.0041289663314819335+-0.0004889978751252193
Timing cuda: 0.0038677167892456053+-0.002895336883318616

Script:

import torch
import torchmetrics
from time import time
import numpy as np

accuracy_train = torchmetrics.Accuracy(num_classes=5, average='none')
x = torch.rand(1, 5, 3, 28, 28)
y = torch.randint(0, 5, (1, 3, 28, 28))

N_reps = 50

def run(device):
    accuracy_train.to(device)
    x_ = x.to(device)
    y_ = y.to(device)
    times = []
    for _ in range(N_reps):
        start = time()
        accuracy_train(x_, y_)
        times.append(time() - start)
    times = np.array(times)
    print(f"Timing {device}: {np.mean(times)}+-{np.std(times)}")

run("cpu")
run("cuda")

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@SkafteNicki SkafteNicki added the enhancement New feature or request label Aug 3, 2021
@pep8speaks
Copy link

pep8speaks commented Aug 3, 2021

Hello @SkafteNicki! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-08-03 10:52:15 UTC

@codecov
Copy link

codecov bot commented Aug 3, 2021

Codecov Report

Merging #419 (e03878a) into master (21fe0ca) will decrease coverage by 0.07%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #419      +/-   ##
==========================================
- Coverage   96.08%   96.00%   -0.08%     
==========================================
  Files         126      126              
  Lines        4083     4085       +2     
==========================================
- Hits         3923     3922       -1     
- Misses        160      163       +3     
Flag Coverage Δ
Linux 74.66% <100.00%> (+0.01%) ⬆️
Windows 74.66% <100.00%> (+0.01%) ⬆️
cpu 96.00% <100.00%> (+<0.01%) ⬆️
gpu ?
macOS 96.00% <100.00%> (+<0.01%) ⬆️
pytest 96.00% <100.00%> (-0.08%) ⬇️
python3.6 95.25% <100.00%> (+<0.01%) ⬆️
python3.8 96.00% <100.00%> (+<0.01%) ⬆️
python3.9 95.91% <100.00%> (+<0.01%) ⬆️
torch1.3.1 95.25% <100.00%> (+<0.01%) ⬆️
torch1.4.0 95.32% <100.00%> (+<0.01%) ⬆️
torch1.9.0 95.91% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchmetrics/utilities/data.py 94.91% <100.00%> (+0.17%) ⬆️
torchmetrics/functional/regression/spearman.py 93.33% <0.00%> (-4.45%) ⬇️
torchmetrics/metric.py 95.16% <0.00%> (-0.31%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 21fe0ca...e03878a. Read the comment docs.

@Borda Borda enabled auto-merge (squash) August 3, 2021 11:00
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@mergify mergify bot added the ready label Aug 3, 2021
@Borda Borda merged commit 5a2388c into master Aug 3, 2021
@Borda Borda deleted the cpu_topk_performance branch August 3, 2021 12:28
@Borda Borda added this to the v0.5 milestone Aug 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Too slow on cpu.
5 participants