diff --git a/sbibm/algorithms/sbi/snle.py b/sbibm/algorithms/sbi/snle.py index eb69ca0c..aa609798 100644 --- a/sbibm/algorithms/sbi/snle.py +++ b/sbibm/algorithms/sbi/snle.py @@ -37,8 +37,8 @@ def run( "num_candidate_samples": 10000, }, }, - z_score_x: bool = True, - z_score_theta: bool = True, + z_score_x: str = "independent", + z_score_theta: str = "independent", max_num_epochs: Optional[int] = 2**31 - 1, ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]: """Runs (S)NLE from `sbi` @@ -126,8 +126,6 @@ def run( show_train_summary=True, max_num_epochs=max_num_epochs, ) - if r > 1: - mcmc_parameters["init_strategy"] = "latest_sample" posterior = inference_method.build_posterior( density_estimator=density_estimator, @@ -135,6 +133,12 @@ def run( mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters, ) + # Change init_strategy to latest_sample after second round. + if r > 1: + posterior.init_strategy = "latest_sample" + # copy init params from round 2 posterior. + posterior._mcmc_init_params = proposal._mcmc_init_params + proposal = posterior.set_default_x(observation) posteriors.append(posterior) diff --git a/sbibm/algorithms/sbi/snpe.py b/sbibm/algorithms/sbi/snpe.py index 98f6dd29..cb7b2d01 100644 --- a/sbibm/algorithms/sbi/snpe.py +++ b/sbibm/algorithms/sbi/snpe.py @@ -27,8 +27,8 @@ def run( training_batch_size: int = 10000, num_atoms: int = 10, automatic_transforms_enabled: bool = False, - z_score_x: bool = True, - z_score_theta: bool = True, + z_score_x: str = "independent", + z_score_theta: str = "independent", max_num_epochs: Optional[int] = 2**31 - 1, ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]: """Runs (S)NPE from `sbi` diff --git a/sbibm/algorithms/sbi/snre.py b/sbibm/algorithms/sbi/snre.py index ac785800..1b60058c 100644 --- a/sbibm/algorithms/sbi/snre.py +++ b/sbibm/algorithms/sbi/snre.py @@ -38,8 +38,8 @@ def run( "num_candidate_samples": 10000, }, }, - z_score_x: bool = True, - z_score_theta: bool = True, + z_score_x: str = "independent", + z_score_theta: str = "independent", variant: str = "B", max_num_epochs: Optional[int] = 2**31 - 1, ) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]: @@ -137,13 +137,17 @@ def run( max_num_epochs=max_num_epochs, **inference_method_kwargs, ) - if r > 1: - mcmc_parameters["init_strategy"] = "latest_sample" + posterior = inference_method.build_posterior( density_estimator, mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters, ) + # Change init_strategy to latest_sample after second round. + if r > 1: + posterior.init_strategy = "latest_sample" + # copy init params from round 2 posterior. + posterior._mcmc_init_params = proposal._mcmc_init_params proposal = posterior.set_default_x(observation) posteriors.append(posterior) diff --git a/tests/algorithms/sbi/test_sbi_run_methods.py b/tests/algorithms/sbi/test_sbi_run_methods.py index 273608a6..5f18dded 100644 --- a/tests/algorithms/sbi/test_sbi_run_methods.py +++ b/tests/algorithms/sbi/test_sbi_run_methods.py @@ -15,12 +15,13 @@ def test_sbi_api( run_method, task_name, num_observation, num_simulations=2_000, num_samples=100 ): task = sbibm.get_task(task_name) + num_rounds = 4 if run_method in (mcabc, smcabc, sl): # abc algorithms kwargs = dict() else: # neural algorithms kwargs = dict( - num_rounds=2, + num_rounds=num_rounds, training_batch_size=100, neural_net="mlp" if run_method == snre else "maf", )