From 86e73a936a9bc98c521f05ab9359c663f1547af3 Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 26 Apr 2023 10:50:20 +0200 Subject: [PATCH] fix: add x batch dim in sbc loop. --- sbi/analysis/sbc.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sbi/analysis/sbc.py b/sbi/analysis/sbc.py index 14a38daba..4264e06ca 100644 --- a/sbi/analysis/sbc.py +++ b/sbi/analysis/sbc.py @@ -169,7 +169,9 @@ def sbc_on_batch( dap_samples = torch.zeros_like(thetas) ranks = torch.zeros((thetas.shape[0], len(reduce_fns))) - for idx, (tho, xo) in enumerate(zip(thetas, xs)): + for idx in range(thetas.shape[0]): + # unsqueeze for potential higher-dimensional data. + xo = xs[idx].unsqueeze(0) # VI posterior needs to be trained on the current xo. if isinstance(posterior, VIPosterior): posterior.set_default_x(xo) @@ -184,7 +186,9 @@ def sbc_on_batch( # rank for each posterior dimension as in Talts et al. section 4.1. for i, reduce_fn in enumerate(reduce_fns): ranks[idx, i] = ( - (reduce_fn(ths, xo) < reduce_fn(tho.unsqueeze(0), xo)).sum().item() + (reduce_fn(ths, xo) < reduce_fn(thetas[idx].unsqueeze(0), xo)) + .sum() + .item() ) return ranks, dap_samples