You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to @parameterized and parameter() but for random keys and seeds.
I envision an API providing something like @random and rng():
Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to
@parameterized
andparameter()
but for random keys and seeds.I envision an API providing something like
@random
andrng()
:And then ~ magic ~ happens after which point we get a function like:
The text was updated successfully, but these errors were encountered: