-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with gln loss #6
Comments
Hm, this seems weird. The code in the notebooks still works though, right? Maybe it is just you need to |
Sorry, I forgot to mention that in the "learned_dynamics" function I had to remove "jnp.squeeze" jnp.squeeze(nn_forward_fn(params, state), axis=-1) because otherwise the code would not work. Sorry for the trouble, your project is really well done, most likely it is me who is struggling with JAX |
Oh, I see. Is that a bug in the code (maybe due to new JAX updates)? If so do you think you could submit a PR to patch it? It would be very appreciated if you could (currently drowning in grant writing...)! Cheers, |
Hello! I was trying to use the code in the "experiment_dblpend" directory. As long as it is training the neural network with the loss "baseline_nn", everything is fine.
The moment I try to use the loss "gln" (that should direct me to learning like the paper) something wrong happens. The resulting error is as follows (of course, without touching the code)
TypeError: Gradient only defined for scalar-output functions. Output had shape: (4,).
This error refers to the application of gln_loss at the following line
preds = jax.vmap(partial(lagrangian_eom, learned_dynamics(params)))(state)
Specifically to the command
(jax.grad(lagrangian, 0)(q, q_t)
of the function lagrangian_eom.
I honestly cannot explain this, because everything seems correct
The text was updated successfully, but these errors were encountered: