Skip to content

Commit cb48f42

Browse files
jblespiaujax authors
authored andcommitted
Raise an error on non-hashable static arguments for jax.jit and xla_computation.
Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`: ``` class WrapHashably(object): __slots__ = ["val"] def __init__(self, val): self.val = val def __hash__(self): return id(self.val) def __eq__(self, other): return self.val is other.val ``` This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with: @partial(jax.jit, static_argnums=(1,)) def sum(a, b): return a+ b sum(np.asarray([1,2,3]), np.asarray([4,5,6]) # The next line will recompile, because the 1-indexed argument is non # hashable and thus compared by identity with different instances sum(np.asarray([1,2,3]), np.asarray([4,5,6]) or more simply np.pad(a, [2, 3], 'constant', constant_values=(4, 6)) ^^^^^^ non-hashable static argument. The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about). If this commit breaks you, you usually have one of the following options: - If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static. - When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset. - You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function. PiperOrigin-RevId: 339351798
1 parent a7de694 commit cb48f42

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

jax/api_util.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
8888
try:
8989
hash(static_arg)
9090
except TypeError:
91-
logging.warning(
92-
"Static argument (index %s) of type %s for function %s is "
93-
"non-hashable. As this can lead to unexpected cache-misses, it "
94-
"will raise an error in a near future.", i, type(static_arg),
95-
f.__name__)
96-
# e.g. ndarrays, DeviceArrays
97-
fixed_args[i] = WrapHashably(static_arg) # type: ignore
91+
raise ValueError(
92+
"Non-hashable static arguments are not supported, as this can lead "
93+
f"to unexpected cache-misses. Static argument (index {i}) of type "
94+
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
9895
else:
9996
fixed_args[i] = Hashable(static_arg) # type: ignore
10097

tests/api_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,18 @@ def test_jit_reference_dropping(self):
415415
del g # no more references to x
416416
assert x() is None # x is gone
417417

418+
def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):
419+
if self.jit != jax.api._python_jit:
420+
raise unittest.SkipTest("this test only applies to _python_jit")
421+
f = lambda x, y: x + 3
422+
jitted_f = self.jit(f, static_argnums=(1,))
423+
424+
msg = ("Non-hashable static arguments are not supported, as this can lead "
425+
"to unexpected cache-misses. Static argument (index 1) of type "
426+
"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")
427+
with self.assertRaisesRegex(ValueError, re.escape(msg)):
428+
jitted_f(1, np.asarray(1))
429+
418430
def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
419431
if version < (0, 1, 58):
420432
raise unittest.SkipTest("Disabled because it depends on some future "
@@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
428440

429441
jitted_f(1, 1)
430442

431-
msg = (
432-
"""Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, 1. The error was:
433-
TypeError: unhashable type: 'numpy.ndarray'""")
443+
msg = ("Non-hashable static arguments are not supported. An error occured "
444+
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
445+
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")
434446

435447
with self.assertRaisesRegex(ValueError, re.escape(msg)):
436448
jitted_f(1, np.asarray(1))

0 commit comments

Comments
 (0)