Skip to content

Commit

Permalink
fix for multipleindependent and vi
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jun 30, 2022
1 parent ccf1334 commit a51e93b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sbi/samplers/vi/vi_pyro_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,10 @@ def build_flow(
"""
# Some transforms increase dimension by decreasing the degrees of freedom e.g.
# SoftMax.
# `unsqueeze(0)` because the `link_flow` requires a batch dimension if the prior is
# a `MultipleIndependent`.
additional_dim = (
len(link_flow(torch.zeros(event_shape, device=device)))
len(link_flow(torch.zeros(event_shape, device=device).unsqueeze(0))[0])
- torch.tensor(event_shape, device=device).item()
)
event_shape = torch.Size(
Expand Down

0 comments on commit a51e93b

Please sign in to comment.