Skip to content

Commit

Permalink
Merge pull request #8117 from LenaMartens:changelist/400933831
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 401418826
  • Loading branch information
jax authors committed Oct 7, 2021
2 parents bfbdfa8 + 342948d commit 8f0589f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
20 changes: 20 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6826,6 +6826,25 @@ def _convert_2xU64_to_4xU32_without_bitcast(c, key):
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]

def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
out_keys = []
out_bits = []
for k in key:
updated_key, bits = rng_bit_generator_p.bind(k, shape=shape, dtype=dtype,
algorithm=algorithm)
out_keys.append(reshape(updated_key, (1,)+updated_key.shape))
out_bits.append(reshape(bits, (1,)+bits.shape))
stacked_keys = concatenate(out_keys, 0)
stacked_bits = concatenate(out_bits, 0)
return (stacked_keys, stacked_bits), (0, 0)

rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
Expand All @@ -6835,6 +6854,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
batching.primitive_batchers[rng_bit_generator_p] = _rng_bit_generator_batching_rule
xla.translations[rng_bit_generator_p] = \
partial(_rng_bit_generator_translation_rule, False)
xla.backend_specific_translations['gpu'][rng_bit_generator_p] = \
Expand Down
26 changes: 26 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,32 @@ def test_vmap_fold_in_shape(self):
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
self.assertEqual(keys.shape, (3,))

def test_vmap_split_not_mapped_key(self):
key = self.seed_prng(73)
single_split_key = random.split(key)
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
self.assertEqual(vmapped_keys.shape, (3, 2))
for vk in vmapped_keys:
self.assertArraysEqual(vk.keys, single_split_key.keys)

def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(fk.keys, vk.keys)

def test_vmap_random_bits(self):
rand_fun = lambda key: random.randint(key, (), 0, 100)
key = self.seed_prng(73)
mapped_keys = random.split(key, num=3)
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
rand_nums = vmap(rand_fun)(mapped_keys)
self.assertEqual(rand_nums.shape, (3,))
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))

def test_cannot_add(self):
key = self.seed_prng(73)
self.assertRaisesRegex(
Expand Down

0 comments on commit 8f0589f

Please sign in to comment.