Skip to content

Commit

Permalink
errors: add NonConcreteBooleanIndexError & debugging tips
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 23, 2021
1 parent 7890d6c commit 0796bfe
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ along with representative examples of how one might fix them.

.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerIntegerConversionError
1 change: 1 addition & 0 deletions docs/jax.ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
index_min
index_max

.. _syntactic-sugar-for-ops:

Syntactic sugar for indexed update operators
--------------------------------------------
Expand Down
132 changes: 122 additions & 10 deletions jax/_src/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@
from jax import core


class JAXTypeError(TypeError):
"""Base class for JAX-specific TypeErrors"""
class _JAXErrorMixin:
"""Mixin for JAX-specific errors"""
_error_page = 'https://jax.readthedocs.io/en/latest/errors.html'
_module_name = "jax.errors"

def __init__(self, message: str):
error_page = self._error_page
module_name = getattr(self, '_module_name', self.__class__.__module__)
module_name = self._module_name
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
super().__init__(error_msg)
# https://github.com/python/mypy/issues/5887
super().__init__(error_msg) # type: ignore


class JAXTypeError(_JAXErrorMixin, TypeError):
pass


class JAXIndexError(_JAXErrorMixin, IndexError):
pass


class ConcretizationTypeError(JAXTypeError):
Expand Down Expand Up @@ -130,14 +140,120 @@ class ConcretizationTypeError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer", context: str = ""):
super().__init__(
"Abstract tracer value encountered where concrete value is expected: "
f"{tracer}\n{context}\n{tracer._origin_msg()}\n")


class NonConcreteBooleanIndexError(JAXIndexError):
"""
This error occurs when a program attempts to use non-concrete boolean indices
in a traced indexing operation. Under JIT compilation, JAX arrays must have static
shapes (i.e. shapes that are known at compile-time) and so boolean masks must be
used carefully. Some logic implemented via boolean masking is simply not possible
under JAX's JIT compilation model; in other cases, the logic can be re-expressed in
a JIT-compatible way, often using the three-argument version of :func:`~jax.numpy.where`.
Following are a few examples of when this error might arise.
Constructing arrays via boolean masking
This most commonly arises when attempting to create an array via a boolean mask
within a JIT context. For example::
>>> import jax
>>> import jax.numpy as jnp
>>> @jax.jit
... def positive_values(x):
... return x[x > 0]
>>> positive_values(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
This function is attempting to return only the positive values in the input array; the size of
this returned array cannot be determined at compile-time unless `x` is marked as static, and so
operations like this cannot be performed under JIT compilation.
Reexpressible Boolean Logic
Although creating dynamically sized arrays is not supported directly, in many cases it is
possible to re-express the logic of the computation in terms of a JIT-compatible operation.
For example, here is another function that fails under JIT for the same reason::
>>> @jax.jit
... def sum_of_positive(x):
... return x[x > 0].sum()
>>> sum_of_positive(jnp.arange(-5, 5)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
In this case, however, the problematic array is only an intermediate value, and we can
instead express the same logic in terms of the JIT-compatible three-argument version of
:func:`jax.numpy.where`::
>>> @jax.jit
... def sum_of_positive(x):
... return jnp.where(x > 0, x, 0).sum()
>>> sum_of_positive(jnp.arange(-5, 5))
DeviceArray(10, dtype=int32)
This pattern of replacing boolean masking with three-argument :func:`~jax.numpy.where` is a
common solution to this sort of problem.
Boolean indices in :mod:`jax.ops`
The other situation where this error often arises is when using boolean indices within functions
in :mod:`jax.ops`, such as :func:`jax.ops.index_update`. Here is a simple example::
>>> @jax.jit
... def manual_clip(x):
... return jax.ops.index_update(x, x < 0, 0)
>>> manual_clip(jnp.arange(-2, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
This function is attempting to set values smaller than zero to a scalar fill value. As above,
this can be addressed by re-expressing the logic in terms of :func:`~jax.numpy.where`::
>>> @jax.jit
... def manual_clip(x):
... return jnp.where(x < 0, 0, x)
>>> manual_clip(jnp.arange(-2, 2))
DeviceArray([0, 0, 0, 1], dtype=int32)
These operations also commonly are written in terms of the :ref:`syntactic-sugar-for-ops`;
for example, this is syntactic sugar for :func:`~jax.ops.index_mul`, and fails under JIT::
>>> @jax.jit
... def manual_abs(x):
... return x.at[x < 0].mul(-1)
>>> manual_abs(jnp.arange(-2, 2)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
As above, the solution is to re-express this in terms of :func:`~jax.numpy.where`::
>>> @jax.jit
... def manual_abs(x):
... return jnp.where(x < 0, x * -1, x)
>>> manual_abs(jnp.arange(-2, 2))
DeviceArray([2, 1, 0, 1], dtype=int32)
"""
def __init__(self, tracer: "core.Tracer"):
super().__init__(
f"Array boolean indices must be concrete; got {tracer}\n")


class TracerArrayConversionError(JAXTypeError):
"""
This error occurs when a program attempts to convert a JAX Tracer object into a
Expand Down Expand Up @@ -205,8 +321,6 @@ class TracerArrayConversionError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer"):
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
Expand Down Expand Up @@ -293,8 +407,6 @@ class TracerIntegerConversionError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer"):
super().__init__(
f"The __index__() method was called on the JAX Tracer object {tracer}")
6 changes: 3 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .util import _wraps
from jax import core
from jax import dtypes
from jax import errors
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import flags, config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
Expand Down Expand Up @@ -4782,9 +4783,8 @@ def _expand_bool_indices(idx):
abstract_i = core.get_aval(i)

if not type(abstract_i) is ConcreteArray:
# TODO(mattjj): improve this error by tracking _why_ the indices are not
# concrete
raise IndexError("Array boolean indices must be concrete.")
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
else:
out.extend(np.where(i))
else:
Expand Down
2 changes: 2 additions & 0 deletions jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

# flake8: noqa: F401
from ._src.errors import (JAXTypeError,
JAXIndexError,
ConcretizationTypeError,
NonConcreteBooleanIndexError,
TracerArrayConversionError,
TracerIntegerConversionError)
2 changes: 1 addition & 1 deletion tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class CustomErrorsTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(errorclass), "errorclass": errorclass}
for errorclass in dir(jax.errors)
if errorclass.endswith('Error') and errorclass != 'JAXTypeError'))
if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']))
def testErrorsURL(self, errorclass):
class FakeTracer(core.Tracer):
aval = None
Expand Down

0 comments on commit 0796bfe

Please sign in to comment.