From 48747f0824998fff61e5a931924fe766cb6ba0c4 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Wed, 22 Jan 2025 10:20:05 +0000 Subject: [PATCH] fix coverage and testing bugs --- tests/mnle_test.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 01fd69b53..8d809af9e 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -40,9 +40,10 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2. return torch.cat((rts, choices), dim=1) -def wrapped_simulator( +def mixed_simulator_with_conditions( theta_and_condition: Tensor, last_idx_parameters: int = 2 ) -> Tensor: + """Simulator for mixed data with experimental conditions.""" # simulate with experiment conditions theta = theta_and_condition[:, :last_idx_parameters] condition = theta_and_condition[:, last_idx_parameters:] @@ -276,7 +277,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): ) theta = proposal.sample((num_simulations,)) - x = wrapped_simulator(theta) + x = mixed_simulator_with_conditions(theta) assert x.shape == (num_simulations, 2) num_trials = 10 @@ -287,14 +288,14 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): condition_o = theta_and_condition[:, 2:] theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1) - x_o = wrapped_simulator(theta_and_conditions_o) + x_o = mixed_simulator_with_conditions(theta_and_conditions_o) mcmc_kwargs = dict( method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate ) # MNLE - estimator_fun = likelihood_nn(model="mnle", z_score_x=None) + estimator_fun = likelihood_nn(model="mnle", log_transform_x=True) trainer = MNLE(proposal, estimator_fun) estimator = trainer.append_simulations(theta, x).train() @@ -311,6 +312,9 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): ], validate_args=False, ) + # test theta with sample shape. + conditioned_potential_fn(prior.sample((10,)).unsqueeze(0)) + prior_transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( BinomialGammaPotential( @@ -337,14 +341,28 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): @pytest.mark.parametrize("num_thetas", [1, 10]) @pytest.mark.parametrize("num_trials", [1, 5]) -@pytest.mark.parametrize("num_xs", [1]) # batched x not supported for iid trials. +@pytest.mark.parametrize( + "num_xs", + [ + 1, + pytest.param( + 2, + marks=pytest.mark.xfail( + reason="Batched x not supported for iid trials.", + raises=NotImplementedError, + ), + ), + ], +) @pytest.mark.parametrize( "num_conditions", [ 1, pytest.param( 2, - marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"), + marks=pytest.mark.xfail( + reason="Batched theta_condition is not supported", + ), ), ], ) @@ -376,7 +394,7 @@ def test_log_likelihood_over_local_iid_theta( num_simulations = 100 theta = proposal.sample((num_simulations,)) - x = wrapped_simulator(theta) + x = mixed_simulator_with_conditions(theta) estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1) # condition on multiple conditions