-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matrix multiplication precision API #2161
Comments
On TPUs, "HIGH" precision corresponds to 3 passes of bfloat16 and "HIGHEST" precision corresponds to 6 passes, which is effectively full float32 (dropping denormals), as explained in this Intel paper. With that in mind, and considering that ideally we would retain some flexibility for alternative Some ideas:
I think I lean towards option 3? |
Notes from offline discussion:
Given that we want to support platform specific options, descriptive names seem like the best bet. The main remaining concern is what to call "3 pass bfloat16" precision on TPUs, which approximates roughly 16 bits of precision for the significand. "intermediate" precision would be OK for TPUs, but seems very vague in general. Maybe |
“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs |
Users have also requested a way to set a more "global" default precision. One possible mechanism to do this is via a scope, e.g.:
I would suggest that it should override only operations with default precision. |
I assume you mean only for |
Yes, I meant |
I just spent several day chasing numerical stability issue until i pin-pointed it the matmul precision. Using "low" precision by default, seems like a very questionable design decision, outside of optimized pipelines with carefully measured speed/quality trade-offs. |
@marksandler2 The issue is that there are at least two communities of people using JAX:
I'm not sure there's any default setting that makes both groups happy. You didn't say if you are using TPU or GPU, but on TPU, at least, one can argue that if you didn't want fast lower-precision matmuls most of the time then using a TPU is an odd choice. It's more complicated on GPU. |
Operations that do matrix multiplication in JAX accept a
precison
for controlling precision when executed on TPUs.The current API seems non-ideal to me:
np.dot(x, y, precision=lax.Precison.HIGHEST)
). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g.,np.dot(x, y, precision='highest')
).@
infix operator.The text was updated successfully, but these errors were encountered: