Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PRNG handling akin to parameters #24

Open
samuela opened this issue Apr 8, 2020 · 2 comments
Open

PRNG handling akin to parameters #24

samuela opened this issue Apr 8, 2020 · 2 comments

Comments

@samuela
Copy link

samuela commented Apr 8, 2020

Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to @parameterized and parameter() but for random keys and seeds.

I envision an API providing something like @random and rng():

@random
def my_func(x):
  W = jax.random.normal(rng(), shape=(2, 2))
  b = jax.random.exponential(rng(), shape=(2,))
  return W @ x + b

And then ~ magic ~ happens after which point we get a function like:

def my_func(x, rng=None):
  rng0, rng = jax.random.split(rng)
  W = jax.random.normal(rng0, shape=(2, 2))
  rng1, rng = jax.random.split(rng)
  b = jax.random.exponential(rng1, shape=(2,))
  return W @ x + b
@juliuskunze
Copy link
Owner

You can already userandom_key() within @parametrized:

@parametrized
def dropout(inputs):
        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

An independent seed transform as you describe would make sense, if I find time I will factor it out.

@samuela
Copy link
Author

samuela commented Apr 9, 2020

Neat, was not aware of that! Yeah I think having a separate transform would be great.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants