From f56488031fe7d36445794889eb695d53d60e1607 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 22 Mar 2021 15:48:42 -0700 Subject: [PATCH] unify api w/ jax_enable_x64 flag --- jax/config.py | 40 ++++++++++++++++---- jax/dtypes.py | 8 ---- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- jax/experimental/x64_context.py | 38 +++++-------------- jax/test_util.py | 2 +- tests/lax_scipy_sparse_test.py | 2 +- tests/x64_context_test.py | 2 +- 7 files changed, 47 insertions(+), 47 deletions(-) diff --git a/jax/config.py b/jax/config.py index 477794ab14e2..45f69c15f2b4 100644 --- a/jax/config.py +++ b/jax/config.py @@ -159,13 +159,13 @@ def disable_omnistaging(self): disabler() self.omnistaging_enabled = False - # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below - @property - def x64_enabled(self): - return lib.jax_jit.get_enable_x64() +# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below +# @property +# def x64_enabled(self): +# return lib.jax_jit.get_enable_x64() - def _set_x64_enabled(self, state): - lib.jax_jit.thread_local_state().enable_x64 = bool(state) +# def _set_x64_enabled(self, state): +# lib.jax_jit.thread_local_state().enable_x64 = bool(state) def define_bool_state(self, name: str, default: bool, help: str): """Set up thread-local state and return a contextmanager for managing it. @@ -289,7 +289,6 @@ def __getattr__(self, name): default=False, help=('Turn on checking for leaked tracers as soon as a trace completes. ' 'Enabling leak checking may have performance impacts: some caching ' - 'is disabled, and other overheads may be added.')) checking_leaks = functools.partial(check_tracer_leaks, True) @@ -316,3 +315,30 @@ def __getattr__(self, name): 'computation. Logging is performed with `absl.logging`. When this ' 'option is set, the log level is WARNING; otherwise the level is ' 'DEBUG.')) + +# Because jax_enable_x64 is managed by C++ code, we don't reuse the +# config.define_bool_state mechanism, though conceptually it is the same. +config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False), + help='Enable 64-bit types to be used') +lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False) + +@contextlib.contextmanager +def enable_x64(new_val: bool = True): + """Experimental context manager to temporarily enable X64 mode. + + Usage:: + + >>> import jax.numpy as jnp + >>> with enable_x64(True): + ... print(jnp.arange(10.0).dtype) + ... + float64 + """ + prev_val = config.jax_enable_x64 + lib.jax_jit.thread_local_state().enable_x64 = bool(new_val) + try: + yield + finally: + lib.jax_jit.thread_local_state().enable_x64 = prev_val +Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64()) +config._contextmanager_flags.add('jax_enable_x64') diff --git a/jax/dtypes.py b/jax/dtypes.py index dc1397d2d6f5..2987b35e55e6 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -20,27 +20,19 @@ # so we need our own implementation that deviates from NumPy in places. -from distutils.util import strtobool import functools -import os from typing import Dict import numpy as np from ._src import util from .config import flags, config -from . import lib from .lib import xla_client from ._src import traceback_util traceback_util.register_exclusion(__file__) FLAGS = flags.FLAGS -flags.DEFINE_bool('jax_enable_x64', - strtobool(os.getenv('JAX_ENABLE_X64', 'False')), - 'Enable 64-bit types to be used.') -lib.jax_jit.global_state().enable_x64 = strtobool( - os.getenv('JAX_ENABLE_X64', 'False')) # bfloat16 support bfloat16: type = xla_client.bfloat16 diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index af2d76fbdc54..080119d89d19 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -135,7 +135,7 @@ def test_bfloat16_returned_by_jax(self): dtype=dtype) for dtype in [np.int64, np.float64])) def test_converts_64bit(self, dtype=np.int64, with_function=False): - if not config.FLAGS.jax_enable_x64: + if not config.jax_enable_x64: self.skipTest("requires x64 mode") big_const = np.full((5,), 2 ** 33, dtype=dtype) self.ConvertAndCompare(jnp.sin, big_const) diff --git a/jax/experimental/x64_context.py b/jax/experimental/x64_context.py index 089f3a3690dc..64af6b5ac034 100644 --- a/jax/experimental/x64_context.py +++ b/jax/experimental/x64_context.py @@ -17,31 +17,17 @@ **Experimental: please give feedback, and expect changes.** """ -from contextlib import contextmanager -from jax import config - -@contextmanager -def enable_x64(): - """Experimental context manager to temporarily enable X64 mode. +# This file provides +# 1. a jax.experimental API endpoint +# 2. a monkey-patch to jax.config.Config for the `x64_enabled` property +# 3. the `disable_x64` wrapper +# TODO(jakevdp): remove this file, and consider removing `config.x64_enabled` +# and `disable_x64` for uniformity - Usage:: - - >>> import jax.numpy as jnp - >>> with enable_x64(): - ... print(jnp.arange(10.0).dtype) - ... - float64 +from contextlib import contextmanager +from jax.config import Config, enable_x64 - See Also - -------- - jax.experimental.disable_x64 : temporarily disable X64 mode. - """ - _x64_state = config.x64_enabled - config._set_x64_enabled(True) - try: - yield - finally: - config._set_x64_enabled(_x64_state) +Config.x64_enabled = Config.jax_enable_x64 @contextmanager def disable_x64(): @@ -59,9 +45,5 @@ def disable_x64(): -------- jax.experimental.enable_x64 : temporarily enable X64 mode. """ - _x64_state = config.x64_enabled - config._set_x64_enabled(False) - try: + with enable_x64(False): yield - finally: - config._set_x64_enabled(_x64_state) diff --git a/jax/test_util.py b/jax/test_util.py index bd19b0bc6975..732435323e41 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -445,7 +445,7 @@ def skip_on_flag(flag_name, skip_value): def skip(test_method): # pylint: disable=missing-docstring @functools.wraps(test_method) def test_method_wrapper(self, *args, **kwargs): - flag_value = getattr(FLAGS, flag_name) + flag_value = config._read(flag_name) if flag_value == skip_value: test_name = getattr(test_method, '__name__', '[unknown test]') raise unittest.SkipTest( diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 18c8fd960667..0216be48f6d4 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -208,7 +208,7 @@ def tree_unflatten(cls, aux_data, children): )) def test_bicgstab_against_scipy( self, shape, dtype, preconditioner): - if not config.FLAGS.jax_enable_x64: + if not config.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 84b37013929d..9e809d9a805b 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -73,7 +73,7 @@ def test_correctly_capture_default(self, jit, enable_or_disable): func = _maybe_jit(jit, lambda: jnp.arange(10.0)) func() - expected_dtype = "float64" if config.read("jax_enable_x64") else "float32" + expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64():