Skip to content

Commit

Permalink
Fix IntEnum test when checking is enabled. (#2981)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp authored May 7, 2020
1 parent d284755 commit 50dc44b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
14 changes: 8 additions & 6 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import numpy as onp

from . import ad_util
Expand Down Expand Up @@ -57,16 +59,16 @@ def zeros_like_shaped_array(aval):

core.literalable_types.update(array_types)

def _zeros_like_python_scalar(x):
return onp.array(0, dtypes.python_scalar_dtypes[type(x)])
def _zeros_like_python_scalar(t, x):
return onp.array(0, dtypes.python_scalar_dtypes[t])

def _make_concrete_python_scalar(x):
def _make_concrete_python_scalar(t, x):
return ConcreteArray(
onp.array(x, dtype=dtypes.python_scalar_dtypes[type(x)]),
onp.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)

for t in dtypes.python_scalar_dtypes.keys():
core.pytype_aval_mappings[t] = _make_concrete_python_scalar
ad_util.jaxval_zeros_likers[t] = _zeros_like_python_scalar
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
8 changes: 4 additions & 4 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,10 @@ def valid_jaxtype(x):


def concrete_aval(x):
try:
return pytype_aval_mappings[type(x)](x)
except KeyError as err:
raise TypeError("{} is not a valid Jax type".format(type(x))) from err
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
raise TypeError(f"{type(x)} is not a valid Jax type")


def get_aval(x):
Expand Down
4 changes: 1 addition & 3 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ class AnEnum(enum.IntEnum):
A = 42
B = 101
np.testing.assert_equal(np.array(42), np.array(AnEnum.A))
with core.skipping_checks():
# Passing AnEnum.A to jnp.array fails the type check in bind
np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A))
np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A))
np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))

Expand Down

0 comments on commit 50dc44b

Please sign in to comment.