Skip to content

Commit

Permalink
add jax_default_matmul_precision flag & context mngr
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 24, 2021
1 parent 7b4c2e3 commit 8dd05f0
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 51 deletions.
3 changes: 2 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

# flake8: noqa: F401
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
debug_nans, debug_infs, log_compiles)
debug_nans, debug_infs, log_compiles,
default_matmul_precision, numpy_rank_promotion)
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
Expand Down
84 changes: 55 additions & 29 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,17 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
Precision = xla_client.PrecisionConfig.Precision
Precision.__str__ = lambda precision: precision.name
PrecisionType = Any
PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]]

PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]
_precision_strings = {
'highest': Precision.HIGHEST,
'float32': Precision.HIGHEST,
'bfloat16_3x': Precision.HIGH,
'tensorfloat32': Precision.HIGH,
'bfloat16': Precision.DEFAULT,
'fastest': Precision.DEFAULT,
None: Precision.DEFAULT,
}

class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
Expand Down Expand Up @@ -555,42 +564,44 @@ def conv_general_dilated(
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`.
dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or
a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a
string of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
'fastest', see the ``jax.default_matmul_precision`` context manager), or a
tuple of two ``lax.Precision`` enums or strings indicating precision of
``lhs`` and ``rhs``.
Returns:
An array containing the convolution result.
In the string case of `dimension_numbers`, each character identifies by
In the string case of ``dimension_numbers``, each character identifies by
position:
- the batch dimensions in `lhs`, `rhs`, and the output with the character
- the batch dimensions in ``lhs``, ``rhs``, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between lhs, rhs, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the `conv` function
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
another example, to indicate dimension numbers consistent with the TensorFlow
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
latter form of convolution dimension specification, window strides are
associated with spatial dimension character labels according to the order in
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
is matched with the dimension corresponding to the first character
appearing in rhs_spec that is not `'I'` or `'O'`.
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
(for a 2D convolution).
For example, to indicate dimension numbers consistent with the ``conv``
function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
'NCHW')``. As another example, to indicate dimension numbers consistent with
the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
When using the latter form of convolution dimension specification, window
strides are associated with spatial dimension character labels according to
the order in which the labels appear in the ``rhs_spec`` string, so that
``window_strides[0]`` is matched with the dimension corresponding to the first
character appearing in rhs_spec that is not ``'I'`` or ``'O'``.
If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
'NCHW')`` (for a 2D convolution).
"""
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
if lhs_dilation is None:
Expand Down Expand Up @@ -6394,16 +6405,31 @@ def remaining(original, *removed_lists):

def _canonicalize_precision(precision):
if precision is None:
return None
if isinstance(precision, Precision) or (
isinstance(precision, tuple)
and len(precision) == 2
and all(isinstance(p, Precision) for p in precision)
):
if config.jax_default_matmul_precision is None:
return None
try:
return _precision_strings[config.jax_default_matmul_precision]
except KeyError:
raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in "
f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
) from None
elif isinstance(precision, str) and precision in _precision_strings:
return _precision_strings.get(precision)
elif isinstance(precision, Precision):
return precision
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, Precision) for p in precision)):
return precision
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)):
s1, s2 = precision
return (_canonicalize_precision(s1), _canonicalize_precision(s2))
else:
raise ValueError("Precision argument must be None, a lax.Precision value "
f"or a tuple of two lax.Precision values; got {precision}")
raise ValueError(
f"Precision argument must be None, a string in {_precision_strings}, "
"a lax.Precision value or a tuple of two lax.Precision values or "
f"strings; got {precision}.")


def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers
Expand Down
17 changes: 4 additions & 13 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import collections
import collections.abc
import operator
import os
import types
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union, cast
from textwrap import dedent as _dedent
Expand All @@ -45,7 +44,7 @@
from jax import dtypes
from jax import errors
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import flags, config
from jax.config import config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax.interpreters.masking import Poly
from jax import lax
Expand All @@ -55,14 +54,6 @@
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
from jax.tree_util import tree_leaves, tree_flatten, tree_map

FLAGS = flags.FLAGS
flags.DEFINE_enum(
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
enum_values=['allow', 'warn', 'raise'],
help=
'Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").')

newaxis = None

# Common docstring additions:
Expand Down Expand Up @@ -247,20 +238,20 @@ def _promote_shapes(fun_name, *args):
if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
return args
else:
if FLAGS.jax_numpy_rank_promotion != "allow":
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]

def _rank_promotion_warning_or_error(fun_name, shapes):
if FLAGS.jax_numpy_rank_promotion == "warn":
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif FLAGS.jax_numpy_rank_promotion == "raise":
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
Expand Down
91 changes: 86 additions & 5 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import os
import sys
import threading
from typing import List, Callable, Optional

from jax import lib
from typing import Callable, Optional

def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
Expand Down Expand Up @@ -52,7 +52,7 @@ class Config:
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.FLAGS = NameSpace(self.read, self.update)
self.use_absl = False
self._contextmanager_flags = set()

Expand Down Expand Up @@ -255,18 +255,70 @@ def set_state(new_val: bool):
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state

def define_enum_state(self, name: str, enum_values: List[str],
default: Optional[str], help: str):
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional string, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
See docstring for ``define_bool_state``.
"""
name = name.lower()
self.DEFINE_enum(name, os.getenv(name.upper(), default),
enum_values=enum_values, help=help)
self._contextmanager_flags.add(name)

def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

@contextlib.contextmanager
def set_state(new_val: Optional[str]):
if (new_val is not None and
(type(new_val) is not str or new_val not in enum_values)):
raise ValueError(f"new enum value must be None or in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
else:
setattr(_thread_local_state, name, prev_val)
set_state.__name__ = name[4:] if name.startswith('jax_') else name
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state


_thread_local_state = threading.local()

class Unset: pass
unset = Unset()

class NameSpace(object):
def __init__(self, getter):
self._getter = getter
class NameSpace:
def __init__(self, getter, setter):
# must use super because we override this class's __setattr__, see
# https://docs.python.org/3/reference/datamodel.html#object.__setattr__
super().__setattr__('_getter', getter)
super().__setattr__('_setter', setter)

def __getattr__(self, name):
return self._getter(name)

def __setattr__(self, name, val):
self._setter(name, val)


config = Config()
flags = config
Expand Down Expand Up @@ -357,3 +409,32 @@ def _update_x64_thread_local(val):
config._contextmanager_flags.remove("jax_enable_x64")

Config.x64_enabled = Config.jax_enable_x64 # type: ignore


numpy_rank_promotion = config.define_enum_state(
name='jax_numpy_rank_promotion',
enum_values=['allow', 'warn', 'raise'],
default='allow',
help=('Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").'))

default_matmul_precision = config.define_enum_state(
name='jax_default_matmul_precision',
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
default=None,
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'

'Some platforms, like TPU, offer configurable precision levels for '
'matrix multiplication and convolution computations, trading off '
'accuracy for speed. The precision can be controlled for each '
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
'the default behavior obtained when an operation is not given a '
'specific precision.\n\n'

'This option can be used to control the default precision '
'level for computations involved in matrix multiplication and '
'convolution on 32bit inputs. The levels roughly describe the '
"precision at which scalar products are computed. The 'bfloat16' "
"option is the fastest and least precise; 'float32' is similar to "
"full float32 precision; 'tensorfloat32' is intermediate.\n\n"))
53 changes: 53 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import weakref
import functools
import itertools as it
import operator as op

from absl import logging
from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -2399,6 +2400,58 @@ def test_large_python_int_to_float(self):
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))

def test_dot_precision_context_manager(self):
x = jnp.zeros((2, 2))

with jax.default_matmul_precision(None):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=DEFAULT', str(jaxpr))

with jax.default_matmul_precision("bfloat16"):
x @ x # doesn't crash
jaxpr = jax.make_jaxpr(op.matmul)(x, x)
self.assertIn('precision=DEFAULT', str(jaxpr))

with jax.default_matmul_precision("tensorfloat32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGH\n', str(jaxpr))

with jax.default_matmul_precision("float32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr))

dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
with jax.default_matmul_precision("tensorfloat32"):
dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr))

def test_dot_precision_flag(self):
x = jnp.zeros((2, 2))

prev_val = config._read("jax_default_matmul_precision")
try:
config.FLAGS.jax_default_matmul_precision = "tensorfloat32"
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.FLAGS.jax_default_matmul_precision = prev_val
self.assertIn('precision=HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))

prev_val = config._read("jax_default_matmul_precision")
try:
config.update('jax_default_matmul_precision','tensorfloat32')
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.update('jax_default_matmul_precision', prev_val)
self.assertIn('precision=HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))


class RematTest(jtu.JaxTestCase):

Expand Down
Loading

0 comments on commit 8dd05f0

Please sign in to comment.