Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unify configuration state handling #6112

Merged
merged 1 commit into from
Mar 24, 2021
Merged

unify configuration state handling #6112

merged 1 commit into from
Mar 24, 2021

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Mar 18, 2021

Reviewer: The main action is in config.py. The other changed files are downstream of those changes.

This change arose from working on adding a jax_tpu_dot_precision flag / context manager (see #6143).

We have a few configurable bits of global state. These bits of configurable state are managed by flags and sometimes dynamically by context managers, but neither the APIs nor implementations are uniform. Moreover the state is spread out through a few different files.

The bits of global state we have in mind are:

  1. core.skip_checks (state in core.py, flag defined in config.py, not properly thread-local, has a context manager, does not affect jit dispatch since it's all about trace-time errors)
  2. core.debug_state.check_leaks (state in core.py, flag defined in config.py, thread-local, has a context manager, does not affect jit dispatch since it's all about trace-time errors)
  3. jax_debug_nans / jax_debug_infs (flag defined in xla.py, not properly thread-local, no context manager, affects jit dispatch in that it adds checks to every execution on the Python side) (affects jit dispatch)
  4. jax_log_compiles (flag defined in xla.py, not properly thread-local, no context manager, does not affect jit dispatch since it's all about trace-time logging)
  5. jax_enable_x64 (work-in-progress, state in jax_jit.cc, context manager being developed in jax.experimental, thread-local, affects jit dispatch in that it's part of the compilation cache key and affects how input arguments are handled in the c++ code)
  6. jax_default_dot_precision (work-in-progress, not present yet, affects jit dispatch analogously to jax_enable_x64)
  7. disable_jit (state in jax_jit.cc, affects jit dispatch in that it's part of the compilation cache key)
  8. jax_numpy_rank_promotion (flag defined in lax_numpy.py, not thread-local and no context manager, does not affect jit dispatch in that it's all about trace-time errors)

This PR unifies all the boolean-valued instances of Python state via a single mechanism in config.py which sets up flags, thread-local state, and context manager APIs. This PR doesn't touch jax_default_dot_precision or jax_numpy_rank_promotion because those are enums rather than booleans; it doesn't touch disable_jit or jax_enable_x64 because those are in C++.

Another effect of this PR is introducing new context managers: jax.enable_checks, jax.check_tracer_leaks, jax.debug_nans, jax.debug_infs, and jax.log_compiles. Each takes a single boolean argument.

Follow-up work might put more of these bits in C++ (i.e. in jax_jit.cc) for fast dispatch, and/or speed up dispatch times for Python state bits. That work should be easier once we collect all the state in one place as in this PR. It's also follow-up work to unify the API with jax_enable_x64 (cc @jakevdp), and to add an enum version of this logic for jax_default_dot_precision, jax_numpy_rank_promotion, and perhaps the default device. After discussing with @jakevdp , we decided to unify with the implementation of jax_enable_x64 in this PR, but leave the API endpoints for the x64 stuff unchanged.

Benchmark results on benchmarks/api_benchmark.py show no real differences AIUI:

name                                old cpu/op  new cpu/op  delta
jit_trivial_dispatch                43.7µs ± 3%  43.3µs ± 2%   ~     (p=0.690 n=5+5)
jit_trivial                         44.9µs ± 2%  44.7µs ± 2%   ~     (p=0.690 n=5+5)
jit_simple_dispatch                 15.3µs ± 1%  15.4µs ± 4%   ~     (p=1.000 n=5+5)
jit_simple                          16.2µs ± 4%  16.1µs ± 3%   ~     (p=1.000 n=5+5)
jit_simple_many_args_dispatch_10    20.1µs ± 2%  20.5µs ± 3%   ~     (p=0.421 n=5+5)
jit_simple_many_args_10             21.7µs ± 2%  22.0µs ± 2%   ~     (p=0.548 n=5+5)
jit_simple_many_args_dispatch_100   65.1µs ± 1%  65.1µs ± 1%   ~     (p=0.841 n=5+5)
jit_simple_many_args_100            66.1µs ± 1%  66.4µs ± 1%   ~     (p=0.222 n=5+5)
jit_simple_many_args_dispatch_1000   535µs ± 2%   538µs ± 2%   ~     (p=0.310 n=5+5)
jit_simple_many_args_1000            549µs ± 3%   555µs ± 2%   ~     (p=0.421 n=5+5)
jit_simple_many_args_dispatch_2000  1.12ms ± 4%  1.11ms ± 2%   ~     (p=0.841 n=5+5)
jit_simple_many_args_2000           1.13ms ± 3%  1.13ms ± 3%   ~     (p=1.000 n=5+5)
jit_dispatch_without_transfer        116µs ± 6%   118µs ± 6%   ~     (p=0.841 n=5+5)
jit_dispatch_with_transfer           126µs ± 7%   124µs ±10%   ~     (p=0.841 n=5+5)

This PR doesn't currently include tests, though the code is pretty thoroughly exercised by existing test coverage for skip_checks,check_leaks, debug_nans, disable_jit, etc.

@google-cla google-cla bot added the cla: yes label Mar 18, 2021
@mattjj mattjj added pull ready Ready for copybara import and testing and removed cla: yes labels Mar 18, 2021
@google-cla google-cla bot added the cla: yes label Mar 18, 2021
@mattjj mattjj force-pushed the flag-cleanup branch 2 times, most recently from 6bf189c to 6930015 Compare March 18, 2021 01:31
@mattjj mattjj marked this pull request as ready for review March 18, 2021 01:32
@mattjj mattjj requested a review from hawkinsp March 18, 2021 01:32
jax/config.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@mattjj mattjj force-pushed the flag-cleanup branch 7 times, most recently from 2a85f85 to fa2766f Compare March 24, 2021 01:16
@copybara-service copybara-service bot merged commit d148a57 into master Mar 24, 2021
@mattjj mattjj deleted the flag-cleanup branch March 24, 2021 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants