diff --git a/docs/docs/faq/question_04.md b/docs/docs/faq/question_04.md index c8e785b2a..b8963a5bb 100644 --- a/docs/docs/faq/question_04.md +++ b/docs/docs/faq/question_04.md @@ -1,18 +1,23 @@ # Can I use the GPU for training the density estimator? -TLDR; Yes, by passing `device="cuda"`. But no speed-ups for default density estimators. +TLDR; Yes, by passing `device="cuda"` and by passing a prior that lives on the device +name your passed. But no speed-ups for default density estimators. Yes. When creating the inference object in the flexible interface, you can pass the `device` as an argument, e.g., ```python -inference = SNPE(simulator, prior, device="cuda", density_estimator="maf") +inference = SNPE(prior, device="cuda", density_estimator="maf") ``` The device is set to `"cpu"` by default, and it can be set to anything, as long as it maps to an existing PyTorch CUDA device. `sbi` will take care of copying the `net` and the training data to and from the `device`. +Note that the prior must be on the training device already, e.g., when passing `device="cuda:0"`, +make sure to pass a prior object that was created on that device, e.g., +`prior = torch.distributions.MultivariateNormal(loc=torch.zeros(2, device="cuda:0"), + covariance_matrix=torch.eye(2, device="cuda:0"))`. ## Performance diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 7907b0196..9ad73ac21 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -46,8 +46,8 @@ def eval_conditional_density( eps_margins2: We will evaluate the posterior along `dim2` at `limits[0]+eps_margins` until `limits[1]-eps_margins`. This avoids evaluations potentially exactly at the prior bounds. - return_raw_log_prob: If `True`, return the log-probability evaluated on the· - grid. If `False`, return the probability, scaled down by the maximum value· + return_raw_log_prob: If `True`, return the log-probability evaluated on the + grid. If `False`, return the probability, scaled down by the maximum value on the grid for numerical stability (i.e. exp(log_prob - max_log_prob)). Returns: Conditional probabilities. If `dim1 == dim2`, this will have shape diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 0ee2e8c45..128a48b08 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -115,7 +115,7 @@ def __init__( 0.14.0 is more mature, we will remove this argument. """ - self._device = process_device(device) + self._device = process_device(device, prior=prior) if unused_args: warn( @@ -381,9 +381,7 @@ def _default_summary_writer(self) -> SummaryWriter: method = self.__class__.__name__ logdir = Path( - get_log_root(), - method, - datetime.now().isoformat().replace(":", "_"), + get_log_root(), method, datetime.now().isoformat().replace(":", "_") ) return SummaryWriter(logdir) @@ -437,11 +435,7 @@ def _report_convergence_at_end( ) def _summarize( - self, - round_: int, - x_o: Union[Tensor, None], - theta_bank: Tensor, - x_bank: Tensor, + self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor ) -> None: """Update the summary_writer with statistics for a given round. @@ -456,12 +450,7 @@ def _summarize( # Median |x - x0| for most recent round. if x_o is not None: median_observation_distance = torch.median( - torch.sqrt( - torch.sum( - (x_bank - x_o.reshape(1, -1)) ** 2, - dim=-1, - ) - ) + torch.sqrt(torch.sum((x_bank - x_o.reshape(1, -1)) ** 2, dim=-1)) ) self._summary["median_observation_distances"].append( median_observation_distance.item() @@ -558,11 +547,7 @@ def simulate_for_sbi( theta = proposal.sample((num_simulations,)) x = simulate_in_batches( - simulator, - theta, - simulation_batch_size, - num_workers, - show_progress_bar, + simulator, theta, simulation_batch_size, num_workers, show_progress_bar ) return theta, x diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 97ef95087..20b8c0788 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -149,7 +149,9 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior": Returns: `NeuralPosterior` that will use a default `x` when not explicitly passed. """ - self._x = process_x(x, self._x_shape, allow_iid_x=self._allow_iid_x) + self._x = process_x(x, self._x_shape, allow_iid_x=self._allow_iid_x).to( + self._device + ) self._num_iid_trials = self._x.shape[0] return self @@ -358,12 +360,10 @@ def _prepare_theta_and_x_for_log_prob_( self._ensure_single_x(x) self._ensure_x_consistent_with_default_x(x) - return theta, x + return theta.to(self._device), x.to(self._device) def _prepare_for_sample( - self, - x: Tensor, - sample_shape: Optional[Tensor], + self, x: Tensor, sample_shape: Optional[Tensor] ) -> Tuple[Tensor, int]: r""" Return checked, reshaped, potentially default values for `x` and `sample_shape`. @@ -835,7 +835,8 @@ def map( def potential_fn(theta): return self.log_prob(theta, x=x, track_gradients=True, **log_prob_kwargs) - interruption_note = "The last estimate of the MAP can be accessed via the `posterior.map_` attribute." + interruption_note = """The last estimate of the MAP can be accessed via the + `posterior.map_` attribute.""" self.map_, _ = optimize_potential_fn( potential_fn=potential_fn, @@ -1174,7 +1175,7 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat: theta_condition = deepcopy(self.condition) theta_condition[:, self.dims_to_sample] = theta - return self.potential_fn_provider.np_potential( + return self.potential_fn_provider.posterior_potential( utils.tensor2numpy(theta_condition) ) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 932659abf..a055ce3fd 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -82,10 +82,7 @@ def __init__( device: Training device, e.g., cpu or cuda:0 """ - kwargs = del_entries( - locals(), - entries=("self", "__class__"), - ) + kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs) self._purpose = ( @@ -183,9 +180,7 @@ def log_prob( with torch.set_grad_enabled(track_gradients): # Evaluate on device, move back to cpu for comparison with prior. - unnorm_log_prob = self.net.log_prob( - theta_repeated.to(self._device), x_repeated.to(self._device) - ).cpu() + unnorm_log_prob = self.net.log_prob(theta_repeated, x_repeated) # Force probability to be zero outside prior support. in_prior_support = within_support(self._prior, theta) @@ -193,7 +188,7 @@ def log_prob( masked_log_prob = torch.where( in_prior_support, unnorm_log_prob, - torch.tensor(float("-inf"), dtype=torch.float32), + torch.tensor(float("-inf"), dtype=torch.float32, device=self._device), ) if leakage_correction_params is None: @@ -556,11 +551,7 @@ class PotentialFunctionProvider: """ def __call__( - self, - prior, - posterior_nn: nn.Module, - x: Tensor, - method: str, + self, prior, posterior_nn: nn.Module, x: Tensor, method: str ) -> Callable: """Return potential function. @@ -583,59 +574,47 @@ def __call__( NotImplementedError def posterior_potential( - self, theta: np.ndarray, track_gradients: bool = False - ) -> ScalarFloat: - r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior." - - Args: - theta: Parameters $\theta$, batch dimension 1. + self, theta: Union[Tensor, np.array], track_gradients: bool = False + ) -> Tensor: + "Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior." - Returns: - Posterior log probability $\log(p(\theta|x))$. - """ - theta = torch.as_tensor(theta, dtype=torch.float32) - theta = ensure_theta_batched(theta) - num_batch = theta.shape[0] + # Device is the same for net and prior. + theta = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to( + self.device + ) - # Repeat x over batch dim to match theta batch, accounting for multi-D x. - x_repeated = self.x.repeat(num_batch, *(1 for _ in range(self.x.ndim - 1))) + theta_repeated, x_repeated = DirectPosterior._match_theta_and_x_batch_shapes( + theta, self.x + ) with torch.set_grad_enabled(track_gradients): - target_log_prob = self.posterior_nn.log_prob( - inputs=theta.to(self.device), - context=x_repeated, - ) + + # Evaluate on device, move back to cpu for comparison with prior. + posterior_log_prob = self.posterior_nn.log_prob(theta_repeated, x_repeated) + + # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) - target_log_prob[~in_prior_support] = -float("Inf") - return target_log_prob + posterior_log_prob = torch.where( + in_prior_support, + posterior_log_prob, + torch.tensor(float("-inf"), dtype=torch.float32, device=self.device), + ) + + return posterior_log_prob def pyro_potential( self, theta: Dict[str, Tensor], track_gradients: bool = False ) -> Tensor: - r"""Return posterior log prob. of theta $p(\theta|x)$, -inf where outside prior. + r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior." Args: theta: Parameters $\theta$ (from pyro sampler). Returns: - Posterior log probability $p(\theta|x)$, masked outside of prior. + Negative posterior log probability $p(\theta|x)$, masked outside of prior. """ theta = next(iter(theta.values())) - with torch.set_grad_enabled(track_gradients): - # Notice opposite sign to `posterior_potential`. - # Move theta to device for evaluation. - log_prob_posterior = -self.posterior_nn.log_prob( - inputs=theta.to(self.device), - context=self.x, - ).cpu() - - in_prior_support = within_support(self.prior, theta) - - return torch.where( - in_prior_support, - log_prob_posterior, - float("-inf") * torch.ones_like(log_prob_posterior), - ) + return -self.posterior_potential(theta, track_gradients=track_gradients) diff --git a/sbi/inference/posteriors/likelihood_based_posterior.py b/sbi/inference/posteriors/likelihood_based_posterior.py index 4797d5b49..119862b27 100644 --- a/sbi/inference/posteriors/likelihood_based_posterior.py +++ b/sbi/inference/posteriors/likelihood_based_posterior.py @@ -120,7 +120,7 @@ def log_prob( ) # Move to cpu for comparison with prior. - return log_likelihood_trial_sum.cpu() + self._prior.log_prob(theta) + return log_likelihood_trial_sum + self._prior.log_prob(theta) def sample( self, @@ -355,10 +355,7 @@ def map( @staticmethod def _log_likelihoods_over_trials( - x: Tensor, - theta: Tensor, - net: nn.Module, - track_gradients: bool = False, + x: Tensor, theta: Tensor, net: nn.Module, track_gradients: bool = False ) -> Tensor: r"""Return log likelihoods summed over iid trials of `x`. @@ -423,11 +420,7 @@ class PotentialFunctionProvider: """ def __call__( - self, - prior, - likelihood_nn: nn.Module, - x: Tensor, - method: str, + self, prior, likelihood_nn: nn.Module, x: Tensor, method: str ) -> Callable: r"""Return potential function for posterior $p(\theta|x)$. @@ -459,35 +452,23 @@ def __call__( else: NotImplementedError - def log_likelihood(self, theta: Tensor, track_gradients: bool = False) -> Tensor: + def posterior_potential( + self, theta: Union[Tensor, np.array], track_gradients: bool = False + ) -> Tensor: """Return log likelihood of fixed data given a batch of parameters.""" + # Device is the same for net and prior. + theta = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to( + self.device + ) + log_likelihoods = LikelihoodBasedPosterior._log_likelihoods_over_trials( x=self.x, - theta=ensure_theta_batched(theta).to(self.device), + theta=theta, net=self.likelihood_nn, track_gradients=track_gradients, ) - - return log_likelihoods - - def posterior_potential( - self, theta: np.array, track_gradients: bool = False - ) -> ScalarFloat: - r"""Return posterior log prob. of theta $p(\theta|x)$" - - Args: - theta: Parameters $\theta$, batch dimension 1. - - Returns: - Posterior log probability of the theta, $-\infty$ if impossible under prior. - """ - theta = torch.as_tensor(theta, dtype=torch.float32) - - # Notice opposite sign to pyro potential. - return self.log_likelihood( - theta, track_gradients=track_gradients - ).cpu() + self.prior.log_prob(theta) + return log_likelihoods + self.prior.log_prob(theta) def pyro_potential( self, theta: Dict[str, Tensor], track_gradients: bool = False @@ -505,7 +486,5 @@ def pyro_potential( theta = next(iter(theta.values())) - return -( - self.log_likelihood(theta, track_gradients=track_gradients).cpu() - + self.prior.log_prob(theta) - ) + # Note the minus to match the pyro potential function requirements. + return -self.posterior_potential(theta, track_gradients=track_gradients) diff --git a/sbi/inference/posteriors/ratio_based_posterior.py b/sbi/inference/posteriors/ratio_based_posterior.py index dc99214b8..8f7b9c14f 100644 --- a/sbi/inference/posteriors/ratio_based_posterior.py +++ b/sbi/inference/posteriors/ratio_based_posterior.py @@ -117,7 +117,7 @@ def log_prob( track_gradients=track_gradients, ) - return log_ratio.cpu() + self._prior.log_prob(theta) + return log_ratio + self._prior.log_prob(theta) def _warn_log_prob_snre(self) -> None: if self._method_family == "snre_a": @@ -390,10 +390,7 @@ def _num_trained_rounds(self, trained_rounds: int) -> None: @staticmethod def _log_ratios_over_trials( - x: Tensor, - theta: Tensor, - net: nn.Module, - track_gradients: bool = False, + x: Tensor, theta: Tensor, net: nn.Module, track_gradients: bool = False ) -> Tensor: r"""Return log ratios summed over iid trials of `x`. @@ -453,11 +450,7 @@ class PotentialFunctionProvider: """ def __call__( - self, - prior, - classifier: nn.Module, - x: Tensor, - method: str, + self, prior, classifier: nn.Module, x: Tensor, method: str ) -> Callable: r"""Return potential function for posterior $p(\theta|x)$. @@ -491,8 +484,8 @@ def __call__( NotImplementedError def posterior_potential( - self, theta: np.array, track_gradients: bool = False - ) -> ScalarFloat: + self, theta: Union[Tensor, np.array], track_gradients: bool = False + ) -> Tensor: """Returns the unnormalized posterior log-probability. This is the potential used in the numpy slice sampler and in rejection sampling. @@ -503,18 +496,19 @@ def posterior_potential( Returns: Posterior log probability of theta. """ - theta = torch.as_tensor(theta, dtype=torch.float32) - theta = ensure_theta_batched(theta) + + # Device is the same for net and prior. + theta = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to( + self.device + ) log_ratio = RatioBasedPosterior._log_ratios_over_trials( self.x, - theta.to(self.device), + theta, self.classifier, track_gradients=track_gradients, ) - - # Notice opposite sign to pyro potential. - return log_ratio.cpu() + self.prior.log_prob(theta) + return log_ratio + self.prior.log_prob(theta) def pyro_potential( self, theta: Dict[str, Tensor], track_gradients: bool = False @@ -534,14 +528,5 @@ def pyro_potential( theta = next(iter(theta.values())) - # Theta and x should have shape (1, dim). - theta = ensure_theta_batched(theta) - - log_ratio = RatioBasedPosterior._log_ratios_over_trials( - self.x, - theta.to(self.device), - self.classifier, - track_gradients=track_gradients, - ) - - return -(log_ratio.cpu() + self.prior.log_prob(theta)) + # Note the minus to match the pyro potential function requirements. + return -self.posterior_potential(theta, track_gradients=track_gradients) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 2f4e7a621..d3a18e6a6 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -95,12 +95,7 @@ def __init__( # to pass arguments between functions, and that's implicit state management. kwargs = utils.del_entries( locals(), - entries=( - "self", - "__class__", - "unused_args", - "num_components", - ), + entries=("self", "__class__", "unused_args", "num_components"), ) super().__init__(**kwargs) @@ -425,10 +420,7 @@ def log_prob(self, inputs, context=None): # Compute the log_prob of theta under the product. log_prob_proposal_posterior = utils.sbiutils.mog_log_prob( - theta, - logits_pp, - m_pp, - prec_pp, + theta, logits_pp, m_pp, prec_pp ) utils.assert_all_finite( log_prob_proposal_posterior, "proposal posterior eval" @@ -527,12 +519,7 @@ def _posthoc_correction(self, x: Tensor): # Compute the MoG parameters of the posterior. logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation( - logits_pp, - m_pp, - prec_pp, - norm_logits_d, - m_d, - prec_d, + logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d ) return logits_p, m_p, prec_p @@ -575,11 +562,7 @@ def _proposal_posterior_transformation( ) means_post = self._means_posterior( - covariances_post, - means_pp, - precisions_pp, - means_d, - precisions_d, + covariances_post, means_pp, precisions_pp, means_d, precisions_d ) logits_post = SNPE_A_MDN._logits_posterior( @@ -645,9 +628,8 @@ def _set_maybe_z_scored_prior(self) -> None: prior will not be exactly have mean=0 and std=1. """ if self.z_score_theta: - # Default to cpu to avoid mismatch with prior (always on cpu) below. - scale = self._neural_net._transform._transforms[0]._scale.cpu() - shift = self._neural_net._transform._transforms[0]._shift.cpu() + scale = self._neural_net._transform._transforms[0]._scale + shift = self._neural_net._transform._transforms[0]._shift # Following the definition of the linear transform in # `standardizing_transform` in `sbiutils.py`: diff --git a/sbi/mcmc/slice.py b/sbi/mcmc/slice.py index 21309d624..1389fe587 100644 --- a/sbi/mcmc/slice.py +++ b/sbi/mcmc/slice.py @@ -154,17 +154,20 @@ def _log_prob_d(x): ) ).unsqueeze( 0 - ) # TODO: The unsqueeze seems to give a speed up, figure out when this is the case exactly + ) # TODO: The unsqueeze seems to give a speed up, figure out when + # this is the case exactly } ) # Sample uniformly from slice log_height = _log_prob_d(params[self._site_name].view(-1)[dim]) + torch.log( - torch.rand(1) + torch.rand(1, device=params[self._site_name].device) ) # Position the bracket randomly around the current sample - lower = params[self._site_name].view(-1)[dim] - self._width[dim] * torch.rand(1) + lower = params[self._site_name].view(-1)[dim] - self._width[dim] * torch.rand( + 1, device=params[self._site_name].device + ) upper = lower + self._width[dim] # Find lower bracket end @@ -182,7 +185,9 @@ def _log_prob_d(x): upper += self._width[dim] # Sample uniformly from bracket - new_parameter = (upper - lower) * torch.rand(1) + lower + new_parameter = (upper - lower) * torch.rand( + 1, device=params[self._site_name].device + ) + lower # If outside slice, reject sample and shrink bracket while _log_prob_d(new_parameter) < log_height: @@ -190,6 +195,8 @@ def _log_prob_d(x): lower = new_parameter else: upper = new_parameter - new_parameter = (upper - lower) * torch.rand(1) + lower + new_parameter = (upper - lower) * torch.rand( + 1, device=params[self._site_name].device + ) + lower return new_parameter, upper - lower diff --git a/sbi/utils/metrics.py b/sbi/utils/metrics.py index 82115b8a6..59de7bd14 100644 --- a/sbi/utils/metrics.py +++ b/sbi/utils/metrics.py @@ -9,6 +9,8 @@ from sklearn.neural_network import MLPClassifier from torch import Tensor +from sbi.utils import tensor2numpy + def c2st( X: Tensor, @@ -45,8 +47,8 @@ def c2st( X += noise_scale * torch.randn(X.shape) Y += noise_scale * torch.randn(Y.shape) - X = X.cpu().numpy() - Y = Y.cpu().numpy() + X = tensor2numpy(X) + Y = tensor2numpy(Y) ndim = X.shape[1] @@ -59,12 +61,7 @@ def c2st( ) data = np.concatenate((X, Y)) - target = np.concatenate( - ( - np.zeros((X.shape[0],)), - np.ones((Y.shape[0],)), - ) - ) + target = np.concatenate((np.zeros((X.shape[0],)), np.ones((Y.shape[0],)))) shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed) scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index e67f73f4d..0ba3a2b30 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -217,10 +217,8 @@ def rejection_sample_posterior_within_prior( while num_remaining > 0: # Sample and reject. - candidates = ( - posterior_nn.sample(sampling_batch_size, context=x) - .reshape(sampling_batch_size, -1) - .cpu() + candidates = posterior_nn.sample(sampling_batch_size, context=x).reshape( + sampling_batch_size, -1 ) # SNPE-style rejection-sampling when the proposal is the neural net. @@ -332,7 +330,7 @@ def rejection_sample( # Define a potential as the ratio between target distribution and proposal. def potential_over_proposal(theta): - return potential_fn(theta) - proposal.log_prob(theta).cpu() + return potential_fn(theta) - proposal.log_prob(theta) # Search for the maximum of the ratio. _, max_log_ratio = optimize_potential_fn( @@ -398,14 +396,11 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: sampling_batch_size, -1 ) - # `candidates` will lie on CPU if the proposal is the prior and - # possibly on the GPU if the proposal is a neural net. Everything returned - # by the `potential_fn` will lie on the CPU. target_proposal_ratio = torch.exp( - potential_fn(candidates) - proposal.log_prob(candidates).cpu() + potential_fn(candidates) - proposal.log_prob(candidates) ) uniform_rand = torch.rand(target_proposal_ratio.shape) - samples = candidates.cpu()[target_proposal_ratio > uniform_rand] + samples = candidates[target_proposal_ratio > uniform_rand] accepted.append(samples) @@ -551,9 +546,7 @@ def check_warn_and_setstate( def get_simulations_since_round( - data: List, - data_round_indices: List, - starting_round_index: int, + data: List, data_round_indices: List, starting_round_index: int ) -> Tensor: """ Returns tensor with all data coming from a round >= `starting_round`. @@ -919,10 +912,7 @@ def log_prob(self, value: Tensor) -> Tensor: def mog_log_prob( - theta: Tensor, - logits_pp: Tensor, - means_pp: Tensor, - precisions_pp: Tensor, + theta: Tensor, logits_pp: Tensor, means_pp: Tensor, precisions_pp: Tensor ) -> Tensor: r""" Returns the log-probability of parameter sets $\theta$ under a mixture of Gaussians. diff --git a/sbi/utils/torchutils.py b/sbi/utils/torchutils.py index 880b51677..dd279c4eb 100644 --- a/sbi/utils/torchutils.py +++ b/sbi/utils/torchutils.py @@ -4,7 +4,7 @@ """Various PyTorch utility functions.""" import warnings -from typing import Union +from typing import Any, Optional, Union import numpy as np import torch @@ -15,8 +15,11 @@ from sbi.types import Array, OneOrMore, ScalarFloat -def process_device(device: str) -> str: - """Set and return the default device to cpu or gpu.""" +def process_device(device: str, prior: Optional[Any] = None) -> str: + """Set and return the default device to cpu or gpu. + + Throws an AssertionError if the prior is not matching the training device not. + """ if not device == "cpu": if device == "gpu": @@ -34,6 +37,16 @@ def process_device(device: str) -> str: warnings.warn(f"Device {device} not available, falling back to CPU.") device = "cpu" + if prior is not None: + prior_device = prior.sample((1,)).device + training_device = torch.zeros(1, device=device).device + assert ( + prior_device == training_device + ), f"""Prior ({prior_device}) device must match training device ( + {training_device}). When training on GPU make sure to pass a prior + initialized on the GPU as well, e.g., `prior = torch.distributions.Normal + (torch.zeros(2, device='cuda'), scale=1.0)`.""" + return device @@ -63,7 +76,8 @@ def split_leading_dim(x, shape): def merge_leading_dims(x, num_dims): - """Reshapes the tensor `x` such that the first `num_dims` dimensions are merged to one.""" + """Reshapes the tensor `x` such that the first `num_dims` dimensions are merged to + one.""" if not utils.is_positive_int(num_dims): raise TypeError("Number of leading dims must be a positive integer.") if num_dims > x.dim(): @@ -209,6 +223,7 @@ def __init__( low: ScalarFloat, high: ScalarFloat, reinterpreted_batch_ndims: int = 1, + device: str = "cpu", ): """Multidimensional uniform distribution defined on a box. @@ -227,12 +242,14 @@ def __init__( high: upper range (exclusive). reinterpreted_batch_ndims (int): the number of batch dims to reinterpret as event dims. + device: device of the prior, defaults to "cpu", should match the training + device when used in SBI. """ super().__init__( Uniform( - low=torch.as_tensor(low, dtype=torch.float32), - high=torch.as_tensor(high, dtype=torch.float32), + low=torch.as_tensor(low, dtype=torch.float32, device=device), + high=torch.as_tensor(high, dtype=torch.float32, device=device), validate_args=False, ), reinterpreted_batch_ndims, diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py index 1daf0bf18..b1daae455 100644 --- a/sbi/utils/user_input_checks_utils.py +++ b/sbi/utils/user_input_checks_utils.py @@ -58,7 +58,7 @@ def _set_mean_and_variance(self): torch.as_tensor(self.custom_prior.sample((1000,))), dim=0 ) warnings.warn( - "Prior is lacking mean attribute, estimating prior mean from samples...", + "Prior is lacking mean attribute, estimating prior mean from samples.", UserWarning, ) if hasattr(self.custom_prior, "variance"): @@ -174,9 +174,7 @@ class MultipleIndependent(Distribution): Uniform(torch.ones(1), 2.0 * torch.ones(1))] """ - def __init__( - self, dists: Sequence[Distribution], validate_args=None, - ): + def __init__(self, dists: Sequence[Distribution], validate_args=None): self._check_distributions(dists) self.dists = dists diff --git a/tests/abc_test.py b/tests/abc_test.py index d4c8522a1..b36b5c276 100644 --- a/tests/abc_test.py +++ b/tests/abc_test.py @@ -14,10 +14,7 @@ @pytest.mark.parametrize("num_dim", (1, 2)) def test_mcabc_inference_on_linear_gaussian( - num_dim, - lra=False, - sass=False, - sass_expansion_degree=1, + num_dim, lra=False, sass=False, sass_expansion_degree=1 ): x_o = zeros((1, num_dim)) num_samples = 1000 @@ -58,10 +55,7 @@ def simulator(theta): def test_mcabc_sass_lra(lra, sass_expansion_degree, set_seed): test_mcabc_inference_on_linear_gaussian( - num_dim=2, - lra=lra, - sass=True, - sass_expansion_degree=sass_expansion_degree, + num_dim=2, lra=lra, sass=True, sass_expansion_degree=sass_expansion_degree ) @@ -86,11 +80,7 @@ def test_smcabc_inference_on_linear_gaussian( elif prior_type == "uniform": prior = BoxUniform(-ones(num_dim), ones(num_dim)) target_samples = samples_true_posterior_linear_gaussian_uniform_prior( - x_o[0], - likelihood_shift, - likelihood_cov, - prior, - num_samples, + x_o[0], likelihood_shift, likelihood_cov, prior, num_samples ) else: raise ValueError("Wrong prior string.") @@ -124,5 +114,9 @@ def simulator(theta): def test_smcabc_sass_lra(lra, sass_expansion_degree, set_seed): test_smcabc_inference_on_linear_gaussian( - num_dim=2, lra=lra, sass=True, sass_expansion_degree=sass_expansion_degree + num_dim=2, + lra=lra, + sass=True, + sass_expansion_degree=sass_expansion_degree, + prior_type="gaussian", ) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index a3c7437e3..711eb60c4 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -11,32 +11,47 @@ from sbi import utils as utils from sbi.inference import SNLE, SNPE_A, SNPE_C, SNRE_A, SNRE_B, simulate_for_sbi from sbi.simulators import linear_gaussian -from sbi.utils.torchutils import process_device +from sbi.utils.torchutils import BoxUniform, process_device @pytest.mark.slow @pytest.mark.gpu @pytest.mark.parametrize( - "method, model", + "method, model, mcmc_method", [ - (SNPE_A, "mdn_snpe_a"), - (SNPE_C, "mdn"), - (SNPE_C, "maf"), - (SNLE, "maf"), - (SNLE, "nsf"), - (SNRE_A, "mlp"), - (SNRE_B, "resnet"), + (SNPE_C, "mdn", "rejection"), + (SNPE_C, "maf", "rejection"), + (SNLE, "maf", "slice"), + (SNLE, "nsf", "slice_np"), + (SNRE_A, "mlp", "slice_np_vectorized"), + (SNRE_B, "resnet", "nuts"), ], ) @pytest.mark.parametrize("data_device", ("cpu", "cuda:0")) -@pytest.mark.parametrize("training_device", ("cpu", "cuda:0")) -def test_training_and_mcmc_on_device(method, model, data_device, training_device): +@pytest.mark.parametrize( + "training_device, prior_device", + [ + pytest.param("cpu", "cuda", marks=pytest.mark.xfail), + pytest.param("cuda:0", "cpu", marks=pytest.mark.xfail), + ("cuda:0", "cuda:0"), + ("cuda:0", "cuda:0"), + ("cpu", "cpu"), + ], +) +def test_training_and_mcmc_on_device( + method, + model, + data_device, + mcmc_method, + training_device, + prior_device, + prior_type="gaussian", +): """Test training on devices. This test does not check training speeds. """ - training_device = process_device(training_device) num_dim = 2 num_samples = 10 @@ -44,41 +59,38 @@ def test_training_and_mcmc_on_device(method, model, data_device, training_device max_num_epochs = 5 x_o = zeros(1, num_dim).to(data_device) - likelihood_shift = -1.0 * ones(num_dim) - likelihood_cov = 0.3 * eye(num_dim) + likelihood_shift = -1.0 * ones(num_dim).to(prior_device) + likelihood_cov = 0.3 * eye(num_dim).to(prior_device) - prior_mean = zeros(num_dim) - prior_cov = eye(num_dim) - prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + if prior_type == "gaussian": + prior_mean = zeros(num_dim).to(prior_device) + prior_cov = eye(num_dim).to(prior_device) + prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov) + else: + prior = BoxUniform( + low=-2 * torch.ones(num_dim), + high=2 * torch.ones(num_dim), + device=prior_device, + ) def simulator(theta): return linear_gaussian(theta, likelihood_shift, likelihood_cov) + training_device = process_device(training_device, prior) + if method in [SNPE_A, SNPE_C]: - kwargs = dict( - density_estimator=utils.posterior_nn(model=model), - ) + kwargs = dict(density_estimator=utils.posterior_nn(model=model)) mcmc_kwargs = ( - dict( - sample_with="mcmc", - mcmc_method="slice_np", - ) + dict(sample_with="rejection", mcmc_method=mcmc_method) if method == SNPE_C else {} ) elif method == SNLE: - kwargs = dict( - density_estimator=utils.likelihood_nn(model=model), - ) - mcmc_kwargs = dict(sample_with="mcmc", mcmc_method="slice") + kwargs = dict(density_estimator=utils.likelihood_nn(model=model)) + mcmc_kwargs = dict(sample_with="mcmc", mcmc_method=mcmc_method) elif method in (SNRE_A, SNRE_B): - kwargs = dict( - classifier=utils.classifier_nn(model=model), - ) - mcmc_kwargs = dict( - sample_with="mcmc", - mcmc_method="slice_np_vectorized", - ) + kwargs = dict(classifier=utils.classifier_nn(model=model)) + mcmc_kwargs = dict(sample_with="mcmc", mcmc_method=mcmc_method) else: raise ValueError() diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index c7a79b62b..186b66570 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -106,7 +106,7 @@ def test_c2st_snl_on_linearGaussian(set_seed): @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("prior_str", ("uniform", "gaussian")) -def test_c2st_snl_on_linearGaussian_different_dims_and_trials( +def test_c2st_and_map_snl_on_linearGaussian_different( num_dim: int, prior_str: str, set_seed ): """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st. @@ -117,8 +117,8 @@ def test_c2st_snl_on_linearGaussian_different_dims_and_trials( set_seed: fixture for manual seeding """ num_samples = 500 - num_simulations = 7500 - trials_to_test = [1, 5, 10] + num_simulations = 5000 + trials_to_test = [1] # likelihood_mean will be likelihood_shift+theta likelihood_shift = -1.0 * ones(num_dim) @@ -138,7 +138,7 @@ def test_c2st_snl_on_linearGaussian_different_dims_and_trials( inference = SNL(prior, show_progress_bars=False) theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=50 + simulator, prior, num_simulations, simulation_batch_size=10000 ) _ = inference.append_simulations(theta, x).train() @@ -150,7 +150,7 @@ def test_c2st_snl_on_linearGaussian_different_dims_and_trials( x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov ) target_samples = gt_posterior.sample((num_samples,)) - else: + elif prior_str == "uniform": target_samples = samples_true_posterior_linear_gaussian_uniform_prior( x_o, likelihood_shift, @@ -158,6 +158,9 @@ def test_c2st_snl_on_linearGaussian_different_dims_and_trials( prior=prior, num_samples=num_samples, ) + else: + raise ValueError(f"Wrong prior_str: '{prior_str}'.") + posterior = inference.build_posterior( mcmc_method="slice_np_vectorized" ).set_default_x(x_o) @@ -171,7 +174,9 @@ def test_c2st_snl_on_linearGaussian_different_dims_and_trials( samples, target_samples, alg=f"snle_a-{prior_str}-prior-{num_trials}-trials" ) - map_ = posterior.map(num_init_samples=1_000, init_method="prior") + map_ = posterior.map( + num_init_samples=1_000, init_method="prior", show_progress_bars=False + ) # TODO: we do not have a test for SNL log_prob(). This is because the output # TODO: density is not normalized, so KLd does not make sense. @@ -239,10 +244,19 @@ def test_c2st_multi_round_snl_on_linearGaussian(num_trials: int, set_seed): @pytest.mark.slow -@pytest.mark.parametrize("prior_str", ("gaussian", "uniform")) @pytest.mark.parametrize( - "sampling_method", - ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc", "rejection"), + "sampling_method, prior_str", + ( + ("slice_np", "gaussian"), + ("slice_np", "uniform"), + ("slice_np_vectorized", "gaussian"), + ("slice_np_vectorized", "uniform"), + ("slice", "gaussian"), + ("slice", "uniform"), + ("nuts", "gaussian"), + ("nuts", "uniform"), + ("hmc", "gaussian"), + ), ) @pytest.mark.parametrize("init_strategy", ("prior", "sir")) def test_api_snl_sampling_methods( @@ -259,8 +273,7 @@ def test_api_snl_sampling_methods( num_dim = 2 num_samples = 10 num_trials = 2 - # HMC with uniform prior needs good likelihood. - num_simulations = 10000 if sampling_method == "hmc" else 1000 + num_simulations = 1000 x_o = zeros((num_trials, num_dim)) # Test for multiple chains is cheap when vectorized. num_chains = 3 if sampling_method == "slice_np_vectorized" else 1 @@ -278,7 +291,7 @@ def test_api_snl_sampling_methods( inference = SNL(prior, show_progress_bars=False) theta, x = simulate_for_sbi( - simulator, prior, num_simulations, simulation_batch_size=50 + simulator, prior, num_simulations, simulation_batch_size=1000 ) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) posterior = inference.build_posterior( diff --git a/tests/linearGaussian_snre_test.py b/tests/linearGaussian_snre_test.py index 2494dad6c..2a58e5f56 100644 --- a/tests/linearGaussian_snre_test.py +++ b/tests/linearGaussian_snre_test.py @@ -142,7 +142,7 @@ def test_c2st_sre_variants_on_linearGaussian( x_o = zeros(num_trials, num_dim) num_samples = 500 - num_simulations = 2500 if num_trials == 1 else 35000 + num_simulations = 2500 if num_trials == 1 else 40000 # `likelihood_mean` will be `likelihood_shift + theta`. likelihood_shift = -1.0 * ones(num_dim) @@ -228,10 +228,19 @@ def simulator(theta): @pytest.mark.slow -@pytest.mark.parametrize("prior_str", ("gaussian", "uniform")) @pytest.mark.parametrize( - "sampling_method", - ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc", "rejection"), + "sampling_method, prior_str", + ( + ("slice_np", "gaussian"), + ("slice_np", "uniform"), + ("slice_np_vectorized", "gaussian"), + ("slice_np_vectorized", "uniform"), + ("slice", "gaussian"), + ("slice", "uniform"), + ("nuts", "gaussian"), + ("nuts", "uniform"), + ("hmc", "gaussian"), + ), ) def test_api_sre_sampling_methods(sampling_method: str, prior_str: str, set_seed): """Test leakage correction both for MCMC and rejection sampling. @@ -244,8 +253,7 @@ def test_api_sre_sampling_methods(sampling_method: str, prior_str: str, set_seed num_dim = 2 num_samples = 10 num_trials = 2 - # HMC with uniform prior needs good likelihood. - num_simulations = 10000 if sampling_method == "hmc" else 1000 + num_simulations = 1000 x_o = zeros((num_trials, num_dim)) # Test for multiple chains is cheap when vectorized. num_chains = 3 if sampling_method == "slice_np_vectorized" else 1 diff --git a/tests/mcmc_test.py b/tests/mcmc_test.py index 4ea51c337..ff5217128 100644 --- a/tests/mcmc_test.py +++ b/tests/mcmc_test.py @@ -8,6 +8,8 @@ import torch from torch import eye, ones, zeros +from sbi import utils +from sbi.inference import SNLE, SNPE, SNRE, prepare_for_sbi, simulate_for_sbi from sbi.mcmc.slice_numpy import SliceSampler from sbi.mcmc.slice_numpy_vectorized import SliceSamplerVectorized from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index 60dd194e8..2e15422af 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -516,7 +516,8 @@ def test_train_with_different_data_and_training_device( # simulator, prior = prepare_for_sbi(user_simulator, user_prior) prior_ = MultivariateNormal( - loc=torch.zeros(num_dim), covariance_matrix=torch.eye(num_dim) + loc=torch.zeros(num_dim).to(training_device), + covariance_matrix=torch.eye(num_dim).to(training_device), ) simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior_)