Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add jax_default_matmul_precision flag and context manager #6143

Merged
merged 1 commit into from
Mar 24, 2021

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 19, 2021

Notice the "diffbase" on the flag-cleanup branch, i.e. #6112. See the comment on #6112 for more context.

fixes #2161

This PR adds a configuration option jax_default_matmul_precision and a context manager jax.default_matmul_precision(...) to control the default precision of internal computations used in matrix multiplies and convolutions on float32 inputs for supported backends (currently just TPU but likely soon A100 GPUs as well).

For example, say we have a function foo which includes a jnp.dot call:

def foo(...):
  ...
  z = jnp.dot(x, y)  # x and y are float32 arrays
  ...

We can ensure that dot is computed at the highest (or lowest) precision using any of these methods:

  1. we can set the shell environment variable JAX_DEFAULT_MATMUL_PRECISION=float32 (or JAX_DEFAULT_MATMUL_PRECISION=bfloat16);
  2. if our main script parses flags with absl, then we can use the command-line flag --jax_default_matmul_precision=float32 (or --jax_default_matmul_precision=bfloat16);
  3. at the top of our main script, we can put jax.config.update('jax_default_matmul_precision', 'float32') (or jax.config.update('jax_default_matmul_precision', 'bfloat16'));
  4. when we call foo, we can use the jax.default_matmul_precision context manager:
with jax.default_matmul_precision('float32'):   # or 'bfloat16' for lowest
  ... = foo(...)

This configuration option controls the default precision in the sense that convolution operations like lax.conv_general_dilated and matrix multiply operations like lax.dot take an optional precision argument. This configuration option does not change the behaviors of such calls with explicit precision arguments; it only changes the behaviors of calls with no such argument provided.

This PR does not change the default default dot precision; that remains 'bfloat16'. It only adds new ways to control the default dot precision. We might change the default in follow-up work.

In follow-up work we may add an analogous bit of enum state for controlling the default device. But I think we should land this part first.

TODO:

  • write PR message
  • add tests

cc @rohan-anil @sharadmv @shoyer @SiegeLordEx @jonbarron

@google-cla google-cla bot added the cla: yes label Mar 19, 2021
@mattjj mattjj force-pushed the precision-flag branch 4 times, most recently from 047be0a to c2ef1fc Compare March 20, 2021 04:09
@mattjj mattjj marked this pull request as ready for review March 20, 2021 04:12
@mattjj mattjj requested a review from hawkinsp March 20, 2021 04:12
jax/_src/lax/lax.py Outdated Show resolved Hide resolved
jax/config.py Outdated Show resolved Hide resolved
tests/api_test.py Outdated Show resolved Hide resolved
tests/api_test.py Show resolved Hide resolved
@mattjj mattjj force-pushed the flag-cleanup branch 3 times, most recently from ad63cc3 to f564880 Compare March 23, 2021 04:37
@mattjj mattjj changed the base branch from flag-cleanup to master March 23, 2021 14:59
Tuple[PrecisionType, PrecisionType]]
_precision_strings = {
'bfloat16': Precision.DEFAULT,
'tensorfloat32': Precision.HIGH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether tensorfloat32 is an good description. On A100 I believe CUDA tensorfloat32 gives you 10 bits of mantissa on the input. And I'm not sure that HIGH in XLA's naming actually gives you tensorfloat32 on GPU.

On TPU it means something different, namely a multipass algorithm on bfloat16 inputs. Is conflating the two is wise? I suspect we should just have two different names here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's probably misleading to use "tensorfloat32" if XLA doesn't actually support that on GPU (but clearly it should, at least in some form). This might be a good time to consult with the XLA GPU team to see how they feel about conflating Precision.HIGH and tensorfloat32.

See #2161 (comment) for notes on possible names. My favorite alternatives were bfloat24 and bfloat16_3x.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the nice thing about strings instead of enums is that we can support multiple redundant names -- they don't have to be unique, and we can extend it over time.

So perhaps we could support all of bfloat16/bfloat16_1x, bfloat16_3x and float32/bfloat16_6x for now, and extend that list later when XLA adds true support for tensorfloat32 on GPUs, in whatever form that takes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you want a couple of generic options:

  • fastest
  • most precise

and after that point you want to be completely specific about which algorithm you mean.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so for the current state of things, how about:

  • precision='highest' -> lax.Precision.HIGHEST
  • precision='float32' -> lax.Precision.HIGHEST
  • precision='bfloat16_3x' -> lax.Precision.HIGH
  • precision='bfloat16' -> lax.Precision.DEFAULT
  • precision='fastest' -> lax.Precision.DEFAULT

(We could add the bfloat16_1x and bfloat16_6x aliases as well, if desired)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jekbradbury

Maybe we should discuss on chat for lower latency?

@shoyer
Copy link
Collaborator

shoyer commented Mar 23, 2021

"matmul_precision" might be a more descriptive name than "dot_precision", as it only effects matrix-matrix multiplication and convolutions, which are arguably a special case of matrix/matrix multiplication.

@mattjj mattjj changed the title add jax_default_dot_precision flag and context manager add jax_default_matmul_precision flag and context manager Mar 24, 2021
@mattjj mattjj force-pushed the precision-flag branch 2 times, most recently from bc56413 to 8dd05f0 Compare March 24, 2021 20:26
@mattjj mattjj added the pull ready Ready for copybara import and testing label Mar 24, 2021
@froystig froystig added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Mar 24, 2021
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Mar 24, 2021
@mattjj
Copy link
Collaborator Author

mattjj commented Mar 24, 2021

Ping for robots...

@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Mar 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Matrix multiplication precision API
4 participants