From 9ecc68bba72aec3531218dd9acd67941f972e229 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Tue, 21 Jan 2025 18:02:52 +0000 Subject: [PATCH] fix: do not test batched and iid x --- tests/mnle_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 778cc44cd..01fd69b53 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -337,7 +337,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): @pytest.mark.parametrize("num_thetas", [1, 10]) @pytest.mark.parametrize("num_trials", [1, 5]) -@pytest.mark.parametrize("num_xs", [1, 3]) +@pytest.mark.parametrize("num_xs", [1]) # batched x not supported for iid trials. @pytest.mark.parametrize( "num_conditions", [ @@ -407,8 +407,10 @@ def test_log_likelihood_over_local_iid_theta( ) x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1) ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition)) - ll_single = torch.stack(ll_single).sum(0) # sum over trials + ll_single = ( + torch.stack(ll_single).sum(0).squeeze(0) + ) # sum over trials, squeeze x batch. - assert ll_batched.shape == torch.Size([num_xs, num_thetas]) + assert ll_batched.shape == torch.Size([num_thetas]) assert ll_batched.shape == ll_single.shape assert torch.allclose(ll_batched, ll_single, atol=1e-5)