Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved docs for (default) matmul precision #10413

Open
sanchit-gandhi opened this issue Apr 22, 2022 · 1 comment
Open

Improved docs for (default) matmul precision #10413

sanchit-gandhi opened this issue Apr 22, 2022 · 1 comment
Labels
documentation enhancement New feature or request

Comments

@sanchit-gandhi
Copy link

sanchit-gandhi commented Apr 22, 2022

  1. It is somewhat unintuitive that the default matmul precision is bfloat16 on TPU, especially for users coming from PyTorch/GPU where the default precision is float32. Information regarding the default matrix multiplication precision on TPUs is extremely difficult to find. There is a short section on the README.md within the cloud TPU Colab folder of the JAX repo: https://github.com/google/jax/tree/main/cloud_tpu_colabs#bfloat16-dtype However, this is somewhat unclear, as it references 'MXUs' without any explanation of what this abbreviation means, and only highlights how the default precision can be changed manually on a op-by-op basis by setting precision=jax.lax.Precision.XXX. This gives the impression that in order to change the TPU precision to float32, one must insert the key-word argument precision=jax.lax.Precision.HIGHEST for every jax.numpy operation in one's script.

  2. It is difficult to find how the default precision can be changed. Performing matmul operations in the default bfloat16 precision can lead to undesirable results. At Hugging Face, we're constantly running into problems with the default fast-speed low precision TPU default, as shown here for example: Diverging PT-Flax Wav2Vec2 Hidden-States huggingface/transformers#15754
    In the case of changing the default matmul precision, the docs do make mention to the default matmul precision context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html However, they do not explicitly state how one can use this context manager to change the default matmul precision (for instance with an example). It's hard to know from the docs that you have to write your code under the context manager as follows:

with jax.default_matmul_precision('float32'):   # or 'bfloat16' for lowest
  ... = foo(...)

The docs also brush over three additional methods for changing the default matmul precision, highlighted brilliantly in this PR: #6143 (comment) These three methods require no change to one's actual script, just the inclusion of a shell/command line flag or a JAX config change, and are arguably much easier to use and less obtrusive.

It would be great if the default matmul precisions for CPU/GPU/TPU were documented, along with what bfloat16, tensorfloat16, float32 precision actually mean for matmul precision in terms of number of passes. It would also be super helpful if all four methods for manipulating the default precision were added to the docs with short examples on how to use them, as done in the aforementioned PR.

@shoyer
Copy link
Collaborator

shoyer commented Apr 4, 2023

Note that default precision for matrix-matrix multiplication is actually now tensorfloat32 on recent Nvidia GPUs: #14022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants