Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device and dtype properties #462

Merged
merged 18 commits into from
Aug 26, 2021
Merged

Device and dtype properties #462

merged 18 commits into from
Aug 26, 2021

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Aug 18, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #455
Issue describes how the Bootstrapper metric does not work currently on gpu. Trying to fix this made me realize that we do not have a easy way of getting the device and dtype of a metric. This PR implements the logic from the DeviceDtypeMixin class taken from lightning into the core Metric class.
https://github.com/PyTorchLightning/pytorch-lightning/blob/38ceb8943ef9b858abead1fbba43ea9a9b4cd93b/pytorch_lightning/core/mixins/device_dtype_mixin.py

Additional, solve the issue using the new property :]

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@SkafteNicki SkafteNicki added bug / fix Something isn't working enhancement New feature or request labels Aug 18, 2021
@SkafteNicki SkafteNicki added this to the v0.6 milestone Aug 18, 2021
@mergify mergify bot removed the has conflicts label Aug 18, 2021
@codecov
Copy link

codecov bot commented Aug 18, 2021

Codecov Report

Merging #462 (eede3fe) into master (94a158c) will decrease coverage by 0%.
The diff coverage is 84%.

@@          Coverage Diff          @@
##           master   #462   +/-   ##
=====================================
- Coverage      96%    96%   -0%     
=====================================
  Files         130    130           
  Lines        4301   4341   +40     
=====================================
+ Hits         4126   4159   +33     
- Misses        175    182    +7     

torchmetrics/metric.py Outdated Show resolved Hide resolved
torchmetrics/metric.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready label Aug 18, 2021
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean metrics used in a lightning module where mixed precision training is used would be converted to use fp16 precision as well? is that always desirable? do people want to compute metrics in fp32 while doing the rest of model training in fp16?

@SkafteNicki
Copy link
Member Author

@ananthsub the PR does not actually introduce that kind of change, it has always been the case in TM that if you cast your metric to fp16 the metric states would also be casted (since we have overridden the self._apply method):
https://github.com/PyTorchLightning/metrics/blob/689b2189c6f2aff3968d94d8e5fcfdb85dc5b98a/torchmetrics/metric.py#L414-L441
This PR just makes sure that when half(), cpu(), cuda(), to(...) is called we have some local properties which tracks this.

If people want fp32 metrics when doing mixed training, I am not sure about. I am not sure that it matters for that many to have the extra precision during training. However, when it comes to testing, it is very clear for me that users should be using fp32.

@Borda Borda requested a review from ananthsub August 19, 2021 08:19
CHANGELOG.md Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
torchmetrics/metric.py Show resolved Hide resolved
torchmetrics/metric.py Show resolved Hide resolved
torchmetrics/text/bert.py Show resolved Hide resolved
@maximsch2
Copy link
Contributor

On mixed-precision case - I agree with Ananth that this is potentially a concern (especially for metrics with accumulations - fp16 will overflow at ~64k so having a 100k sample dataset and doing fp16 training will break even simple metrics like Accuracy), but not new thing introduced by this diff. Let's file an issue and track it? I'm assuming people will get nan/inf as a result of metric in case of fp16 overflow so at least we are not going to silently screw them up.

Worth it to have an example of how to do mixed-precision metric calculation in the docs once this question comes up though.

@mergify mergify bot removed the has conflicts label Aug 24, 2021
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great !

@mergify mergify bot removed the has conflicts label Aug 26, 2021
@Borda
Copy link
Member

Borda commented Aug 26, 2021

but not new thing introduced by this diff. Let's file an issue and track it? I'm assuming people will get nan/inf as a result of metric in case of fp16 overflow so at least we are not going to silently screw them up.

yes, pls do so 🐰

@Borda Borda enabled auto-merge (squash) August 26, 2021 07:31
@Borda
Copy link
Member

Borda commented Aug 26, 2021

@SkafteNicki mind checking /resolving the last comments?

@Borda Borda merged commit b10dba4 into master Aug 26, 2021
@Borda Borda deleted the device_placement branch August 26, 2021 08:17
@ananthsub
Copy link
Contributor

@SkafteNicki we are having a very related discussion about device & dtype properties here: https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#heading=h.cvihcwdhwas5

Given metrics are nn.Modules, what happens if metrics have parameters which live on different devices or have different dtypes? Then we're at odds with this: pytorch/pytorch#7460 (comment)

This makes metrics a restricted set of modules, which could potentially limit use cases in the future.

@leezu
Copy link

leezu commented Aug 28, 2021

@ananthsub

does this mean metrics used in a lightning module where mixed precision training is used would be converted to use fp16 precision as well? is that always desirable? do people want to compute metrics in fp32 while doing the rest of model training in fp16?

It's not desirable, but this behavior was already introduced by accident in cda5dbd. I opened #484 for tracking.

Borda pushed a commit that referenced this pull request Aug 30, 2021
* add gpu testing
* change super
* move to metric + simplify
* fix bert
* update docs
* add typing
* changelog

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit b10dba4)
@SkafteNicki SkafteNicki mentioned this pull request Sep 2, 2021
4 tasks
@SkafteNicki
Copy link
Member Author

@ananthsub, @maximsch2 please see PR #493 that should fix the problems with auto cast. half, double, float is getting disabled for now and will not change the dtype of the metric states, which should hopefully fix the problems with mixed precision training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GPU support in BootStrapper
7 participants