From 0796bfe6e71e7a1500d8a7e10b4d4db9edca20c0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 23 Mar 2021 11:23:20 -0700 Subject: [PATCH] errors: add NonConcreteBooleanIndexError & debugging tips --- docs/errors.rst | 1 + docs/jax.ops.rst | 1 + jax/_src/errors.py | 132 +++++++++++++++++++++++++++++++++--- jax/_src/numpy/lax_numpy.py | 6 +- jax/errors.py | 2 + tests/errors_test.py | 2 +- 6 files changed, 130 insertions(+), 14 deletions(-) diff --git a/docs/errors.rst b/docs/errors.rst index 98f65b1ce6d5..4d93426e26f4 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -7,5 +7,6 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError +.. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerIntegerConversionError \ No newline at end of file diff --git a/docs/jax.ops.rst b/docs/jax.ops.rst index f333d7a2e422..cdc6a072cdd4 100644 --- a/docs/jax.ops.rst +++ b/docs/jax.ops.rst @@ -24,6 +24,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives. index_min index_max +.. _syntactic-sugar-for-ops: Syntactic sugar for indexed update operators -------------------------------------------- diff --git a/jax/_src/errors.py b/jax/_src/errors.py index a230bbfeffd4..d3c20a271992 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -14,16 +14,26 @@ from jax import core -class JAXTypeError(TypeError): - """Base class for JAX-specific TypeErrors""" +class _JAXErrorMixin: + """Mixin for JAX-specific errors""" _error_page = 'https://jax.readthedocs.io/en/latest/errors.html' + _module_name = "jax.errors" def __init__(self, message: str): error_page = self._error_page - module_name = getattr(self, '_module_name', self.__class__.__module__) + module_name = self._module_name class_name = self.__class__.__name__ error_msg = f'{message} ({error_page}#{module_name}.{class_name})' - super().__init__(error_msg) + # https://github.com/python/mypy/issues/5887 + super().__init__(error_msg) # type: ignore + + +class JAXTypeError(_JAXErrorMixin, TypeError): + pass + + +class JAXIndexError(_JAXErrorMixin, IndexError): + pass class ConcretizationTypeError(JAXTypeError): @@ -130,14 +140,120 @@ class ConcretizationTypeError(JAXTypeError): To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. """ - _module_name = "jax.errors" - def __init__(self, tracer: "core.Tracer", context: str = ""): super().__init__( "Abstract tracer value encountered where concrete value is expected: " f"{tracer}\n{context}\n{tracer._origin_msg()}\n") +class NonConcreteBooleanIndexError(JAXIndexError): + """ + This error occurs when a program attempts to use non-concrete boolean indices + in a traced indexing operation. Under JIT compilation, JAX arrays must have static + shapes (i.e. shapes that are known at compile-time) and so boolean masks must be + used carefully. Some logic implemented via boolean masking is simply not possible + under JAX's JIT compilation model; in other cases, the logic can be re-expressed in + a JIT-compatible way, often using the three-argument version of :func:`~jax.numpy.where`. + + Following are a few examples of when this error might arise. + + Constructing arrays via boolean masking + This most commonly arises when attempting to create an array via a boolean mask + within a JIT context. For example:: + + >>> import jax + >>> import jax.numpy as jnp + + >>> @jax.jit + ... def positive_values(x): + ... return x[x > 0] + + >>> positive_values(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10]) + + This function is attempting to return only the positive values in the input array; the size of + this returned array cannot be determined at compile-time unless `x` is marked as static, and so + operations like this cannot be performed under JIT compilation. + + Reexpressible Boolean Logic + Although creating dynamically sized arrays is not supported directly, in many cases it is + possible to re-express the logic of the computation in terms of a JIT-compatible operation. + For example, here is another function that fails under JIT for the same reason:: + + >>> @jax.jit + ... def sum_of_positive(x): + ... return x[x > 0].sum() + + >>> sum_of_positive(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10]) + + In this case, however, the problematic array is only an intermediate value, and we can + instead express the same logic in terms of the JIT-compatible three-argument version of + :func:`jax.numpy.where`:: + + >>> @jax.jit + ... def sum_of_positive(x): + ... return jnp.where(x > 0, x, 0).sum() + + >>> sum_of_positive(jnp.arange(-5, 5)) + DeviceArray(10, dtype=int32) + + This pattern of replacing boolean masking with three-argument :func:`~jax.numpy.where` is a + common solution to this sort of problem. + + Boolean indices in :mod:`jax.ops` + The other situation where this error often arises is when using boolean indices within functions + in :mod:`jax.ops`, such as :func:`jax.ops.index_update`. Here is a simple example:: + + >>> @jax.jit + ... def manual_clip(x): + ... return jax.ops.index_update(x, x < 0, 0) + + >>> manual_clip(jnp.arange(-2, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4]) + + This function is attempting to set values smaller than zero to a scalar fill value. As above, + this can be addressed by re-expressing the logic in terms of :func:`~jax.numpy.where`:: + + >>> @jax.jit + ... def manual_clip(x): + ... return jnp.where(x < 0, 0, x) + + >>> manual_clip(jnp.arange(-2, 2)) + DeviceArray([0, 0, 0, 1], dtype=int32) + + These operations also commonly are written in terms of the :ref:`syntactic-sugar-for-ops`; + for example, this is syntactic sugar for :func:`~jax.ops.index_mul`, and fails under JIT:: + + >>> @jax.jit + ... def manual_abs(x): + ... return x.at[x < 0].mul(-1) + + >>> manual_abs(jnp.arange(-2, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4]) + + As above, the solution is to re-express this in terms of :func:`~jax.numpy.where`:: + + >>> @jax.jit + ... def manual_abs(x): + ... return jnp.where(x < 0, x * -1, x) + + >>> manual_abs(jnp.arange(-2, 2)) + DeviceArray([2, 1, 0, 1], dtype=int32) + """ + def __init__(self, tracer: "core.Tracer"): + super().__init__( + f"Array boolean indices must be concrete; got {tracer}\n") + + class TracerArrayConversionError(JAXTypeError): """ This error occurs when a program attempts to convert a JAX Tracer object into a @@ -205,8 +321,6 @@ class TracerArrayConversionError(JAXTypeError): To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. """ - _module_name = "jax.errors" - def __init__(self, tracer: "core.Tracer"): super().__init__( "The numpy.ndarray conversion method __array__() was called on " @@ -293,8 +407,6 @@ class TracerIntegerConversionError(JAXTypeError): To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. """ - _module_name = "jax.errors" - def __init__(self, tracer: "core.Tracer"): super().__init__( f"The __index__() method was called on the JAX Tracer object {tracer}") diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b299c1a2bf12..4c2fbe9c91cf 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -43,6 +43,7 @@ from .util import _wraps from jax import core from jax import dtypes +from jax import errors from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape from jax.config import flags, config from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray @@ -4782,9 +4783,8 @@ def _expand_bool_indices(idx): abstract_i = core.get_aval(i) if not type(abstract_i) is ConcreteArray: - # TODO(mattjj): improve this error by tracking _why_ the indices are not - # concrete - raise IndexError("Array boolean indices must be concrete.") + # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete + raise errors.NonConcreteBooleanIndexError(abstract_i) else: out.extend(np.where(i)) else: diff --git a/jax/errors.py b/jax/errors.py index 3481fbbdb690..71c6c6ec343f 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -14,6 +14,8 @@ # flake8: noqa: F401 from ._src.errors import (JAXTypeError, + JAXIndexError, ConcretizationTypeError, + NonConcreteBooleanIndexError, TracerArrayConversionError, TracerIntegerConversionError) diff --git a/tests/errors_test.py b/tests/errors_test.py index 75fd6387850a..255747ce1c6e 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -313,7 +313,7 @@ class CustomErrorsTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(errorclass), "errorclass": errorclass} for errorclass in dir(jax.errors) - if errorclass.endswith('Error') and errorclass != 'JAXTypeError')) + if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError'])) def testErrorsURL(self, errorclass): class FakeTracer(core.Tracer): aval = None