Skip to content

Commit

Permalink
Feature/permutation (#1568)
Browse files Browse the repository at this point in the history
* added test for random.permutation

* added permutation that wraps shuffle with behaviour of np.random.permutation

* update docstring

* need to shuffle also the integer range input

* fixed test for permutation with integer

* tweak handling of random.permutation scalar case

* NotImplementedError for random.permutation on >1d

pending resolution to #2066

* address reviewer comments: improve tests

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
MichaelMarien and mattjj authored Apr 24, 2020
1 parent fc4203c commit e0d42e9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
26 changes: 26 additions & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,32 @@ def shuffle(key: np.ndarray, x: np.ndarray, axis: int = 0) -> np.ndarray:
"""
return _shuffle(key, x, axis)


def permutation(key, x):
"""
Permute elements of an array along its first axis or return a permuted range.
Args:n
key: a PRNGKey used as the random key.
x: the array or integer range to be shuffled.
Returns:
A shuffled version of x or array range
"""
if not onp.ndim(x):
# scalar case, must be a concrete integer
if not onp.issubdtype(lax.dtype(x), onp.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
x = int(x)
return _shuffle(key, np.arange(x), 0)
elif onp.ndim(x) == 1:
return _shuffle(key, x, 0)
else:
msg = ("permutation for >1d inputs x not yet implemented, see "
"https://github.com/google/jax/issues/2066 for updates.")
raise NotImplementedError(msg)


@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis):
# On parallel architectures, Fisher-Yates is more expensive than doing
Expand Down
46 changes: 43 additions & 3 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import scipy.stats

from jax import api
from jax import core
from jax import grad
from jax import lax
from jax import numpy as np
Expand Down Expand Up @@ -164,10 +165,49 @@ def testShuffle(self, dtype):
perm1 = rand(key)
perm2 = crand(key)

self.assertTrue(onp.all(perm1 == perm2))
self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertTrue(onp.all(onp.sort(perm1) == x))
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]))
def testPermutationArray(self, dtype):
key = random.PRNGKey(0)
x = onp.arange(100).astype(dtype)
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)

perm1 = rand(key)
perm2 = crand(key)

self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
self.assertArraysAllClose(x, onp.arange(100).astype(dtype),
check_dtypes=True)

def testPermutationInteger(self):
key = random.PRNGKey(0)
x = 100
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)

perm1 = rand(key)
perm2 = crand(key)

self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(onp.all(perm1 == onp.arange(100))) # seems unlikely!
self.assertAllClose(onp.sort(perm1), onp.arange(100), check_dtypes=False)

def testPermutationErrors(self):
key = random.PRNGKey(0)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
api.jit(random.permutation)(key, 10)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}".format(p, dtype),
Expand Down

0 comments on commit e0d42e9

Please sign in to comment.