Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/permutation #1568

Merged
merged 8 commits into from
Apr 24, 2020
Merged
26 changes: 26 additions & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,32 @@ def shuffle(key: np.ndarray, x: np.ndarray, axis: int = 0) -> np.ndarray:
"""
return _shuffle(key, x, axis)


def permutation(key, x):
"""
Permute elements of an array along its first axis or return a permuted range.

Args:n
key: a PRNGKey used as the random key.
x: the array or integer range to be shuffled.

Returns:
A shuffled version of x or array range
"""
if not onp.ndim(x):
# scalar case, must be a concrete integer
if not onp.issubdtype(lax.dtype(x), onp.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
x = int(x)
return _shuffle(key, np.arange(x), 0)
elif onp.ndim(x) == 1:
return _shuffle(key, x, 0)
else:
msg = ("permutation for >1d inputs x not yet implemented, see "
"https://github.com/google/jax/issues/2066 for updates.")
raise NotImplementedError(msg)


@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis):
# On parallel architectures, Fisher-Yates is more expensive than doing
Expand Down
46 changes: 43 additions & 3 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import scipy.stats

from jax import api
from jax import core
from jax import grad
from jax import lax
from jax import numpy as np
Expand Down Expand Up @@ -164,10 +165,49 @@ def testShuffle(self, dtype):
perm1 = rand(key)
perm2 = crand(key)

self.assertTrue(onp.all(perm1 == perm2))
self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertTrue(onp.all(onp.sort(perm1) == x))
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64, onp.int32, onp.int64]))
def testPermutationArray(self, dtype):
key = random.PRNGKey(0)
x = onp.arange(100).astype(dtype)
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)

perm1 = rand(key)
perm2 = crand(key)

self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(onp.all(perm1 == x)) # seems unlikely!
self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
self.assertArraysAllClose(x, onp.arange(100).astype(dtype),
check_dtypes=True)

def testPermutationInteger(self):
key = random.PRNGKey(0)
x = 100
rand = lambda key: random.permutation(key, x)
crand = api.jit(rand)

perm1 = rand(key)
perm2 = crand(key)

self.assertAllClose(perm1, perm2, check_dtypes=True)
self.assertEqual(perm1.dtype, perm2.dtype)
self.assertFalse(onp.all(perm1 == onp.arange(100))) # seems unlikely!
self.assertAllClose(onp.sort(perm1), onp.arange(100), check_dtypes=False)

def testPermutationErrors(self):
key = random.PRNGKey(0)
with self.assertRaises(TypeError):
random.permutation(key, 10.)
with self.assertRaises(core.ConcretizationTypeError):
api.jit(random.permutation)(key, 10)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_p={}_{}".format(p, dtype),
Expand Down