-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hessians of network functions #2682
Comments
Note:
prints:
and
In other words, the output of If you don't want JAX to maintain the sparse structure for you, you can simply flatten the input into a 1D dense array before calling
which prints:
Does that help answer your question? |
This is perfect thank you! |
I'm trying to get out the Hessian of the function produced by a neural network, and I cannot figure out a way to actually get access to the Hessian matrix.
Is there a way to manipulate the hessian so that I can view it directly as a matrix. Ideally one would be able to batch this response and call
hessian = get_hessian(apply_fn)(params, train_xs)
using the full training data and get ann x p x p
array back containing the Hessians evaluated at each of then
data points.I'm not sure if I'm missing something already built in or if this is something I would need to built. Any help is appreciated!
The text was updated successfully, but these errors were encountered: