Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2021
1 parent 78f5c3d commit df891e5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])

def _make_concrete_python_scalar(t, x):
return ConcreteArray(
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
Expand Down
12 changes: 12 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,18 @@ def f(_):
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)

def test_large_python_int_to_float(self):
# https://github.com/google/jax/pull/6165
# We skip checks because otherwise we end up calling valid_jaxtype(2**100),
# which tries to form a ConcreteArray with that value and thus leads to a
# NumPy OverflowError. It's true that 2**100 does not inhabit a jax type,
# but as an issue of Python embedding we can handle operations like those in
# the tests below.
with jax.core.skipping_checks():
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit df891e5

Please sign in to comment.