Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: torch==1.12 will toggle torch.backends.matmul.allow_tf32 to False - what should we do? #16588

Closed
stas00 opened this issue Apr 4, 2022 · 15 comments
Assignees

Comments

@stas00
Copy link
Contributor

stas00 commented Apr 4, 2022

Ampere GPUs added a new mode called TF32. Pytorch created a new flag to support the TF32 mode enabling using torch.backends.matmul.allow_tf32 which has been True by default in pytorch since it was added.

Having this mode on means that matrix multiplications when inputs were in FP32 were actually done in TF32, which made the math significantly faster, albeit less precise (TF32 has the dynamic range of BF16, and the precision of FP16).

The NVIDIA engineers have done many experiments and have found that Deep Learning training accuracy doesn't get impacted for worse by using TF32 instead of FP32 (and often is better), but it provides a significant speed up. It's easy to see from the A100 spec why:

FP32 |  19.5 TFLOPS
TF32 | 156   TFLOPS 

(numbers with no sparsity)

And the accuracy tables are:
AI_training_TF32_tensor_cores_F3-1024x565 from Accelerating AI Training with NVIDIA TF32 Tensor Cores

However, the lost precision for some non-DL applications is a problem. Therefore starting from pytorch 1.12 (already in nightly shortly) the default for torch.backends.matmul.allow_tf32 will be False, which won't make the training accuracy worse, but it'll make fp32 training significantly slower. So if you believe we should remain consistent/back compatible - most likely we should turn it back on for pt>1.11:

if version.parse(torch.__version__) > version.parse("1.11"):
    torch.backends.matmul.allow_tf32 = True

at a single point which always gets executed for pytorch users.

The question is whether this should be done:

  1. Not at all - let the user sort it out
  2. Transformers-wide
  3. Only in HF Trainer (and Accelerate) and if not done add a new flag to let the user control the behavior

Additionally other use-modes should be made in sync:

  1. PyTorch/XLA (some other flag?)

Currently tf32 and how to flip it on/off is documented here: https://huggingface.co/docs/transformers/performance#tf32

A detailed discussion with multiple links to other related resources is here: https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504

@LysandreJik, @sgugger, @patrickvonplaten, @patil-suraj

@stas00 stas00 self-assigned this Apr 4, 2022
@ngimel
Copy link

ngimel commented Apr 4, 2022

You don't have to condition the torch.backends.matmul.allow_tf32 = True on torch version, on previous pytorch version it'll just be a no-op.

@stas00
Copy link
Contributor Author

stas00 commented Apr 4, 2022

The main reason for the conditional suggestion was to be self-documenting, but w/o the conditional this code will fail in older pytorch, for example:

$ python -c "import torch; print(torch.__version__); torch.backends.matmul.allow_tf32 = True"
1.8.1+cu102
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'torch.backends' has no attribute 'matmul'

@stas00
Copy link
Contributor Author

stas00 commented Apr 5, 2022

@mruberry shared on slack, that jax has a similar flag jax-ml/jax#6143 should you want to make this behavior consistent across all 3 frameworks and/or to make it configurable. And they too have a default that not appreciated by all who expect fp32 to be fp32: jax-ml/jax#7010

@stas00 stas00 changed the title pytorch is turning off torch.backends.matmul.allow_tf32 for torch > 1.11 RFC: torch==1.12 will toggle torch.backends.matmul.allow_tf32 to False - what should we do? Apr 5, 2022
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 5, 2022

Just to understand:
PyTorch added tf32, set it to True by default (from which version to which version?) and now reverted the default to False?

I think I'm in favor of not overwriting torch.backends.matmul.allow_tf32 = True, but instead add good documentation to let the user decide what to do here. Also happy to add a tf32 flag to the Trainer which I would also set to False though. Think overwriting torch.backends.matmul.allow_tf32 = True gets us out of sync with PyTorch and might lead to unexpected behavior no?

E.g. if a user does:

import torch
torch.backends.matmul.allow_tf32 = False

import transformers

....

Also I think it's a good rule of thumb that in PyTorch by default, always the highest precision, lowest speed is enabled.

Think we don't have to or shouldn't care about JAX here really as the default precision / device behavior is already very different (e.g. JAX uses lowest precision on TPU by default, uses GPU/TPU by default in contrast to PyTorch)

@gante
Copy link
Member

gante commented Apr 5, 2022

Tensorflow has it active by default and has a flag to control it (docs). I'd say we don't need to touch it in TF, but happy to go with a solution that minimizes PT-TF interface differences.

@sgugger
Copy link
Collaborator

sgugger commented Apr 5, 2022

This is a very complicated as on the one hand, we don't want to change the PyTorch default and surprise the user, but on the other hand we don't want most of our beginner users to experience degraded performance in training on most GPUs without them knowing why (as this change will be hidden in PyTorch release notes).

I'm also in favor of not touching PyTorch's default (the same way we don't turn on things link torch.backends.cudnn.benchmark or torch.backends.cudnn.deterministic) and leave it to the user, but we do need proper documentation. Also in favor of having a TrainingArguments flag to make it easier for the user to turn on in our examples.

@mruberry
Copy link

mruberry commented Apr 5, 2022

Just to understand: PyTorch added tf32, set it to True by default (from which version to which version?) and now reverted the default to False?

Small point of clarification: we have not changed the default to False at this time, but expect to do so in the future.

Also I think it's a good rule of thumb that in PyTorch by default, always the highest precision, lowest speed is enabled.

Agreed! This is the principal that motivated this change.

We will also have user-facing documentation beyond the release notes when this change is part of PyTorch release, because we agree this change has the potential to be surprising and disruptive to current Ampere users. We'll also provide a recommendation for developers when making this change in nightlies.

@stas00
Copy link
Contributor Author

stas00 commented Apr 5, 2022

Just to understand: PyTorch added tf32, set it to True by default (from which version to which version?) and now reverted the default to False?

I think it was added in pt-1.9, since 1.8 doesn't have this flag. see #16588 (comment)

and the plan is to revert to False in pt-1.12, but, of course, this will happen sooner in pt-nightly.

So it has been set to True in pt: 1.9, 1.10, 1.11

@stas00
Copy link
Contributor Author

stas00 commented Apr 8, 2022

Also in favor of having a TrainingArguments flag to make it easier for the user to turn on in our examples.

I forgot that I added it already when we added bf16 support:

tf32 (`bool`, *optional*):
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
and it may change.

Except it has no default setting, I guess we keep it that way w/o default?

@stas00
Copy link
Contributor Author

stas00 commented Apr 8, 2022

I'm also in favor of not touching PyTorch's default (the same way we don't turn on things link torch.backends.cudnn.benchmark or torch.backends.cudnn.deterministic) and leave it to the user, but we do need proper documentation.

Please review the current doc and suggest if anything needs to be changed:
https://huggingface.co/docs/transformers/performance#tf32

Thank you!

@sgugger
Copy link
Collaborator

sgugger commented Apr 8, 2022

Yes, that doc is great. We should also expand a bit the documentation of the flag in TrainingArguments (and link to this doc) since this where users might get to TF32 the first time. That flag should indeed be left without default (and leave it to the current PyTorch version default)

@stas00
Copy link
Contributor Author

stas00 commented Apr 8, 2022

Thank you for reviewing and the feedback, Sylvain.

Here is a PR: #16674

@github-actions
Copy link

github-actions bot commented May 5, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@mruberry
Copy link

mruberry commented May 5, 2022

FYI pytorch/pytorch#76509 has landed, and while it may not be perfect we think it achieves the goal of giving users device agnostic control over fp32 matmul precision. Please don't hesitate to reach out if you have additional questions, I'll also be producing additional documentation on this change ahead of the PyTorch 1.12 release.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants