Skip to content

Commit

Permalink
make constant handlers follow type mro
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored and NeilGirdhar committed Apr 1, 2021
1 parent 9d60c8d commit 95c072c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
15 changes: 9 additions & 6 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
logging._warn_preinit_stderr = 0

from ..config import flags
from jax._src import util
from jax._src import util, traceback_util
from .. import dtypes
import numpy as np
import threading

traceback_util.register_exclusion(__file__)

try:
from . import tpu_client
except ImportError:
Expand Down Expand Up @@ -335,11 +337,12 @@ def constant(builder, py_val, canonicalize_types=True):
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
py_type = type(py_val)
if py_type in _constant_handlers:
return _constant_handlers[py_type](builder, py_val, canonicalize_types)
else:
raise TypeError("No constant handler for type: {}".format(py_type))
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return constant(builder, py_val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))

# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
from contextlib import contextmanager
import copy
import enum
from functools import partial
import re
import unittest
Expand Down Expand Up @@ -2359,6 +2360,19 @@ def __jax_array__(self):
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
self.assertEqual(f(x), f(a))

def test_constant_handler_mro(self):
# https://github.com/google/jax/issues/6129

class Foo(enum.IntEnum):
bar = 1

@api.pmap
def f(_):
return Foo.bar

ans = f(jnp.arange(1)) # doesn't crash
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)


class RematTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 95c072c

Please sign in to comment.