-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add auto-batched (low-rank) multivariate normal guides. #1737
Conversation
numpyro/infer/autoguide.py
Outdated
for site in self.prototype_trace.values(): | ||
if site["type"] == "sample" and not site["is_observed"]: | ||
shape = site["value"].shape | ||
if site["value"].ndim < self.batch_ndim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a safer check is site["value"].ndim < self.batch_ndim + site["fn"].event_dim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great contribution, @tillahoffmann! LGTM pending the comment above.
It turns out that for larger datasets, we run into jax-ml/jax#19885. The issue could probably be worked around in |
Oh, what a subtle issue. It would be nice to have a fix here (if the solution is simple like changing operators around) |
* Add `ReshapeTransform`. * Add `AutoBatchedMultivariateNormal`. * Refactor to use `AutoBatchedMixin`. * Add `AutoLowRankMultivariateNormal`. * Fix import order. * Disable batching along event dimensions.
This PR implements auto-guides that support batching along leading dimensions of the parameters. The guides are motivated by models that have conditional independence structure but possibly strong correlation within each instance of a
plate
. The interface is exactly the same asAuto[LowRank]MultivariateNormal
with an additional argumentbatch_ndims
that specifies the number of dimensions to treat as independent in the posterior approximation.Example
Consider a random walk model with
n
time series andt
observations. Then the number of parameters isn * t + 2 * n
(matrix of latent time series, one scale parameter for random walk innovations for each series, and one scale parameter for observation noise for each series). For concreteness, here's the model.Suppose we use different auto-guides and count the number of parameters we need to optimize. The example below is for
n = 10
andt = 20
AutoDiagonalNormal
of course has the fewest parameters andAutoMultivariateNormal
the most. The number of location parameters is the same across all guides. The batched versions have significantly fewer scale/covariance parameters (but of course cannot model dependence between different series). There is no free lunch, but I believe these batched guides can strike a reasonable compromise between modeling dependence and computational cost.Implementation
The implementation uses a mixin
AutoBatchedMixin
toThe two batched guides are implemented analogously to the non-batched guides with the addition of the mixin and slight modifications to the parameters.
I added a
ReshapeTransform
to take care of the shapes. That could probably also be squeezed into theUnpackTransform
. I decided on the former approach becauseUnpackTransform
andNote
I didn't implement the
get_base_dist
,get_transform
, andget_posterior
methods because I couldn't find the corresponding tests.