Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto-batched (low-rank) multivariate normal guides. #1737

Merged
merged 6 commits into from
Feb 21, 2024

Conversation

tillahoffmann
Copy link
Contributor

@tillahoffmann tillahoffmann commented Feb 17, 2024

This PR implements auto-guides that support batching along leading dimensions of the parameters. The guides are motivated by models that have conditional independence structure but possibly strong correlation within each instance of a plate. The interface is exactly the same as Auto[LowRank]MultivariateNormal with an additional argument batch_ndims that specifies the number of dimensions to treat as independent in the posterior approximation.

Example

Consider a random walk model with n time series and t observations. Then the number of parameters is n * t + 2 * n (matrix of latent time series, one scale parameter for random walk innovations for each series, and one scale parameter for observation noise for each series). For concreteness, here's the model.

def model(n, t):
    with numpyro.plate("n", n):
        # Model for time series.
        innovation_scale = numpyro.sample(
            "innovation_scale",
            distributions.HalfCauchy(1),
        )
        innovations = numpyro.sample(
            "innovations",
            distributions.Normal().expand([t]).to_event(1),
        )
        series = numpyro.deterministic(
            "series",
            innovations.cumsum(axis=-1),
        )
        
        # Model for observations.
        noise_scale = numpyro.sample(
            "noise_scale",
            distributions.HalfCauchy(1),
        )
        data = numpyro.sample(
            "data",
            distributions.Normal(series, noise_scale[:, None]).to_event(1),
        )

Suppose we use different auto-guides and count the number of parameters we need to optimize. The example below is for n = 10 and t = 20

# [guide class] [total number of parameters]
# 	[parameter shapes]
AutoDiagonalNormal 440
	 {'auto_loc': (220,), 'auto_scale': (220,)}
AutoLowRankMultivariateNormal 3740
	 {'auto_loc': (220,), 'auto_cov_factor': (220, 15), 'auto_scale': (220,)}
AutoMultivariateNormal 48620
	 {'auto_loc': (220,), 'auto_scale_tril': (220, 220)}
AutoBatchedLowRankMultivariateNormal 1540
	 {'auto_loc': (10, 22), 'auto_cov_factor': (10, 22, 5), 'auto_scale': (10, 22)}
AutoBatchedMultivariateNormal 5060
	 {'auto_loc': (10, 22), 'auto_scale_tril': (10, 22, 22)}

AutoDiagonalNormal of course has the fewest parameters and AutoMultivariateNormal the most. The number of location parameters is the same across all guides. The batched versions have significantly fewer scale/covariance parameters (but of course cannot model dependence between different series). There is no free lunch, but I believe these batched guides can strike a reasonable compromise between modeling dependence and computational cost.

Implementation

The implementation uses a mixin AutoBatchedMixin to

  1. determine the batch shape (and verify that a batched guide is appropriate for the model) and
  2. apply a reshaping transformation to account for the existence of batches in the variational approximation.

The two batched guides are implemented analogously to the non-batched guides with the addition of the mixin and slight modifications to the parameters.

I added a ReshapeTransform to take care of the shapes. That could probably also be squeezed into the UnpackTransform. I decided on the former approach because

  1. it separates the concerns rather than packing more logic into UnpackTransform and
  2. I've found myself looking for reshaping samples in other settings.

Note

I didn't implement the get_base_dist, get_transform, and get_posterior methods because I couldn't find the corresponding tests.

for site in self.prototype_trace.values():
if site["type"] == "sample" and not site["is_observed"]:
shape = site["value"].shape
if site["value"].ndim < self.batch_ndim:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a safer check is site["value"].ndim < self.batch_ndim + site["fn"].event_dim.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great contribution, @tillahoffmann! LGTM pending the comment above.

@fehiepsi fehiepsi merged commit b35fcec into pyro-ppl:master Feb 21, 2024
4 checks passed
@tillahoffmann tillahoffmann deleted the batched branch February 22, 2024 16:13
@tillahoffmann
Copy link
Contributor Author

It turns out that for larger datasets, we run into jax-ml/jax#19885. The issue could probably be worked around in numpyro by slightly rearranging operations in the LowRankMultivariateNormal implementation. Is that of interest or just wait for the upstream fix (I don't know how quick the jax folks usually are)?

@fehiepsi
Copy link
Member

Oh, what a subtle issue. It would be nice to have a fix here (if the solution is simple like changing operators around)

OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* Add `ReshapeTransform`.

* Add `AutoBatchedMultivariateNormal`.

* Refactor to use `AutoBatchedMixin`.

* Add `AutoLowRankMultivariateNormal`.

* Fix import order.

* Disable batching along event dimensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants