Skip to content

Commit

Permalink
fix: add x batch dim in sbc loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Apr 26, 2023
1 parent 5768a7c commit 86e73a9
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sbi/analysis/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def sbc_on_batch(
dap_samples = torch.zeros_like(thetas)
ranks = torch.zeros((thetas.shape[0], len(reduce_fns)))

for idx, (tho, xo) in enumerate(zip(thetas, xs)):
for idx in range(thetas.shape[0]):
# unsqueeze for potential higher-dimensional data.
xo = xs[idx].unsqueeze(0)
# VI posterior needs to be trained on the current xo.
if isinstance(posterior, VIPosterior):
posterior.set_default_x(xo)
Expand All @@ -184,7 +186,9 @@ def sbc_on_batch(
# rank for each posterior dimension as in Talts et al. section 4.1.
for i, reduce_fn in enumerate(reduce_fns):
ranks[idx, i] = (
(reduce_fn(ths, xo) < reduce_fn(tho.unsqueeze(0), xo)).sum().item()
(reduce_fn(ths, xo) < reduce_fn(thetas[idx].unsqueeze(0), xo))
.sum()
.item()
)

return ranks, dap_samples
Expand Down

0 comments on commit 86e73a9

Please sign in to comment.