Skip to content

Commit

Permalink
tweak handling of random.permutation scalar case
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 24, 2020
1 parent 595e7f2 commit ac0697d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
11 changes: 7 additions & 4 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,14 @@ def permutation(key, x):
Returns:
A shuffled version of x or array range
"""
if isinstance(x, (int, onp.integer)):
arr = _shuffle(key, onp.arange(x), 0)
if not onp.shape(x):
# scalar case, must be a concrete integer
if not onp.issubdtype(lax.dtype(x), onp.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
x = int(x) # TODO(mattjj): concrete tracer error from core.py
return _shuffle(key, np.arange(x), 0)
else:
arr = _shuffle(key, x, 0)
return arr
return _shuffle(key, x, 0)


@partial(jit, static_argnums=(2,))
Expand Down
8 changes: 8 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import scipy.stats

from jax import api
from jax import core
from jax import grad
from jax import lax
from jax import numpy as np
Expand Down Expand Up @@ -202,6 +203,13 @@ def testPermutationInteger(self):
self.assertFalse(onp.all(perm1 == onp.arange(100))) # seems unlikely!
self.assertTrue(onp.all(onp.sort(perm1) == onp.arange(100)))

def testPermutationErrors(self):
key = random.PRNGKey(0)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
api.jit(random.permutation)(key, 10)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}".format(p, dtype),
"p": p, "dtype": onp.dtype(dtype).name}
Expand Down

0 comments on commit ac0697d

Please sign in to comment.