Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In JAX random linking splitting should happen before using the key #1344

Closed
AdrienCorenflos opened this issue Dec 9, 2022 · 0 comments · Fixed by #1345
Closed

In JAX random linking splitting should happen before using the key #1344

AdrienCorenflos opened this issue Dec 9, 2022 · 0 comments · Fixed by #1345
Assignees
Labels
JAX Involves JAX transpilation random variables Involves random variables and/or sampling

Comments

@AdrienCorenflos
Copy link
Contributor

See for example

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
return (rng, sample)

in order to be internals-agnostic, this should be

def sample_fn(rng, size, dtype, *parameters): 
     rng_key = rng["jax_state"] 
     rng_key, sample_key = jax.random.split(rng_key, 2) 
     sample = jax_op(sample_key, *parameters, shape=size, dtype=dtype) 
     rng["jax_state"] = rng_key
     return (rng, sample) 
@brandonwillard brandonwillard added JAX Involves JAX transpilation random variables Involves random variables and/or sampling labels Dec 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants