Replies: 1 comment 3 replies
-
Hey, thanks. I don't explicitly use jax.jit since I use pmap to distribute over multiple (or just one if you only have one) devices. Pmapping in JAX will automatically JIT your functions for you. Regarding the hyperparameter search, no it's not too early to do something like this in fact for some research of mine I did this over multiple seeds however I will say that since I use config dicts, you would need to slightly hack the input of the train function to take hyperparameters that you can vmap over. Alternatively, everything in Stoix is ready to use with Optuna Bayesian optimisation so you can more intelligently tune parameters. The benefit of this is that although it's not done in parallel, it efficiently explores the space of possible hyperparameters rather than you having to define the specific values that you care about. I personally find this more efficient than choosing my own set of values especially for algorithms I'm not as familiar with and even more when I'm tuning multiple hyperparameters at once since a multidimensional vmap over a large grid is not the most efficient. |
Beta Was this translation helpful? Give feedback.
-
Hi! Really cool that you took the ideas from PureJaxRL to next level.
I searched the keyword "jit" in the codebase and nothing shows up 😅 . Are the training loops jitted similar to PureJaxRL or left to the user?
I was considering implementing a hyperparameter search similar to PureJaxRL where training loops are jitted and vmapped. Is the current state too early for this?
Beta Was this translation helpful? Give feedback.
All reactions