Skip to content

Commit

Permalink
fix coverage and testing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 22, 2025
1 parent 9ecc68b commit 48747f0
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.
return torch.cat((rts, choices), dim=1)


def wrapped_simulator(
def mixed_simulator_with_conditions(
theta_and_condition: Tensor, last_idx_parameters: int = 2
) -> Tensor:
"""Simulator for mixed data with experimental conditions."""
# simulate with experiment conditions
theta = theta_and_condition[:, :last_idx_parameters]
condition = theta_and_condition[:, last_idx_parameters:]
Expand Down Expand Up @@ -276,7 +277,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
)

theta = proposal.sample((num_simulations,))
x = wrapped_simulator(theta)
x = mixed_simulator_with_conditions(theta)
assert x.shape == (num_simulations, 2)

num_trials = 10
Expand All @@ -287,14 +288,14 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
condition_o = theta_and_condition[:, 2:]
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)

x_o = wrapped_simulator(theta_and_conditions_o)
x_o = mixed_simulator_with_conditions(theta_and_conditions_o)

mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
)

# MNLE
estimator_fun = likelihood_nn(model="mnle", z_score_x=None)
estimator_fun = likelihood_nn(model="mnle", log_transform_x=True)
trainer = MNLE(proposal, estimator_fun)
estimator = trainer.append_simulations(theta, x).train()

Expand All @@ -311,6 +312,9 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
],
validate_args=False,
)
# test theta with sample shape.
conditioned_potential_fn(prior.sample((10,)).unsqueeze(0))

prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
BinomialGammaPotential(
Expand All @@ -337,14 +341,28 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):

@pytest.mark.parametrize("num_thetas", [1, 10])
@pytest.mark.parametrize("num_trials", [1, 5])
@pytest.mark.parametrize("num_xs", [1]) # batched x not supported for iid trials.
@pytest.mark.parametrize(
"num_xs",
[
1,
pytest.param(
2,
marks=pytest.mark.xfail(
reason="Batched x not supported for iid trials.",
raises=NotImplementedError,
),
),
],
)
@pytest.mark.parametrize(
"num_conditions",
[
1,
pytest.param(
2,
marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"),
marks=pytest.mark.xfail(
reason="Batched theta_condition is not supported",
),
),
],
)
Expand Down Expand Up @@ -376,7 +394,7 @@ def test_log_likelihood_over_local_iid_theta(

num_simulations = 100
theta = proposal.sample((num_simulations,))
x = wrapped_simulator(theta)
x = mixed_simulator_with_conditions(theta)
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)

# condition on multiple conditions
Expand Down

0 comments on commit 48747f0

Please sign in to comment.