Skip to content

Commit

Permalink
Fix packbits/unpackbits tests (#2702)
Browse files Browse the repository at this point in the history
  • Loading branch information
skye authored Apr 14, 2020
1 parent 5ba9b6b commit 4e9f640
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,9 @@ def attempt_sideeffect(x):
jnp_input = jnp.ones((1))
expected_onp_input_after_call = onp.ones((1))
expected_jnp_input_after_call = jnp.ones((1))

self.assertIs(type(jnp.concatenate([onp_input])), jnp.DeviceArray)

attempt_sideeffect(onp_input)
attempt_sideeffect(jnp_input)

Expand Down Expand Up @@ -2040,6 +2040,8 @@ def testRollaxis(self, shape, dtype, start, axis, rng_factory):
for axis in [None, 0, 1, -2, -1]
for rng_factory in [jtu.rand_some_zero]))
def testPackbits(self, shape, dtype, axis, bitorder, rng_factory):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = rng_factory()
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
Expand All @@ -2059,6 +2061,8 @@ def testPackbits(self, shape, dtype, axis, bitorder, rng_factory):
for count in [None, 20]
for rng_factory in [jtu.rand_int]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count, rng_factory):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = rng_factory(0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
Expand Down Expand Up @@ -2252,7 +2256,7 @@ def onp_fun(*args):
"_a_shape={}_axis={}_keepdims={}".format(
jtu.format_shape_dtype_string(a_shape, a_dtype),
axis, keepdims),
"a_rng": jtu.rand_default(),
"a_rng": jtu.rand_default(),
"a_shape": a_shape, "a_dtype": a_dtype,
"axis": axis,
"keepdims": keepdims}
Expand Down Expand Up @@ -2952,7 +2956,7 @@ def testPrecision(self):
"axis": axis,
"dtype": dtype, "rng_factory": rng_factory}
for shape in [(10,), (10, 15), (10, 15, 20)]
for _num_axes in range(len(shape))
for _num_axes in range(len(shape))
for axis in itertools.combinations(range(len(shape)), _num_axes)
for dtype in inexact_dtypes
for rng_factory in [jtu.rand_default]))
Expand Down

0 comments on commit 4e9f640

Please sign in to comment.