-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Stein Variational Gradient Descent #385
Comments
I've been having a go at implementing this. The problem I'm encountering is that the SVGD is really just a gradient descent type procedure, which means the user will probably want the following
Here are my thoughts: Proposal Implementation(To be clear, this would not be the final implementation, just to pin down ideas) In class SvgdStepFn(Protocol):
def __call__(self, state: State) -> State:
...
class SVGDAlgorithm(NamedTuple):
init: InitFn
step: SvgdStepFn In class SVGDState(NamedTuple):
position: PyTree
kernel_parameters: PyTree
opt_state: Any
def init(initial_position: PyTree, kernel_parameters: PyTree) -> SVGDState:
opt_state = optimizer.init(initial_position)
return SVGDState(initial_position, kernel_parameters, opt_state)
def step(state: SVGDState, log_p: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDState:
# Computation where we call kernel as kernel(x,y, **state.kernel_params)
... In class Optimizer(Protocol):
init: Callable
update: Callable
class svgd:
init = staticmethod(vi.svgd.init)
step = staticmethod(vi.svgd.step)
def __new__( cls, log_p: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDAlgorithm:
def init_fn(initial_position: PyTree, kernel_parameters: Dict):
return cls.init(initial_position, state.kernel_parameters)
def step_fn(
state: SVGDState,
):
return cls.step(state, log_p, kernel, optimizer)
return SVGDAlgorithm(init_fn, step_fn) An example from the user's point of viewsvgd = svgd(log_p, kernel, optimizer) # example optimizer: optax.adam(0.1)
initial_state = svgd.init(initial_position)
step = jax.jit(svgd.step)
def update_kernel_parameters(state: SVGDState) -> SVGDState:
position, kernel_parameters, opt_state = state
# f could be the median heuristic for example
return SVGDState(position, f(kernel_parameters, position), opt_state)
num_iterations = 5000
for _ in range(num_interations):
state = svgd.step(state)
state = update_kernel_parameters(state) |
Thanks for the detailed analysis! |
@antotocar34 Your proposal sounds about right. I think Pathfinder's current implementation, and the implementation of mean-field VI will need to be updated to come close to the approach you suggested. We could maybe start with implementing the kernelized Stein discrepancy #384 in a first PR. I need it for another idea 😊 Any reason why |
If the |
If Assuming import jax.experimental.host_callback as hcb
def f_non_jit(x):
return hcb(f, x, result_shape=()) Do you have any other blockers on this? If so you can always open a draft PR! |
It seems to me like the main distinction here is that SVGD is doing gradient descent on a bunch of particles, not really approximating a distribution in itself but just samples from that distribution. In that sense SVGD is closer to SMC than VI, even if theoretically it's minimizing a distance between the target distribution and (samples from) an approximate distribution. @antotocar34's proposal to code it as a sampling algorithm rather than a approximate inference algorithm is best. But other VI algorithms should be implemented like meanfield VI currently is because there is value in having access to the approximate distribution and not only to samples from it. |
I think it also makes sense to give fine-control over the optimisation for VI algorithms, where the step function perform a one-step update of the approximation by sampling the ELBO and performing an optimisation step. Among other things, it allows you to run the procedure for a few more steps if you realise the approximation hasn't quite converged. |
My blocker is the following: So I wrote the algorithm initially assuming I could destructure the Pytree of particles into an array of dimension The issue is that getting the particles array back into the original PyTree structure is not straightforward. For example import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten
def pytree_to_array(particles_pytree):
_particle_array, pytree_def = tree_flatten(particles_pytree)
particle_array = jnp.stack(_particle_array).squeeze().T
return particle_array, pytree_def
def array_to_pytree(particle_array, pytree_def):
arr = particle_array.T # Works for 1, not for 2 and 3
# arr = [particle_array.T] # Works for 2 and 3, not for 1
# arr = arr # Works for 3, not for 1 and 2
return tree_unflatten(
pytree_def,
arr # What to put here?
)
particles_1 = {
"mu": jnp.array([1,2,-3,-1/2]),
"sigma": jnp.array([1.1,2.3,0.1,2])
}
particles_2 = {
"theta": jnp.array([0.6,.54,0.3,0.1]),
}
particles_3 = jnp.array([ [0.6,.54,0.3,0.1], [0.4,.1,.3,.4] ])
particle_array, pytree_def = pytree_to_array(particles_1)
pytree = array_to_pytree(particle_array, pytree_def)
print(pytree) Playing with the above code should reveal the problem I'm facing. I tried checking what the pytree looks like in the Possible solutions:
|
For reference, here is a current version of the algorithm. https://gist.github.com/antotocar34/3e6a762df1427a7db6105cfe72f66185 |
This works for (pretty sure) all pytrees that have the same dimension 0, i.e. number of particles: import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
def pytree_to_array(particles_pytree):
particle_array = jax.vmap(lambda p: ravel_pytree(p)[0])(particles_pytree)
example_particle = jax.tree_util.tree_map(lambda p: p[0], particles_pytree)
_, unravel_fn = ravel_pytree(example_particle)
return particle_array, unravel_fn
def array_to_pytree(particle_array, unravel_fn):
return jax.vmap(lambda p: unravel_fn(p))(particle_array)
particles_1 = {
"mu": jnp.array([1,2,-3,-1/2]),
"sigma": jnp.array([1.1,2.3,0.1,2])
}
particles_2 = {
"theta": jnp.array([0.6,.54,0.3,0.1]),
}
particles_3 = jnp.array([ [0.6,.54,0.3,0.1], [0.4,.1,.3,.4] ])
particles_4 = {
"mu": jnp.zeros((4, 4)),
"diag_cov": jnp.ones((4, 4))
}
particles_5 = {
"mu": jnp.zeros((4, 5)),
"dense_cov": jnp.ones((4, 5, 5))
}
particle_array, unravel_fn = pytree_to_array(particles_5)
pytree = array_to_pytree(particle_array, unravel_fn)
print(pytree) |
This is much better, thanks a million Alberto. |
No problem Antoine. Let us know here if you have any other blockers (I'm interested in having SVGD implemented on blackjax) 🚀 |
https://arxiv.org/abs/1608.04471
The text was updated successfully, but these errors were encountered: