Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vi failing with independent prior #650

Merged
merged 2 commits into from
Jun 30, 2022

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Feb 17, 2022

this adds a test showing a failure mode of VI:

When sampling from a VIPosterior based on a MultipleIndependent prior there is a problem with the theta_transform and the vi_pyro_flows. I think it has to do with change in dimensionality from transformed to untransformed space, .e.g., it happens when calling link_flow(torch.zeros(event_shape, device=device)) (see trace below). Potential problem: Why is this called with just the event_shape and not the entire theta.shape?

track:

sbi/inference/posteriors/vi_posterior.py:132: in __init__
    self.set_q(q, parameters=parameters, modules=modules)
sbi/inference/posteriors/vi_posterior.py:230: in set_q
    device=self._device,
sbi/samplers/vi/vi_pyro_flows.py:156: in build_fn
    return _FLOW_BUILDERS[name](event_shape, link_flow, device=device, **kwargs)
sbi/samplers/vi/vi_pyro_flows.py:535: in masked_autoregressive_flow_builder
    **kwargs,
sbi/samplers/vi/vi_pyro_flows.py:389: in build_flow
    link_flow(torch.zeros(event_shape, device=device))
../../anaconda3/envs/mnle/lib/python3.7/site-packages/torch/distributions/transforms.py:150: in __call__
    return self._call(x)
../../anaconda3/envs/mnle/lib/python3.7/site-packages/torch/distributions/transforms.py:443: in _call
    return self.base_transform(x)
../../anaconda3/envs/mnle/lib/python3.7/site-packages/torch/distributions/transforms.py:150: in __call__
    return self._call(x)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = CatTransform(), x = tensor([0., 0.])

    def _call(self, x):
>       assert -x.dim() <= self.dim < x.dim()
E       AssertionError

../../anaconda3/envs/mnle/lib/python3.7/site-packages/torch/distributions/transforms.py:1017: AssertionError

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I also created an issue #651

@michaeldeistler
Copy link
Contributor

You could also mark the test as slow and merge it. And then we fix it before releasing v0.18.0

@janfb
Copy link
Contributor Author

janfb commented Feb 17, 2022

Alternatively, we mark it as xfail and add TODO in the code.
But I would rather leave this PR open and not include it into the release.

@michaeldeistler
Copy link
Contributor

I'm against the xfail. But simply leaving the PR open sounds good to me

@michaeldeistler michaeldeistler force-pushed the fix-vi-with-independent-prior branch from 23b400f to 578b7a9 Compare June 29, 2022 15:41
@michaeldeistler michaeldeistler force-pushed the fix-vi-with-independent-prior branch from 578b7a9 to 08b25f7 Compare June 29, 2022 15:42
@michaeldeistler
Copy link
Contributor

Okey, the problem is that the MultipleIndependent define a CatConstraint in order to join the individual constraints. With this, biject_to generates a CatTransform. Unfortunately, CatTransforms require a batch dimension. This is what broke the code.

I tried to fix the support property of MultipleIndependent for a while, but I could not find a good fix. Thus, we simply have to watch out and always ensure that anything passed into our transforms has a batch dimension.

@michaeldeistler
Copy link
Contributor

@janfb @manuelgloeckler can you have a look?

@janfb
Copy link
Contributor Author

janfb commented Jun 29, 2022

Thanks for tackling this!
I am OK with this solution. Is there a way to catch the AssertionError when we build the transform and give an informative error message? I think that would be great for catching future loop wholes of missing batch dimensions.

@michaeldeistler
Copy link
Contributor

I don't know how we would do this unfortunately. The CatTransform is built automatically by biject_to, so we do not really have access to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants