Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix multiplication precision API #2161

Closed
shoyer opened this issue Feb 4, 2020 · 8 comments · Fixed by #6143
Closed

Matrix multiplication precision API #2161

shoyer opened this issue Feb 4, 2020 · 8 comments · Fixed by #6143

Comments

@shoyer
Copy link
Collaborator

shoyer commented Feb 4, 2020

Operations that do matrix multiplication in JAX accept a precison for controlling precision when executed on TPUs.

The current API seems non-ideal to me:

  1. You have to pass an enum value (e.g., np.dot(x, y, precision=lax.Precison.HIGHEST)). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g., np.dot(x, y, precision='highest')).
  2. The current names for precision levels ("highest", "high" and "default") are not very descriptive. In my ideal world we would use some direct indication of the corresponding precision (e.g., bfloat16 multiplication with float32 accumulation), but as the very least can we switch "default" to "low"?
  3. The default low precision is a bit of a footgun, at least when doing anything that isn't implementing a neural net layer. In my opinion, it would be much safer to use "highest" precision by default (which isn't that much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the @ infix operator.
@shoyer
Copy link
Collaborator Author

shoyer commented Apr 17, 2020

On TPUs, "HIGH" precision corresponds to 3 passes of bfloat16 and "HIGHEST" precision corresponds to 6 passes, which is effectively full float32 (dropping denormals), as explained in this Intel paper.

With that in mind, and considering that ideally we would retain some flexibility for alternative matmul optimizations that might appears on other platforms, what more descriptive naming scheme makes sense for values of the precision argument?

Some ideas:

  1. 'low', 'high', 'highest': description of precision level
  2. 'fastest', 'fast', 'slow': description of speed
  3. 'fastest', 'fast', 'accurate': mixed description, using only positive words
  4. 'bfloat16', 'float24', 'float32': rough precision of the underlying arithmetic (but what is "float24"??)

I think I lean towards option 3?

@shoyer
Copy link
Collaborator Author

shoyer commented Apr 17, 2020

Notes from offline discussion:

  • This is really a "minimum precision" configuration, so perhaps a name like min_precision would be more appropriate.
  • The other way to configure matmul precision (maybe more obvious) is by explicitly setting dtype. XLA will use lowest precision on bfloat16 data regardless of the precision option.
  • We want an API that also can support new matmul precision options as they arise on different platforms (GPU, CPU, etc), e.g., precision={'tpu': 'bfloat16', 'gpu': 'float16'}.
  • Another option would be to specify precision numerically, e.g., precision=1e-2 or precision=1e-6. But this mixes together precision in the significand and the exponent, which misses important nuances like bfloat16 vs float16.

Given that we want to support platform specific options, descriptive names seem like the best bet.

The main remaining concern is what to call "3 pass bfloat16" precision on TPUs, which approximates roughly 16 bits of precision for the significand. "intermediate" precision would be OK for TPUs, but seems very vague in general. Maybe bfloat24 or bfloat16_3x would be appropriate? (We could also support bfloat16_6x as a more precise description of float32.)

@shoyer
Copy link
Collaborator Author

shoyer commented May 16, 2020

“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs

@hawkinsp
Copy link
Collaborator

Users have also requested a way to set a more "global" default precision.

One possible mechanism to do this is via a scope, e.g.:

with jax.precision("highest"):
  ...

I would suggest that it should override only operations with default precision.

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 19, 2020

I would suggest that it should override only operations with default precision.

I assume you mean only for precision=None, rather than the confusingly named precision=lax.Precison.DEFAULT (aka bfloat16)?

@hawkinsp
Copy link
Collaborator

Yes, I meant None, not what we are currently calling DEFAULT.

@marksandler2
Copy link
Contributor

I just spent several day chasing numerical stability issue until i pin-pointed it the matmul precision. Using "low" precision by default, seems like a very questionable design decision, outside of optimized pipelines with carefully measured speed/quality trade-offs.

@hawkinsp
Copy link
Collaborator

hawkinsp commented May 8, 2023

@marksandler2 The issue is that there are at least two communities of people using JAX:

  • machine learning researchers/practitioners, who mostly expect speed over precision
  • scientific users, who expect precision over speed.

I'm not sure there's any default setting that makes both groups happy.

You didn't say if you are using TPU or GPU, but on TPU, at least, one can argue that if you didn't want fast lower-precision matmuls most of the time then using a TPU is an odd choice. It's more complicated on GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants