-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Hessian-vector product function #126
Conversation
I would prefer to default to a system that works rather than one that throws errors. A first implementation can build the entire Hessian with finite diffs of gradients, then multiply. A less memory intensive version would generate one row of the Hessian at a time with finite diffs of gradients, then multiply to get a single entry in the result. This is, of course, parallelizable, as would be the Hessian algorithm in AD for that matter. |
In the case of using finite diffs, why not approximate edit. Or related finite diff schemes like |
We're doing the central version (second thing you wrote under "related finite diff"). It's implemented as a functional in Stan's C++ here and we just plugged that into BridgeStan: |
Sorry if I'm missing something here, but doesn't that compute the complete hessian (in 2n grad calls). I was referring to the finite diff version of a hessian vector product (with 2 grad calls, ignoring questions of selecting |
Nope---I was the one missing the point. Yes, what I pointed to computes the complete Hessian. I hadn't though about updating the finite diff algorithm to do Hessian-vector products innately. The place to add that would probably be the Stan math lib. It'd be faster (and more accurate?) than what we're doing now for tests, which I'm pretty sure just computes the Hessian and multiplies. |
Exposed this function in each interface. For the moment I've left it as the naive "calculate the hessian via finite differences and multiply" approach when autodiff hessians are unavailable. |
I think that's OK for a first implementation. @aseyboldt : do you know how to write the algorithm to do this more efficiently with finite differences over gradients that doesn't involve computing the whole Hessian? |
I guess pretty much this one:
The problem is that I really don't know too much about the numerical issues that I'm pretty sure are buried here, and just googling a bit for literature about this I couldn't find too much that looked useful... |
Here's a blog post / reference about that algorithm, which claims the error rate for the approximation is O(h^2) for small h: https://justindomke.wordpress.com/2009/01/17/hessian-vector-products/ Isn't Justin Domke at FI right now? Maybe Bob and/or Brian could check with Justin about this. I can't find a reference on the error for Stan Math's central difference approximation to the Hessian. I guess it too is O(h^2), but I don't know that. If both methods have O(h^2) error (assuming constant terms are similar enough), then the faster one would be my preference for the default in BridgeStan when P.S. The reference given in the Stan Math source code for the central difference approximation is broken. Following bread crumbs from the intent of the broken link ends at the book Numerical Methods for Unconstrained Optimization and Nonlinear Equations Dennis and Schnabel. There might be an error analysis within this book, but I don't have access, electronic or otherwise. |
Ah, that blog post might be where I first heard about this trick. I couldn't quite remember where I got it from... |
Thanks for the algorithm and reference. @justindomke was visiting, but we can still check with him. @bgoodri may also know---I think the central finite diffs rather than one-sided versions was his idea. The place to code the faster version is in the |
Hi everyone! I have a few things to add that might be helpful. One is that you can think of a Hessian-vector product as two composed gradient evaluations. Say you want to compute I don't know if Stan can do 2-nd order derivatives like this, but here's a little demo using JAX: import jax
from jax import numpy as jnp
def f(x):
"random nonlinear function"
return (x[0] * jnp.sum(jnp.sin(x)))**x[1]
def naive_hvp(x,v):
return jax.hessian(f)(x) @ v
def recursive_hvp(x,v):
g = jax.grad(f)
h = lambda x: g(x) @ v
return jax.grad(h)(x)
x = jnp.array([1.1,2.2,3.3])
v = jnp.array([4.4,5.5,6.6]) produces the output
If you can do it this way, you probably should. Second, there's essentially an infinite number of finite difference schemes, with different tradeoffs between computational cost and accuracy. You can go beyond 1-sided or 2-sided differences to "4-sided" or higher. Here's short post from me giving the basic idea. This is for derivatives rather than Hessian-vector products, but the same idea applies. You could modify @aseyboldt 's code above to use any of these. In practice many people seem to find two-sided differences to be a reasonable compromise. They at least work "most of the time" whereas with one-sided differences will fall over for even slightly challenging problems. Finally, if you're doing finite differences, don't neglect the problem of choosing the perturbation size. As far as I know, this is hard and there is no general scheme that's always guaranteed to work well. But probably there are heuristics that are better than always using numerical epsilon. I've had good luck using the heuristic |
Thanks for the explanation and refs, @justindomke. Stan has a reasonable Hessian-vector product implemented for everything but our implicit functions. For models using implicit functions, we want to provide finite differences. I opened an issue in the math lib, which is where the implementation should go: When we were discussing finite diffs, someone made exactly the same point about 2 diffs being a good compromise. |
I took a pass at the above in stan-dev/math#2914, more or less a direct translation of what @aseyboldt had above. For BridgeStan, we have a couple choices in the meantime:
|
I'm OK with any of those choices---they all have pros and cons. I'm leaning slightly toward (2) or (3) just to get the API in place even if we change implementations later. |
@justindomke It seems like that paper suggests a different heuristic (namely |
I too am OK with any of those choice, lean towards 2. or 3., and I vote 3. just because that's already written up in this PR. |
No, I'm sure you're right! I missed a division symbol above (fixed now), and also mine was for gradients. Yours looks good. |
With the release of Stan 2.33 this has been updated to use the finite diff Hessian-vector-product from stan::math if autodiff hessians are unavailable. I believe it's ready for review @roualdes @bob-carpenter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments/questions. Mostly nitpicks, I think.
Otherwise, it looks good to me. Thanks for orchestrating yet another multiple repository feature-add.
if (error_msg) { | ||
std::stringstream error; | ||
error << "log_density_hessian_vector_product() failed with unknown " | ||
"exception" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be another chevron here, <<
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually optional, C++ lets you place two string literals next to each other and the result is they are just pasted together (see "Concatenation" on https://en.cppreference.com/w/cpp/language/string_literal).
If you prefer the extra <<
stylistically I'm happy to add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I just didn't know this. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes.
if (error_msg) { | ||
std::stringstream error; | ||
error << "log_density_hessian_vector_product() failed with unknown " | ||
"exception" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I just didn't know this. Thanks.
Closes #125 and completes part of #84.
This adds a function with the following signature to the C API:
When
BRIDGESTAN_AD_HESSIAN
is defined, this usesstan::math::hessian_vector_product
. Otherwise, I currently have it implemented to compute the hessian and then just multiply it by the supplied vector.In Stan 2.33 we can use stan-dev/math#2914