Skip to content

Commit

Permalink
Raise an error on non-hashable static arguments for jax.jit and xla_c…
Browse files Browse the repository at this point in the history
…omputation.

Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`:

```
class WrapHashably(object):
  __slots__ = ["val"]
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return id(self.val)
  def __eq__(self, other):
    return self.val is other.val
```

This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with:

@partial(jax.jit, static_argnums=(1,))
def sum(a, b):
  return a+ b
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
# The next line will recompile, because the 1-indexed argument is non
# hashable and thus compared by identity with different instances
sum(np.asarray([1,2,3]), np.asarray([4,5,6])

or more simply
np.pad(a, [2, 3], 'constant', constant_values=(4, 6))
          ^^^^^^
          non-hashable static argument.

The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about).

If this commit breaks you, you usually have one of the following options:
- If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static.
- When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset.
- You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function.

PiperOrigin-RevId: 339351798
  • Loading branch information
jblespiau authored and jax authors committed Oct 27, 2020
1 parent a7de694 commit cb48f42
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
11 changes: 4 additions & 7 deletions jax/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
try:
hash(static_arg)
except TypeError:
logging.warning(
"Static argument (index %s) of type %s for function %s is "
"non-hashable. As this can lead to unexpected cache-misses, it "
"will raise an error in a near future.", i, type(static_arg),
f.__name__)
# e.g. ndarrays, DeviceArrays
fixed_args[i] = WrapHashably(static_arg) # type: ignore
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
else:
fixed_args[i] = Hashable(static_arg) # type: ignore

Expand Down
18 changes: 15 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,18 @@ def test_jit_reference_dropping(self):
del g # no more references to x
assert x() is None # x is gone

def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):
if self.jit != jax.api._python_jit:
raise unittest.SkipTest("this test only applies to _python_jit")
f = lambda x, y: x + 3
jitted_f = self.jit(f, static_argnums=(1,))

msg = ("Non-hashable static arguments are not supported, as this can lead "
"to unexpected cache-misses. Static argument (index 1) of type "
"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
jitted_f(1, np.asarray(1))

def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
if version < (0, 1, 58):
raise unittest.SkipTest("Disabled because it depends on some future "
Expand All @@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self):

jitted_f(1, 1)

msg = (
"""Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, 1. The error was:
TypeError: unhashable type: 'numpy.ndarray'""")
msg = ("Non-hashable static arguments are not supported. An error occured "
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")

with self.assertRaisesRegex(ValueError, re.escape(msg)):
jitted_f(1, np.asarray(1))
Expand Down

3 comments on commit cb48f42

@gnecula
Copy link
Collaborator

@gnecula gnecula commented on cb48f42 Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ran into this in a Flax example. This is a non-backwards-compatible change, it ought to be mentioned in the CHANGELOG.md

@gnecula
Copy link
Collaborator

@gnecula gnecula commented on cb48f42 Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jblespiau Perhaps add a mention to CHANGELOG?

@jblespiau
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jblespiau Perhaps add a mention to CHANGELOG?

Sent https://critique-ng.corp.google.com/cl/340395327

Please sign in to comment.