Skip to content

Commit

Permalink
unify api w/ jax_enable_x64 flag
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 23, 2021
1 parent 6b60bf5 commit f564880
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 47 deletions.
40 changes: 33 additions & 7 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def disable_omnistaging(self):
disabler()
self.omnistaging_enabled = False

# TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
@property
def x64_enabled(self):
return lib.jax_jit.get_enable_x64()
# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
# @property
# def x64_enabled(self):
# return lib.jax_jit.get_enable_x64()

def _set_x64_enabled(self, state):
lib.jax_jit.thread_local_state().enable_x64 = bool(state)
# def _set_x64_enabled(self, state):
# lib.jax_jit.thread_local_state().enable_x64 = bool(state)

def define_bool_state(self, name: str, default: bool, help: str):
"""Set up thread-local state and return a contextmanager for managing it.
Expand Down Expand Up @@ -289,7 +289,6 @@ def __getattr__(self, name):
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '

'is disabled, and other overheads may be added.'))
checking_leaks = functools.partial(check_tracer_leaks, True)

Expand All @@ -316,3 +315,30 @@ def __getattr__(self, name):
'computation. Logging is performed with `absl.logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))

# Because jax_enable_x64 is managed by C++ code, we don't reuse the
# config.define_bool_state mechanism, though conceptually it is the same.
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
help='Enable 64-bit types to be used')
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)

@contextlib.contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.
Usage::
>>> import jax.numpy as jnp
>>> with enable_x64(True):
... print(jnp.arange(10.0).dtype)
...
float64
"""
prev_val = config.jax_enable_x64
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
try:
yield
finally:
lib.jax_jit.thread_local_state().enable_x64 = prev_val
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
config._contextmanager_flags.add('jax_enable_x64')
8 changes: 0 additions & 8 deletions jax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,19 @@
# so we need our own implementation that deviates from NumPy in places.


from distutils.util import strtobool
import functools
import os
from typing import Dict

import numpy as np

from ._src import util
from .config import flags, config
from . import lib
from .lib import xla_client

from ._src import traceback_util
traceback_util.register_exclusion(__file__)

FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')
lib.jax_jit.global_state().enable_x64 = strtobool(
os.getenv('JAX_ENABLE_X64', 'False'))

# bfloat16 support
bfloat16: type = xla_client.bfloat16
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_bfloat16_returned_by_jax(self):
dtype=dtype)
for dtype in [np.int64, np.float64]))
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.FLAGS.jax_enable_x64:
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
big_const = np.full((5,), 2 ** 33, dtype=dtype)
self.ConvertAndCompare(jnp.sin, big_const)
Expand Down
38 changes: 10 additions & 28 deletions jax/experimental/x64_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,17 @@
**Experimental: please give feedback, and expect changes.**
"""

from contextlib import contextmanager
from jax import config

@contextmanager
def enable_x64():
"""Experimental context manager to temporarily enable X64 mode.
# This file provides
# 1. a jax.experimental API endpoint
# 2. a monkey-patch to jax.config.Config for the `x64_enabled` property
# 3. the `disable_x64` wrapper
# TODO(jakevdp): remove this file, and consider removing `config.x64_enabled`
# and `disable_x64` for uniformity

Usage::
>>> import jax.numpy as jnp
>>> with enable_x64():
... print(jnp.arange(10.0).dtype)
...
float64
from contextlib import contextmanager
from jax.config import Config, enable_x64

See Also
--------
jax.experimental.disable_x64 : temporarily disable X64 mode.
"""
_x64_state = config.x64_enabled
config._set_x64_enabled(True)
try:
yield
finally:
config._set_x64_enabled(_x64_state)
Config.x64_enabled = Config.jax_enable_x64

@contextmanager
def disable_x64():
Expand All @@ -59,9 +45,5 @@ def disable_x64():
--------
jax.experimental.enable_x64 : temporarily enable X64 mode.
"""
_x64_state = config.x64_enabled
config._set_x64_enabled(False)
try:
with enable_x64(False):
yield
finally:
config._set_x64_enabled(_x64_state)
2 changes: 1 addition & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def skip_on_flag(flag_name, skip_value):
def skip(test_method): # pylint: disable=missing-docstring
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
flag_value = getattr(FLAGS, flag_name)
flag_value = config._read(flag_name)
if flag_value == skip_value:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_scipy_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def tree_unflatten(cls, aux_data, children):
))
def test_bicgstab_against_scipy(
self, shape, dtype, preconditioner):
if not config.FLAGS.jax_enable_x64:
if not config.jax_enable_x64:
raise unittest.SkipTest("requires x64 mode")

rng = jtu.rand_default(self.rng())
Expand Down
2 changes: 1 addition & 1 deletion tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_correctly_capture_default(self, jit, enable_or_disable):
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
func()

expected_dtype = "float64" if config.read("jax_enable_x64") else "float32"
expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
self.assertEqual(func().dtype, expected_dtype)

with enable_x64():
Expand Down

0 comments on commit f564880

Please sign in to comment.