diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index d480e6997..69a1819f3 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -106,7 +106,6 @@ def train( clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, resume_training: bool = False, - force_first_round_loss: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, @@ -144,8 +143,6 @@ def train( force_first_round_loss: If `True`, train with maximum likelihood, i.e., potentially ignoring the correction for using a proposal distribution different from the prior. - force_first_round_loss: If `True`, train with maximum likelihood, - regardless of the proposal distribution. retrain_from_scratch: Whether to retrain the conditional density estimator for the posterior from scratch each round. Not supported for SNPE-A. @@ -174,6 +171,7 @@ def train( # SNPE-A always discards the prior samples. kwargs["discard_prior_samples"] = True + kwargs["force_first_round_loss"] = True self._round = max(self._data_round_index)