Want to know what I'm doing wrong in implementation here #444
Unanswered
dhruvsreenivas
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone, hope you are doing well! I'm working on a research project with the DeepMind JAX ecosystem (Haiku, Optax), but for some reason, I find that when I train over a dataset, the training loss doesn't go down, as shown in this screenshot.
I'm trying to do something pretty simple: train Random Network Distillation (https://arxiv.org/abs/1810.12894, https://github.com/deepmind/acme/tree/master/acme/agents/jax/rnd) on an offline dataset of D4RL MuJoCo data. I tried a few sanity checks, including training on one random data point for some number of iterations. That loss also doesn't go down: it basically stays at 0.005 for 1000 straight epochs (shown in below screenshots):
Here are some snippets:
RND neural network + trainer code:
Training loop code:
where
self
refers to a workspace with an experiment configcfg
where I train and save everything of interest.As shown, I use the
optax.adam
optimizer with learning rate1e-3
. This I think is standard (maybe a bit large, but I've swept through a few learning rates both larger and smaller to get the same results).I'm wondering where I am going wrong in this training approach--I think I have it correct, but there's something that I'm certainly missing that I don't know about. Any help would be greatly appreciated! If you guys have any additional questions, I'll be happy to send you updates either on here or through a video chat. Also let me know if the Optax repo is the right place to send this msg--I don't think this is an issue yet (more on me than on the package) so I'm putting it in the discussions tab.
Regarding package versions, I am using Haiku 0.0.7, Optax 0.1.3, JAX 0.3.16 on CUDA for these experiments. I love the framework by the way!
Beta Was this translation helpful? Give feedback.
All reactions