diff --git a/examples/HH_helper_functions.py b/examples/HH_helper_functions.py index 38cb91b06..ecced908b 100644 --- a/examples/HH_helper_functions.py +++ b/examples/HH_helper_functions.py @@ -145,7 +145,7 @@ def tau_p(x): + g_leak * E_leak + gbar_M * p[i - 1] * E_K + I[i - 1] - + nois_fact * rng.randn() / (tstep ** 0.5) + + nois_fact * rng.randn() / (tstep**0.5) ) / (tau_V_inv * C) V[i] = V_inf + (V[i - 1] - V_inf) * np.exp(-tstep * tau_V_inv) n[i] = n_inf(V[i]) + (n[i - 1] - n_inf(V[i])) * np.exp(-tstep / tau_n(V[i])) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 5a2fb7c57..880b12ef1 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -18,14 +18,7 @@ import sbi.inference from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.simulators.simutils import simulate_in_batches -from sbi.utils import ( - check_prior, - get_log_root, - handle_invalid_x, - warn_if_zscoring_changes_data, - warn_on_invalid_x, - warn_on_invalid_x_for_snpec_leakage, -) +from sbi.utils import check_prior, get_log_root from sbi.utils.sbiutils import get_simulations_since_round from sbi.utils.torchutils import check_if_prior_on_device, process_device from sbi.utils.user_input_checks import prepare_for_sbi @@ -128,7 +121,9 @@ def __init__( # Initialize roundwise (theta, x, prior_masks) for storage of parameters, # simulations and masks indicating if simulations came from prior. - self._theta_roundwise, self._x_roundwise, self._prior_masks = [], [], [] + self._theta_roundwise = [] + self._x_roundwise = [] + self._prior_masks = [] self._model_bank = [] # Initialize list that indicates the round from which simulations were drawn. @@ -159,8 +154,6 @@ def __init__( def get_simulations( self, starting_round: int = 0, - exclude_invalid_x: bool = True, - warn_on_invalid: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`. @@ -187,17 +180,7 @@ def get_simulations( self._prior_masks, self._data_round_index, starting_round ) - # Check for NaNs in simulations. - is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) - # Check for problematic z-scoring - warn_if_zscoring_changes_data(x[is_valid_x]) - if warn_on_invalid: - warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) - warn_on_invalid_x_for_snpec_leakage( - num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round - ) - - return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x] + return theta, x, prior_masks @abstractmethod def train( @@ -218,7 +201,7 @@ def train( def get_dataloaders( self, - dataset: data.TensorDataset, + starting_round: int = 0, training_batch_size: int = 50, validation_fraction: float = 0.1, resume_training: bool = False, @@ -239,14 +222,19 @@ def get_dataloaders( """ - # Get total number of training examples. - num_examples = len(dataset) + # + theta, x, prior_masks = self.get_simulations(starting_round) + + dataset = data.TensorDataset(theta, x, prior_masks) + # Get total number of training examples. + num_examples = theta.size(0) # Select random train and validation splits from (theta, x) pairs. num_training_examples = int((1 - validation_fraction) * num_examples) num_validation_examples = num_examples - num_training_examples if not resume_training: + # Seperate indicies for training and validation permuted_indices = torch.randperm(num_examples) self.train_indices, self.val_indices = ( permuted_indices[:num_training_examples], @@ -358,7 +346,11 @@ def _report_convergence_at_end( ) def _summarize( - self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor + self, + round_: int, + x_o: Union[Tensor, None], + theta_bank: Union[Tensor, None], + x_bank: Union[Tensor, None], ) -> None: """Update the summary_writer with statistics for a given round. diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index e9e7baace..1c04584b8 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -21,8 +21,11 @@ from sbi.utils import ( check_estimator_arg, check_prior, + handle_invalid_x, mask_sims_from_prior, validate_theta_and_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, x_shape_from_simulation, ) @@ -83,6 +86,7 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, + data_device: Optional[str] = None, ) -> "LikelihoodEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -99,16 +103,34 @@ def append_simulations( With default settings, this is not used at all for `SNLE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. Returns: NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True + + x = x[is_valid_x] + theta = theta[is_valid_x] + + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + + if data_device is None: + data_device = self._device + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) + + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) self._theta_roundwise.append(theta) self._x_roundwise.append(x) - self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0))) + self._prior_masks.append(prior_masks) + self._data_round_index.append(int(from_round)) return self @@ -121,7 +143,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -131,8 +152,6 @@ def train( r"""Train the density estimator to learn the distribution $p(x|\theta)$. Args: - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will @@ -150,20 +169,13 @@ def train( Returns: Density estimator that has learned the distribution $p(x|\theta)$. """ - - # Starting index for the training set (1 = discard round-0 samples). - start_idx = int(discard_prior_samples and self._round > 0) # Load data from most recent round. self._round = max(self._data_round_index) - theta, x, _ = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x) + # Starting index for the training set (1 = discard round-0 samples). + start_idx = int(discard_prior_samples and self._round > 0) train_loader, val_loader = self.get_dataloaders( - dataset, + start_idx, training_batch_size, validation_fraction, resume_training, @@ -176,10 +188,12 @@ def train( # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] - ) - self._x_shape = x_shape_from_simulation(x) + + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) + del theta, x assert ( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." @@ -257,8 +271,8 @@ def train( self._summarize( round_=self._round, x_o=None, - theta_bank=theta, - x_bank=x, + theta_bank=None, + x_bank=None, ) # Update description for progress bar. diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index e0edefd12..d480e6997 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -105,7 +105,6 @@ def train( max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, resume_training: bool = False, force_first_round_loss: bool = False, retrain_from_scratch: bool = False, @@ -138,8 +137,6 @@ def train( prevent exploding gradients. Use None for no clipping. calibration_kernel: A function to calibrate the loss with respect to the simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index bdc99e3b1..481b2de95 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -26,8 +26,12 @@ from sbi.utils import ( RestrictedPrior, check_estimator_arg, + handle_invalid_x, test_posterior_net_for_multi_d_x, validate_theta_and_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, + warn_on_invalid_x_for_snpec_leakage, x_shape_from_simulation, ) from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior @@ -88,6 +92,7 @@ def append_simulations( theta: Tensor, x: Tensor, proposal: Optional[DirectPosterior] = None, + data_device: Optional[str] = None, ) -> "PosteriorEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -103,12 +108,32 @@ def append_simulations( proposal: The distribution that the parameters $\theta$ were sampled from. Pass `None` if the parameters were sampled from the prior. If not `None`, it will trigger a different loss-function. + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. Returns: NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True + + x = x[is_valid_x] + theta = theta[is_valid_x] + + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + warn_on_invalid_x_for_snpec_leakage( + num_nans, num_infs, True, type(self).__name__, self._round + ) + + if data_device is None: + data_device = self._device + + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) self._check_proposal(proposal) if ( @@ -122,7 +147,7 @@ def append_simulations( # with MLE loss or with atomic loss (see, in `train()`: # self._round = max(self._data_round_index)) self._data_round_index.append(0) - self._prior_masks.append(mask_sims_from_prior(0, theta.size(0))) + prior_masks = mask_sims_from_prior(0, theta.size(0)) else: if not self._data_round_index: # This catches a pretty specific case: if, in the first round, one @@ -130,10 +155,12 @@ def append_simulations( self._data_round_index.append(1) else: self._data_round_index.append(max(self._data_round_index) + 1) - self._prior_masks.append(mask_sims_from_prior(1, theta.size(0))) + prior_masks = mask_sims_from_prior(1, theta.size(0)) self._theta_roundwise.append(theta) self._x_roundwise.append(x) + self._prior_masks.append(prior_masks) + self._proposal_roundwise.append(proposal) if self._prior is None or isinstance(self._prior, ImproperEmpirical): @@ -161,7 +188,6 @@ def train( max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, resume_training: bool = False, force_first_round_loss: bool = False, discard_prior_samples: bool = False, @@ -184,8 +210,6 @@ def train( prevent exploding gradients. Use None for no clipping. calibration_kernel: A function to calibrate the loss with respect to the simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will @@ -206,6 +230,9 @@ def train( Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ + # Load data from most recent round. + self._round = max(self._data_round_index) + if self._round == 0 and self._neural_net is not None: assert force_first_round_loss, ( "You have already trained this neural network. After you had trained " @@ -235,13 +262,6 @@ def train( if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): start_idx = self._round - theta, x, prior_masks = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x, prior_masks) - # Set the proposal to the last proposal that was passed by the user. For # atomic SNPE, it does not matter what the proposal is. For non-atomic # SNPE, we only use the latest data that was passed, i.e. the one from the @@ -249,31 +269,31 @@ def train( proposal = self._proposal_roundwise[-1] train_loader, val_loader = self.get_dataloaders( - dataset, + start_idx, training_batch_size, validation_fraction, resume_training, dataloader_kwargs=dataloader_kwargs, ) - # First round or if retraining from scratch: # Call the `self._build_neural_net` with the rounds' thetas and xs as # arguments, which will build the neural network. # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] + + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) + + test_posterior_net_for_multi_d_x( + self._neural_net, + theta.to("cpu"), + x.to("cpu"), ) - # If data on training device already move net as well. - if ( - not self._device == "cpu" - and f"{x.device.type}:{x.device.index}" == self._device - ): - self._neural_net.to(self._device) - test_posterior_net_for_multi_d_x(self._neural_net, theta, x) - self._x_shape = x_shape_from_simulation(x) + del theta, x # Move entire net to device for training. self._neural_net.to(self._device) @@ -365,7 +385,7 @@ def train( self._summary["best_validation_log_probs"].append(self._best_val_log_prob) # Update tensorboard and summary dict. - self._summarize(round_=self._round, x_o=None, theta_bank=theta, x_bank=x) + self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) # Update description for progress bar. if show_train_summary: diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index 165414908..34c53ab4b 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -93,7 +93,6 @@ def train( max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, - exclude_invalid_x: bool = True, resume_training: bool = False, force_first_round_loss: bool = False, discard_prior_samples: bool = False, @@ -118,8 +117,6 @@ def train( prevent exploding gradients. Use None for no clipping. calibration_kernel: A function to calibrate the loss with respect to the simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snre/snre_a.py b/sbi/inference/snre/snre_a.py index 1416abd1c..ce06a96e9 100644 --- a/sbi/inference/snre/snre_a.py +++ b/sbi/inference/snre/snre_a.py @@ -55,7 +55,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -75,8 +74,6 @@ def train( we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snre/snre_b.py b/sbi/inference/snre/snre_b.py index d2db9212a..bb53cd09d 100644 --- a/sbi/inference/snre/snre_b.py +++ b/sbi/inference/snre/snre_b.py @@ -56,7 +56,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -77,8 +76,6 @@ def train( we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. - exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` - during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 09898527e..f448eedd4 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -17,7 +17,10 @@ check_estimator_arg, check_prior, clamp_and_warn, + handle_invalid_x, validate_theta_and_x, + warn_if_zscoring_changes_data, + warn_on_invalid_x, x_shape_from_simulation, ) from sbi.utils.sbiutils import mask_sims_from_prior @@ -82,6 +85,7 @@ def append_simulations( theta: Tensor, x: Tensor, from_round: int = 0, + data_device: Optional[str] = None, ) -> "RatioEstimator": r"""Store parameters and simulation outputs to use them for later training. @@ -98,16 +102,35 @@ def append_simulations( With default settings, this is not used at all for `SNRE`. Only when the user later on requests `.train(discard_prior_samples=True)`, we use these indices to find which training data stemmed from the prior. - + data_device: Where to store the data, default is on the same device where + the training is happening. If training a large dataset on a GPU with not + much VRAM can set to 'cpu' to store data on system memory instead. Returns: NeuralInference object (returned so that this function is chainable). """ - theta, x = validate_theta_and_x(theta, x, training_device=self._device) + is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True + + x = x[is_valid_x] + theta = theta[is_valid_x] + + # Check for problematic z-scoring + warn_if_zscoring_changes_data(x) + warn_on_invalid_x(num_nans, num_infs, True) + + if data_device is None: + data_device = self._device + + theta, x = validate_theta_and_x( + theta, x, data_device=data_device, training_device=self._device + ) + + prior_masks = mask_sims_from_prior(int(from_round), theta.size(0)) self._theta_roundwise.append(theta) self._x_roundwise.append(x) - self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0))) + self._prior_masks.append(prior_masks) + self._data_round_index.append(int(from_round)) return self @@ -121,7 +144,6 @@ def train( stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, - exclude_invalid_x: bool = True, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, @@ -149,20 +171,13 @@ def train( Returns: Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. """ - - # Starting index for the training set (1 = discard round-0 samples). - start_idx = int(discard_prior_samples and self._round > 0) # Load data from most recent round. self._round = max(self._data_round_index) - theta, x, _ = self.get_simulations( - start_idx, exclude_invalid_x, warn_on_invalid=True - ) - - # Dataset is shared for training and validation loaders. - dataset = data.TensorDataset(theta, x) + # Starting index for the training set (1 = discard round-0 samples). + start_idx = int(discard_prior_samples and self._round > 0) train_loader, val_loader = self.get_dataloaders( - dataset, + start_idx, training_batch_size, validation_fraction, resume_training, @@ -183,11 +198,12 @@ def train( # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. if self._neural_net is None or retrain_from_scratch: - self._neural_net = self._build_neural_net( - theta[self.train_indices], x[self.train_indices] - ) - self._x_shape = x_shape_from_simulation(x) + # Get theta,x to initialize NN + theta, x, _ = self.get_simulations(starting_round=start_idx) + self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) + self._x_shape = x_shape_from_simulation(x.to("cpu")) + del x, theta self._neural_net.to(self._device) if not resume_training: @@ -260,8 +276,8 @@ def train( self._summarize( round_=self._round, x_o=None, - theta_bank=theta, - x_bank=x, + theta_bank=None, + x_bank=None, ) # Update description for progress bar. diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 66f5a5ec7..8bd46c4e2 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -647,7 +647,7 @@ def check_estimator_arg(estimator: Union[str, Callable]) -> None: def validate_theta_and_x( - theta: Any, x: Any, training_device: str = "cpu" + theta: Any, x: Any, data_device: str = "cpu", training_device: str = "cpu" ) -> Tuple[Tensor, Tensor]: r""" Checks if the passed $(\theta, x)$ are valid. @@ -657,6 +657,10 @@ def validate_theta_and_x( 2) If they have the same batchsize. 3) If they are of `dtype=float32`. + Additionally, We move the data to the specified `data_device`. This is where the + data is stored and can be separate from `training_device`, where the + computations for training are performed. + Raises: AssertionError: If theta or x are not torch.Tensor-like, do not yield the same batchsize and do not have dtype==float32. @@ -664,6 +668,7 @@ def validate_theta_and_x( Args: theta: Parameters. x: Simulation outputs. + data_device: Device where data is stored. training_device: Training device for net. """ assert isinstance(theta, Tensor), "Parameters theta must be a `torch.Tensor`." @@ -679,21 +684,21 @@ def validate_theta_and_x( assert theta.dtype == float32, "Type of parameters must be float32." assert x.dtype == float32, "Type of simulator outputs must be float32." - if str(x.device) != training_device: + if str(x.device) != data_device: warnings.warn( - f"Data x has device '{x.device}' " - f"different from the training_device '{training_device}', " - f"moving x to the training_device '{training_device}'." + f"Data x has device '{x.device}'." + f"Moving x to the data_device '{data_device}'." + f"Training will proceed on device '{training_device}'." ) - x = x.to(training_device) + x = x.to(data_device) - if str(theta.device) != training_device: + if str(theta.device) != data_device: warnings.warn( - f"Parameters theta has device '{theta.device}' " - f"different from the training_device '{training_device}', " - f"moving theta to the training_device '{training_device}'." + f"Parameters theta has device '{x.device}'. " + f"Moving theta to the data_device '{data_device}'." + f"Training will proceed on device '{training_device}'." ) - theta = theta.to(training_device) + theta = theta.to(data_device) return theta, x diff --git a/tests/base_test.py b/tests/base_test.py index ca78fce97..c1cef1131 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.utils.data import TensorDataset from sbi.inference import SNPE @@ -11,12 +10,10 @@ def test_get_dataloaders(training_batch_size): N = 1000 validation_fraction = 0.1 - dataset = TensorDataset(torch.ones(N), torch.zeros(N)) - inferer = SNPE() - + inferer.append_simulations(torch.ones(N), torch.zeros(N)) _, val_loader = inferer.get_dataloaders( - dataset, + 0, training_batch_size=training_batch_size, validation_fraction=validation_fraction, ) diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index 368c2be1d..a9d5378cc 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -102,9 +102,9 @@ def linear_gaussian_nan( inference = method(prior=prior) theta, x = simulate_for_sbi(simulator, prior, num_simulations) - _ = inference.append_simulations(theta, x).train( - exclude_invalid_x=exclude_invalid_x - ) + _ = inference.append_simulations( + theta, x, exclude_invalid_x=exclude_invalid_x + ).train() posterior = inference.build_posterior() samples = posterior.sample((num_samples,), x=x_o)