diff --git a/sbi/analysis/sbc.py b/sbi/analysis/sbc.py index 14a38daba..4264e06ca 100644 --- a/sbi/analysis/sbc.py +++ b/sbi/analysis/sbc.py @@ -169,7 +169,9 @@ def sbc_on_batch( dap_samples = torch.zeros_like(thetas) ranks = torch.zeros((thetas.shape[0], len(reduce_fns))) - for idx, (tho, xo) in enumerate(zip(thetas, xs)): + for idx in range(thetas.shape[0]): + # unsqueeze for potential higher-dimensional data. + xo = xs[idx].unsqueeze(0) # VI posterior needs to be trained on the current xo. if isinstance(posterior, VIPosterior): posterior.set_default_x(xo) @@ -184,7 +186,9 @@ def sbc_on_batch( # rank for each posterior dimension as in Talts et al. section 4.1. for i, reduce_fn in enumerate(reduce_fns): ranks[idx, i] = ( - (reduce_fn(ths, xo) < reduce_fn(tho.unsqueeze(0), xo)).sum().item() + (reduce_fn(ths, xo) < reduce_fn(thetas[idx].unsqueeze(0), xo)) + .sum() + .item() ) return ranks, dap_samples