Skip to content

Commit

Permalink
fix: do not test batched and iid x
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 21, 2025
1 parent a163535 commit 9ecc68b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):

@pytest.mark.parametrize("num_thetas", [1, 10])
@pytest.mark.parametrize("num_trials", [1, 5])
@pytest.mark.parametrize("num_xs", [1, 3])
@pytest.mark.parametrize("num_xs", [1]) # batched x not supported for iid trials.
@pytest.mark.parametrize(
"num_conditions",
[
Expand Down Expand Up @@ -407,8 +407,10 @@ def test_log_likelihood_over_local_iid_theta(
)
x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
ll_single = torch.stack(ll_single).sum(0) # sum over trials
ll_single = (
torch.stack(ll_single).sum(0).squeeze(0)
) # sum over trials, squeeze x batch.

assert ll_batched.shape == torch.Size([num_xs, num_thetas])
assert ll_batched.shape == torch.Size([num_thetas])
assert ll_batched.shape == ll_single.shape
assert torch.allclose(ll_batched, ll_single, atol=1e-5)

0 comments on commit 9ecc68b

Please sign in to comment.