Skip to content

Commit

Permalink
unify configuration state handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 24, 2021
1 parent 22a2be3 commit fd7b286
Show file tree
Hide file tree
Showing 30 changed files with 267 additions and 195 deletions.
3 changes: 2 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
del _cloud_tpu_init

# flake8: noqa: F401
from .config import config
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
debug_nans, debug_infs, log_compiles)
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear):
f'called with operands of type {_avals_short(op_avals)}')

def cond_bind(*args, branches, linear):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_cond_typecheck(*avals, branches=branches, linear=linear)
for jaxpr in branches:
Expand Down Expand Up @@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
f'called with sequence of type\n{_avals_short(x_avals)}')

def scan_bind(*args, **params):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_scan_typecheck(True, *avals, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np

import jax
from jax.config import config

partial = functools.partial
Expand Down Expand Up @@ -192,7 +191,7 @@ def cached(_, *args, **kwargs):

@functools.wraps(f)
def wrapper(*args, **kwargs):
if jax.core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(bool(config.x64_enabled), *args, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from . import linear_util as lu
from . import ad_util
from . import dtypes
from .core import eval_jaxpr, checking_leaks
from .core import eval_jaxpr
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
Expand Down Expand Up @@ -362,7 +362,7 @@ def f_jitted(*args, **kwargs):
context = (getattr(core.thread_local_state.trace_state.trace_stack,
"dynamic", None), config.x64_enabled)
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(context, *args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
Expand All @@ -372,7 +372,7 @@ def f_jitted(*args, **kwargs):
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
Expand All @@ -389,7 +389,7 @@ def f_jitted(*args, **kwargs):
@api_boundary
def f_jitted(*args, **kwargs):
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(*args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
Expand All @@ -399,7 +399,7 @@ def f_jitted(*args, **kwargs):
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
Expand Down
182 changes: 161 additions & 21 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import sys
import threading

from jax import lib

def bool_env(varname: str, default: bool) -> bool:
Expand Down Expand Up @@ -42,11 +46,16 @@ def int_env(varname: str, default: int) -> int:


class Config:
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self._contextmanager_flags = set()

# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []

Expand All @@ -65,6 +74,13 @@ def update(self, name, val):
lib.jax_jit.global_state().enable_x64 = val

def read(self, name):
if name in self._contextmanager_flags:
raise AttributeError(
"For flags with a corresponding contextmanager, read their value "
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
return self._read(name)

def _read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
Expand Down Expand Up @@ -143,14 +159,82 @@ def disable_omnistaging(self):
disabler()
self.omnistaging_enabled = False

@property
def x64_enabled(self):
return lib.jax_jit.get_enable_x64()

# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
lib.jax_jit.thread_local_state().enable_x64 = bool(state)

# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
# @property
# def x64_enabled(self):
# return lib.jax_jit.get_enable_x64()

# def _set_x64_enabled(self, state):
# lib.jax_jit.thread_local_state().enable_x64 = bool(state)

def define_bool_state(self, name: str, default: bool, help: str):
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag and corresponding
thread-local state, which can be managed via the contextmanager it returns.
The thread-local state value can be read via the ``config.<option_name>``
attribute, where ``config`` is the singleton ``Config`` instance.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
Example:
enable_foo = config.define_bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
self._contextmanager_flags.add(name)

def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

@contextlib.contextmanager
def set_state(new_val: bool):
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
else:
setattr(_thread_local_state, name, prev_val)
set_state.__name__ = name[4:] if name.startswith('jax_') else name
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state

_thread_local_state = threading.local()

class Unset: pass
unset = Unset()

class NameSpace(object):
def __init__(self, getter):
Expand All @@ -166,11 +250,6 @@ def __getattr__(self, name):

already_configured_with_absl = False

flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)'
)

flags.DEFINE_bool(
'jax_omnistaging',
Expand All @@ -184,14 +263,6 @@ def __getattr__(self, name):
help='Set the number of stack frames in JAX tracer error messages.'
)

flags.DEFINE_bool(
'jax_check_tracer_leaks',
bool_env('JAX_CHECK_TRACER_LEAKS', False),
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'),
)

flags.DEFINE_bool(
'jax_host_callback_inline',
bool_env('JAX_HOST_CALLBACK_INLINE', False),
Expand All @@ -206,3 +277,72 @@ def __getattr__(self, name):
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)


enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')

check_tracer_leaks = config.define_bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'))
checking_leaks = functools.partial(check_tracer_leaks, True)

debug_nans = config.define_bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))

debug_infs = config.define_bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))

log_compiles = config.define_bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `absl.logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))

# Because jax_enable_x64 is managed by C++ code, we don't reuse the
# config.define_bool_state mechanism, though conceptually it is the same.
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
help='Enable 64-bit types to be used')
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)

@contextlib.contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
Usage::
>>> import jax.numpy as jnp
>>> with enable_x64(True):
... print(jnp.arange(10.0).dtype)
...
float64
"""
prev_val = config.jax_enable_x64
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
try:
yield
finally:
lib.jax_jit.thread_local_state().enable_x64 = prev_val
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun

# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
# backward compatibility.
Config.x64_enabled = Config.jax_enable_x64
Loading

0 comments on commit fd7b286

Please sign in to comment.