From e06fa365c6b6fcda57800aa21abd597a946259c6 Mon Sep 17 00:00:00 2001 From: michael Date: Mon, 17 May 2021 10:04:56 +0200 Subject: [PATCH] pushing the refactoring along. SNPE and SNLE work now --- sbi/inference/posteriors/base_posterior.py | 188 ++++++++++--- sbi/inference/posteriors/direct_posterior.py | 260 ++++++++++-------- .../posteriors/likelihood_based_posterior.py | 143 +++------- .../posteriors/ratio_based_posterior.py | 120 ++++++-- sbi/inference/snle/snle_base.py | 16 +- sbi/inference/snpe/snpe_a.py | 2 +- sbi/inference/snpe/snpe_base.py | 35 ++- sbi/inference/snre/snre_base.py | 13 + sbi/mcmc/__init__.py | 1 - sbi/mcmc/rejection_sampling.py | 11 - sbi/utils/__init__.py | 3 +- sbi/utils/sbiutils.py | 229 +++++++++------ tests/inference_on_device_test.py | 2 +- tests/linearGaussian_snle_test.py | 18 +- tests/linearGaussian_snpe_test.py | 16 +- tests/linearGaussian_snre_test.py | 48 ++-- 16 files changed, 664 insertions(+), 441 deletions(-) delete mode 100644 sbi/mcmc/rejection_sampling.py diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index d85b560f8..4c8fcacd2 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -14,6 +14,7 @@ from torch import Tensor, float32 from torch import multiprocessing as mp from torch import nn, optim +from torch._C import device from sbi import utils as utils from sbi.mcmc import ( @@ -29,6 +30,7 @@ check_dist_class, check_warn_and_setstate, optimize_potential_fn, + rejection_sample, ) from sbi.utils.torchutils import ( BoxUniform, @@ -66,7 +68,7 @@ def __init__( neural_net: A classifier for SNRE, a density estimator for SNPE and SNL. prior: Prior distribution with `.log_prob()` and `.sample()`. x_shape: Shape of the simulator data. - sample_with: Method to use for sampling from the posterior. Must be in + sample_with: Method to use for sampling from the posterior. Must be one of [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy @@ -80,12 +82,12 @@ def __init__( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling. Init strategies may have their own keywords which can also be set from `mcmc_parameters`. - rejection_sampling_parameters: Dictionary overriding the default parameters for - rejection sampling. The following parameters are supported: - `proposal`, as the proposal distribtution. `num_samples_to_find_max` - as the number of samples that are used to find the maximum of the - `potential_fn / proposal` ratio. `m` as multiplier to that ratio. - `sampling_batch_size` as the batchsize of samples being drawn from + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution. `num_samples_to_find_max` + as the number of samples that are used to find the maximum of the + `potential_fn / proposal` ratio. `m` as multiplier to that ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the proposal at every iteration. device: Training device, e.g., cpu or cuda. """ @@ -164,7 +166,7 @@ def sample_with(self, value: str) -> None: self.set_sample_with(value) def set_sample_with(self, sample_with: str) -> "NeuralPosterior": - """Turns MCMC sampling on or off and returns `NeuralPosterior`. + """Set the sampling method for the `NeuralPosterior`. Args: sample_with: The method to sample with. @@ -235,7 +237,7 @@ def set_mcmc_parameters(self, parameters: Dict[str, Any]) -> "NeuralPosterior": @property def rejection_sampling_parameters(self) -> dict: - """Returns rejection sampling parameter.""" + """Returns rejection sampling parameters.""" if self._rejection_sampling_parameters is None: return {} else: @@ -254,9 +256,11 @@ def set_rejection_sampling_parameters( Args: parameters: Dictonary overriding the default parameters for rejection sampling. The following parameters are supported: - `max_sampling_batch_size` to the set the batch size for drawing new - samples from the candidate distribution, e.g., the posterior. Larger - batch size speeds up sampling. + `proposal` as the proposal distribtution. `num_samples_to_find_max` + as the number of samples that are used to find the maximum of the + `potential_fn / proposal` ratio. `m` as multiplier to that ratio. + `sampling_batch_size` as the batchsize of samples being drawn from + the proposal at every iteration. Returns: `NeuralPosterior for chainable calls. @@ -357,7 +361,7 @@ def _prepare_for_sample( sample_shape: Optional[Tensor], ) -> Tuple[Tensor, int]: r""" - Return checked and correctly shaped values for `x` and `sample_shape`. + Return checked, reshaped, potentially default values for `x` and `sample_shape`. Args: sample_shape: Desired shape of samples that are drawn from posterior. If @@ -413,11 +417,13 @@ def _potentially_replace_rejection_parameters( Return potentially default values to rejection sample the posterior. Args: - rejection_sampling_parameters: Dictionary overriding the default - parameters for rejection sampling. The following parameters are - supported: `m` as multiplier to the maximum ratio between - potential function and the proposal. `proposal`, as the proposal - distribtution. + rejection_sampling_parameters: Dictionary overriding the default + parameters for rejection sampling. The following parameters are + supported: `proposal` as the proposal distribtution. + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. `m` as + multiplier to that ratio. `sampling_batch_size` as the batchsize of + samples being drawn from the proposal at every iteration. Returns: Potentially default rejection sampling parameters. """ @@ -624,9 +630,11 @@ def sample_conditional( condition: Tensor, dims_to_sample: List[int], x: Optional[Tensor] = None, + sample_with: str = "mcmc", show_progress_bars: bool = True, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: r""" Return samples from conditional posterior $p(\theta_i|\theta_j, x)$. @@ -635,7 +643,7 @@ def sample_conditional( from a few parameter dimensions while the other parameter dimensions are kept fixed at values specified in `condition`. - Samples are obtained with MCMC. + Samples are obtained with MCMC or rejection sampling. Args: potential_fn_provider: Returns the potential function for the unconditional @@ -653,6 +661,9 @@ def sample_conditional( x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x_o` if previously provided for multiround training, or to a set default (see `set_default_x()` method). + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. In this method, the value of + `self.sample_with` will be ignored. show_progress_bars: Whether to show sampling progress monitor. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. @@ -663,38 +674,69 @@ def sample_conditional( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution. `num_samples_to_find_max` + as the number of samples that are used to find the maximum of the + `potential_fn / proposal` ratio. `m` as multiplier to that ratio. + `sampling_batch_size` as the batchsize of samples being drawn from + the proposal at every iteration. Returns: Samples from conditional posterior. """ - x, num_samples, mcmc_method, mcmc_parameters = self._prepare_for_sample( - x, sample_shape, mcmc_method, mcmc_parameters - ) - self.net.eval() + x, num_samples = self._prepare_for_sample(x, sample_shape) + cond_potential_fn_provider = ConditionalPotentialFunctionProvider( potential_fn_provider, condition, dims_to_sample ) - samples = self._sample_posterior_mcmc( - num_samples=num_samples, - potential_fn=cond_potential_fn_provider( - self._prior, self.net, x, mcmc_method - ), - init_fn=self._build_mcmc_init_fn( - # Restrict prior to sample only free dimensions. - RestrictedPriorForConditional(self._prior, dims_to_sample), - cond_potential_fn_provider(self._prior, self.net, x, "slice_np"), + if sample_with == "mcmc": + mcmc_method, mcmc_parameters = self._potentially_replace_mcmc_parameters( + mcmc_method, mcmc_parameters + ) + samples = self._sample_posterior_mcmc( + num_samples=num_samples, + potential_fn=cond_potential_fn_provider( + self._prior, self.net, x, mcmc_method + ), + init_fn=self._build_mcmc_init_fn( + # Restrict prior to sample only free dimensions. + RestrictedPriorForConditional(self._prior, dims_to_sample), + cond_potential_fn_provider(self._prior, self.net, x, "slice_np"), + **mcmc_parameters, + ), + mcmc_method=mcmc_method, + condition=condition, + dims_to_sample=dims_to_sample, + show_progress_bars=show_progress_bars, **mcmc_parameters, - ), - mcmc_method=mcmc_method, - condition=condition, - dims_to_sample=dims_to_sample, - show_progress_bars=show_progress_bars, - **mcmc_parameters, - ) + ) + elif sample_with == "rejection": + rejection_sampling_parameters = ( + self._potentially_replace_rejection_parameters( + rejection_sampling_parameters + ) + ) + if "proposal" not in rejection_sampling_parameters: + rejection_sampling_parameters[ + "proposal" + ] = RestrictedPriorForConditional(self._prior, dims_to_sample) + + samples, _ = rejection_sample( + potential_fn=cond_potential_fn_provider( + self._prior, self.net, x, "rejection" + ), + num_samples=num_samples, + **rejection_sampling_parameters, + ) + else: + raise NameError( + "The only implemented sampling methods are `mcmc` and `rejection`." + ) self.net.train(True) @@ -1071,18 +1113,43 @@ def __init__( self.condition = ensure_theta_batched(condition) self.dims_to_sample = dims_to_sample - def __call__(self, prior, net: nn.Module, x: Tensor, mcmc_method: str) -> Callable: + def __call__(self, prior, net: nn.Module, x: Tensor, method: str) -> Callable: """Return potential function. Switch on numpy or pyro potential function based on `mcmc_method`. """ # Set prior, net, and x as attributes of unconditional potential_fn_provider. - _ = self.potential_fn_provider.__call__(prior, net, x, mcmc_method) + _ = self.potential_fn_provider.__call__(prior, net, x, method) - if mcmc_method in ("slice", "hmc", "nuts"): + if method in ("slice", "hmc", "nuts"): return self.pyro_potential - else: + elif "slice_np" in method: return self.np_potential + elif method == "rejection": + return self.rejection_potential + else: + NotImplementedError + + def rejection_potential(self, theta: np.ndarray) -> ScalarFloat: + r""" + Return conditional posterior log-probability or $-\infty$ if outside prior. + + The only differences to the `np_potential` is that it tracks the gradients and + does not return a `numpy` array. + + Args: + theta: Free parameters $\theta_i$, batch dimension 1. + + Returns: + Conditional posterior log-probability $\log(p(\theta_i|\theta_j, x))$, + masked outside of prior. + """ + theta = torch.as_tensor(theta, dtype=torch.float32) + + theta_condition = deepcopy(self.condition) + theta_condition[:, self.dims_to_sample] = theta + + return self.potential_fn_provider.rejection_potential(theta_condition) def np_potential(self, theta: np.ndarray) -> ScalarFloat: r""" @@ -1155,3 +1222,38 @@ def log_prob(self, *args, **kwargs): the $\theta$ under the full joint once we have added the condition. """ return self.full_prior.log_prob(*args, **kwargs) + + +class NeuralNetDefaultX: + def __init__(self, density_estimator: Any, x: Tensor) -> None: + r""" + Wraps the neural density estimator returned by `nflows` to have a default $x$. + + Currently, it is used only as `proposal` for rejection sampling with `SNPE`. + + Args: + density_estimator: The neural density estimator parameterizing $p(y|x)$. + x: The value for $x$ at which to evaluate or sample $p(y|x)$. + """ + self.net = density_estimator + self.x = x + + def sample(self, num_samples, **kwargs) -> Tensor: + """ + Return samples from $p(y|x)$. + + Args: + num_samples: Number of samples to return. + """ + s = self.net.sample(num_samples, context=self.x, **kwargs) + return s + + def log_prob(self, y, **kwargs) -> Tensor: + r""" + Return the log-probabilities $\log(p(y|x))$. + + Args: + y: Location at which to evlauate the log-probability. + """ + ys, xs = NeuralPosterior._match_theta_and_x_batch_shapes(y, self.x) + return self.net.log_prob(ys, context=xs) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 47c2a3fc7..576d6c709 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -1,7 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn import numpy as np @@ -9,9 +9,9 @@ from torch import Tensor, log, nn from sbi import utils as utils -from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.posteriors.base_posterior import NeuralNetDefaultX, NeuralPosterior from sbi.types import ScalarFloat, Shape -from sbi.utils import del_entries, within_support +from sbi.utils import del_entries, rejection_sample, within_support from sbi.utils.torchutils import ( atleast_2d, batched_first_of_batch, @@ -41,10 +41,10 @@ def __init__( neural_net: nn.Module, prior, x_shape: torch.Size, - rejection_sampling_parameters: Optional[Dict[str, Any]] = None, - sample_with_mcmc: bool = True, + sample_with: str = "rejection", mcmc_method: str = "slice_np", mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, device: str = "cpu", ): """ @@ -53,14 +53,8 @@ def __init__( neural_net: A classifier for SNRE, a density estimator for SNPE and SNL. prior: Prior distribution with `.log_prob()` and `.sample()`. x_shape: Shape of a single simulator output. - rejection_sampling_parameters: Dictonary overriding the default parameters - for rejection sampling. The following parameters are supported: - `max_sampling_batch_size` to set the batch size for drawing new - samples from the candidate distribution, e.g., the posterior. Larger - batch size speeds up sampling. - sample_with_mcmc: Whether to sample with MCMC. Will always be `True` for SRE - and SNL, but can also be set to `True` for SNPE if MCMC is preferred to - deal with leakage over rejection sampling. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy implementation of slice sampling; select `hmc`, `nuts` or `slice` for @@ -73,45 +67,49 @@ def __init__( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the trained + neural net). `num_samples_to_find_max` as the number of samples that + are used to find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. device: Training device, e.g., cpu or cuda:0 """ kwargs = del_entries( locals(), - entries=( - "self", - "__class__", - "sample_with_mcmc", - "rejection_sampling_parameters", - ), + entries=("self", "__class__"), ) super().__init__(**kwargs) - self.set_sample_with_mcmc(sample_with_mcmc) - self.set_rejection_sampling_parameters(rejection_sampling_parameters) self._purpose = ( "It allows to .sample() and .log_prob() the posterior and wraps the " "output of the .net to avoid leakage into regions with 0 prior probability." ) @property - def sample_with_mcmc(self) -> bool: + def _sample_with_mcmc(self) -> bool: """ - Deprecated, will be removed. + Deprecated, will be removed in future versions of `sbi`. + Return `True` if NeuralPosterior instance should use MCMC in `.sample()`. """ - warn("Deprecated") return self._sample_with_mcmc - @sample_with_mcmc.setter - def sample_with_mcmc(self, value: bool) -> None: - """See `set_sample_with_mcmc`.""" - warn("Deprecated") - # XXX call `.sample_with("mcmc")` - self.set_sample_with_mcmc(value) + @_sample_with_mcmc.setter + def _sample_with_mcmc(self, value: bool) -> None: + """ + Deprecated, will be removed in future versions of `sbi`. + + See `set_sample_with_mcmc`.""" + self._set_sample_with_mcmc(value) + + def _set_sample_with_mcmc(self, use_mcmc: bool) -> "NeuralPosterior": + """ + Deprecated, will be removed in future versions of `sbi`. - def set_sample_with_mcmc(self, use_mcmc: bool) -> "NeuralPosterior": - """Turns MCMC sampling on or off and returns `NeuralPosterior`. + Turns MCMC sampling on or off and returns `NeuralPosterior`. Args: use_mcmc: Flag to set whether or not MCMC sampling is used. @@ -123,7 +121,15 @@ def set_sample_with_mcmc(self, use_mcmc: bool) -> "NeuralPosterior": ValueError: on attempt to turn off MCMC sampling for family of methods that do not support rejection sampling. """ - warn("Deprecated") + warn( + f"You set `sample_with_mcmc={use_mcmc}`. This is deprecated " + "since `sbi v0.16.0` and will lead to an error in future versions. " + "Please use `sample_with=mcmc` instead." + ) + if use_mcmc: + self.set_sample_with("mcmc") + else: + self.set_sample_with("rejection") self._sample_with_mcmc = use_mcmc return self @@ -161,7 +167,6 @@ def log_prob( Returns: `(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the support of the prior, -∞ (corresponding to 0 probability) outside. - """ # TODO Train exited here, entered after sampling? @@ -233,12 +238,15 @@ def leakage_correction( """ def acceptance_at(x: Tensor) -> Tensor: - return utils.sample_posterior_within_prior( - self.net, - self._prior, - x.to(self._device), - num_rejection_samples, - show_progress_bars, + + potential_fn_provider = PotentialFunctionProvider() + return utils.rejection_sample( + potential_fn=potential_fn_provider( + self._prior, self.net, x.to(self._device), "rejection" + ), + proposal=NeuralNetDefaultX(self.net, x.to(self._device)), + num_samples=num_rejection_samples, + show_progress_bars=show_progress_bars, sample_for_correction_factor=True, max_sampling_batch_size=rejection_sampling_batch_size, )[1] @@ -262,10 +270,11 @@ def sample( sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, show_progress_bars: bool = True, - sample_with_mcmc: Optional[bool] = None, + sample_with: Optional[str] = None, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, + sample_with_mcmc: Optional[bool] = None, ) -> Tensor: r""" Return samples from posterior distribution $p(\theta|x)$. @@ -282,7 +291,8 @@ def sample( x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. show_progress_bars: Whether to show sampling progress monitor. - sample_with_mcmc: Optional parameter to override `self.sample_with_mcmc`. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. The following parameters are supported: `thin` to set the thinning @@ -294,25 +304,38 @@ def sample( locations. rejection_sampling_parameters: Dictionary overriding the default parameters for rejection sampling. The following parameters are supported: - `max_sampling_batch_size` to set the batch size for drawing new - samples from the candidate distribution, e.g., the posterior. Larger - batch size speeds up sampling. + `proposal` as the proposal distribtution (default is the trained + neural net). `num_samples_to_find_max` as the number of samples that + are used to find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. + sample_with_mcmc: Deprecated since `sbi v0.16.0`. Use `sample_with=mcmc` + instead. + Returns: Samples from posterior. """ - x, num_samples, mcmc_method, mcmc_parameters = self._prepare_for_sample( - x, sample_shape, mcmc_method, mcmc_parameters - ) - - sample_with_mcmc = ( - sample_with_mcmc if sample_with_mcmc is not None else self.sample_with_mcmc - ) + if sample_with_mcmc is not None: + warn( + f"You set `sample_with_mcmc={sample_with_mcmc}`. This is deprecated " + "since `sbi v0.16.0` and will lead to an error in future versions. " + "Please use `sample_with=mcmc` instead." + ) + if sample_with_mcmc: + sample_with = "mcmc" self.net.eval() - if sample_with_mcmc: - potential_fn_provider = PotentialFunctionProvider() + sample_with = sample_with if sample_with is not None else self._sample_with + + x, num_samples = self._prepare_for_sample(x, sample_shape) + + potential_fn_provider = PotentialFunctionProvider() + if sample_with == "mcmc": + mcmc_method, mcmc_parameters = self._potentially_replace_mcmc_parameters( + mcmc_method, mcmc_parameters + ) samples = self._sample_posterior_mcmc( num_samples=num_samples, potential_fn=potential_fn_provider( @@ -327,17 +350,31 @@ def sample( show_progress_bars=show_progress_bars, **mcmc_parameters, ) - else: - # Rejection sampling. - samples, _ = utils.sample_posterior_within_prior( - self.net, - self._prior, - x, + elif sample_with == "rejection": + rejection_sampling_parameters = ( + self._potentially_replace_rejection_parameters( + rejection_sampling_parameters + ) + ) + if "proposal" not in rejection_sampling_parameters: + assert ( + not self.net.training + ), "Posterior nn must be in eval mode for sampling." + + rejection_sampling_parameters["binary_rejection_criterion"] = True + rejection_sampling_parameters["proposal"] = NeuralNetDefaultX( + self.net, x + ) + + samples, _ = rejection_sample( + potential_fn=self._prior, num_samples=num_samples, show_progress_bars=show_progress_bars, - **rejection_sampling_parameters - if (rejection_sampling_parameters is not None) - else self.rejection_sampling_parameters, + **rejection_sampling_parameters, + ) + else: + raise NameError( + "The only implemented sampling methods are `mcmc` and `rejection`." ) self.net.train(True) @@ -350,9 +387,11 @@ def sample_conditional( condition: Tensor, dims_to_sample: List[int], x: Optional[Tensor] = None, + sample_with: str = "mcmc", show_progress_bars: bool = True, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: r""" Return samples from conditional posterior $p(\theta_i|\theta_j, x)$. @@ -376,6 +415,9 @@ def sample_conditional( `condition`. x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. In this method, the value of + `self.sample_with` will be ignored. show_progress_bars: Whether to show sampling progress monitor. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. @@ -386,6 +428,13 @@ def sample_conditional( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Samples from conditional posterior. @@ -397,9 +446,11 @@ def sample_conditional( condition, dims_to_sample, x, + sample_with, show_progress_bars, mcmc_method, mcmc_parameters, + rejection_sampling_parameters, ) def map( @@ -464,52 +515,6 @@ def map( log_prob_kwargs={"norm_posterior": False}, ) - @torch.no_grad() - def sample_posterior_within_prior( - self, - posterior_nn: nn.Module, - prior, - x: Tensor, - num_samples: int = 1, - show_progress_bars: bool = False, - warn_acceptance: float = 0.01, - sample_for_correction_factor: bool = False, - max_sampling_batch_size: int = 10_000, - ) -> Tuple[Tensor, Tensor]: - - assert ( - not posterior_nn.training - ), "Posterior nn must be in eval mode for sampling." - - def potential_fn(theta): - are_within_prior = within_support(prior, theta) - probs = posterior_nn.log_prob(theta, context=x) - probs[~are_within_prior] = float("-inf") - return probs - - class Proposal: - def __init__(self, posterior_nn: Any): - self.posterior_nn = posterior_nn - - def sample(self, sample_shape, **kwargs): - return self.posterior_nn.sample(sample_shape.numel(), **kwargs) - - def log_prob(self, theta, **kwargs): - return self.posterior_nn.log_prob(theta, context=x) - - proposal = Proposal(posterior_nn) - - samples = rejection_sample_raw( - potential_fn=potential_fn, - proposal=proposal, - num_samples=num_samples, - show_progress_bars=show_progress_bars, - warn_acceptance=warn_acceptance, - sample_for_correction_factor=sample_for_correction_factor, - max_sampling_batch_size=max_sampling_batch_size, - ) - return samples - class PotentialFunctionProvider: """ @@ -535,7 +540,7 @@ def __call__( prior, posterior_nn: nn.Module, x: Tensor, - mcmc_method: str, + method: str, ) -> Callable: """Return potential function. @@ -546,10 +551,41 @@ def __call__( self.device = next(posterior_nn.parameters()).device self.x = atleast_2d(x).to(self.device) - if mcmc_method in ("slice", "hmc", "nuts"): + if method in ("slice", "hmc", "nuts"): return self.pyro_potential - else: + elif "slice_np" in method: return self.np_potential + elif method == "rejection": + return self.rejection_potential + else: + NotImplementedError + + def rejection_potential(self, theta: np.ndarray) -> ScalarFloat: + r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior." + + The only difference to the `np_potential` is that it tracks the gradients. + + Args: + theta: Parameters $\theta$, batch dimension 1. + + Returns: + Posterior log probability $\log(p(\theta|x))$. + """ + theta = torch.as_tensor(theta, dtype=torch.float32) + theta = ensure_theta_batched(theta) + num_batch = theta.shape[0] + + # Repeat x over batch dim to match theta batch, accounting for multi-D x. + x_repeated = self.x.repeat(num_batch, *(1 for _ in range(self.x.ndim - 1))) + + target_log_prob = self.posterior_nn.log_prob( + inputs=theta.to(self.device), + context=x_repeated, + ) + in_prior_support = within_support(self.prior, theta) + target_log_prob[~in_prior_support] = -float("Inf") + + return target_log_prob def np_potential(self, theta: np.ndarray) -> ScalarFloat: r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior." diff --git a/sbi/inference/posteriors/likelihood_based_posterior.py b/sbi/inference/posteriors/likelihood_based_posterior.py index af60b07cb..96cb99105 100644 --- a/sbi/inference/posteriors/likelihood_based_posterior.py +++ b/sbi/inference/posteriors/likelihood_based_posterior.py @@ -9,9 +9,8 @@ from torch import Tensor, nn from sbi.inference.posteriors.base_posterior import NeuralPosterior -from sbi.mcmc import rejection_sample from sbi.types import Shape -from sbi.utils import del_entries +from sbi.utils import del_entries, optimize_potential_fn, rejection_sample from sbi.utils.torchutils import ScalarFloat, atleast_2d, ensure_theta_batched @@ -48,7 +47,7 @@ def __init__( independent and identically distributed data / trials. I.e., the data is assumed to be generated based on the same (unknown) model parameters or experimental condations. - sample_with: Method to use for sampling from the posterior. Must be in + sample_with: Method to use for sampling from the posterior. Must be one of [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy @@ -62,13 +61,13 @@ def __init__( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. - rejection_sampling_parameters: Dictionary overriding the default parameters for - rejection sampling. The following parameters are supported: - `proposal`, as the proposal distribtution. `num_samples_to_find_max` - as the number of samples that are used to find the maximum of the - `potential_fn / proposal` ratio. `m` as multiplier to that ratio. - `sampling_batch_size` as the batchsize of samples being drawn from - the proposal at every iteration. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. device: Training device, e.g., cpu or cuda:0. """ @@ -125,7 +124,7 @@ def sample( sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, show_progress_bars: bool = True, - sample_with: str = "mcmc", + sample_with: Optional[str] = None, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, rejection_sampling_parameters: Optional[Dict[str, Any]] = None, @@ -140,7 +139,7 @@ def sample( x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. show_progress_bars: Whether to show sampling progress monitor. - sample_with: Method to use for sampling from the posterior. Must be in + sample_with: Method to use for sampling from the posterior. Must be one of [`mcmc` | `rejection`]. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. @@ -153,11 +152,11 @@ def sample( locations. rejection_sampling_parameters: Dictionary overriding the default parameters for rejection sampling. The following parameters are supported: - `proposal`, as the proposal distribtution. `num_samples_to_find_max` - as the number of samples that are used to find the maximum of the - `potential_fn / proposal` ratio. `m` as multiplier to that ratio. - `sampling_batch_size` as the batchsize of samples being drawn from - the proposal at every iteration. + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Samples from posterior. @@ -165,6 +164,7 @@ def sample( self.net.eval() + sample_with = sample_with if sample_with is not None else self._sample_with x, num_samples = self._prepare_for_sample(x, sample_shape) potential_fn_provider = PotentialFunctionProvider() @@ -197,11 +197,11 @@ def sample( if "proposal" not in rejection_sampling_parameters: rejection_sampling_parameters["proposal"] = self._prior - samples = self._sample_posterior_rejection( - num_samples=num_samples, + samples, _ = rejection_sample( potential_fn=potential_fn_provider( self._prior, self.net, x, "rejection" ), + num_samples=num_samples, **rejection_sampling_parameters, ) else: @@ -219,9 +219,11 @@ def sample_conditional( condition: Tensor, dims_to_sample: List[int], x: Optional[Tensor] = None, + sample_with: str = "mcmc", show_progress_bars: bool = True, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: r""" Return samples from conditional posterior $p(\theta_i|\theta_j, x)$. @@ -245,6 +247,9 @@ def sample_conditional( `condition`. x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. In this method, the value of + `self._sample_with` will be ignored. show_progress_bars: Whether to show sampling progress monitor. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. @@ -255,6 +260,13 @@ def sample_conditional( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Samples from conditional posterior. @@ -266,9 +278,11 @@ def sample_conditional( condition, dims_to_sample, x, + sample_with, show_progress_bars, mcmc_method, mcmc_parameters, + rejection_sampling_parameters, ) def map( @@ -381,95 +395,6 @@ def _log_likelihoods_over_trials( return log_likelihood_trial_sum - def _sample_posterior_rejection( - self, - num_samples: torch.Size, - potential_fn: Any, - proposal: Any, - num_samples_to_find_max: int = 10_000, - m: float = 1.2, - sampling_batch_size: int = 10_000, - ): - r""" - Return samples from a distribution via rejection sampling. - - This function is used in any case by SNLE and SNRE, but can also be used by SNPE - in order to deal with strong leakage. Depending on the inference method, a - different potential function for the rejection sampler is required. - - Args: - num_samples: Desired number of samples. - potential_fn: Potential function used for rejection sampling. - proposal: Proposal distribution for rejection sampling. - num_samples_to_find_max: Number of samples that are used to find the maximum - of the `potential_fn / proposal` ratio. - m: Multiplier to the maximum ratio between potential function and the - proposal. A higher value will ensure that the samples are indeed from - the posterior, but will increase the rejection ratio and thus - computation time. - sampling_batch_size: Batchsize of samples being drawn from - the proposal at every iteration. - - Returns: - Tensor of shape (num_samples, shape_of_single_theta). - """ - - find_max = proposal.sample((num_samples_to_find_max,)) - - # Define a potential as the ratio between target distribution and proposal. - def potential_over_proposal(theta): - return torch.squeeze(potential_fn(theta)) - proposal.log_prob(theta) - - # Search for the maximum of the ratio. - _, max_log_ratio = optimize_potential_fn( - potential_fn=potential_over_proposal, - inits=find_max, - dist_specifying_bounds=proposal, - num_iter=100, - learning_rate=0.01, - num_to_optimize=max(1, int(num_samples_to_find_max / 10)), - show_progress_bars=False, - ) - - if m < 1.0: - warn("A value of m < 1.0 will lead to systematically wrong results.") - - class ScaledProposal: - def __init__(self, proposal: Any, max_log_ratio: float, log_m: float): - self.proposal = proposal - self.max_log_ratio = max_log_ratio - self.log_m = log_m - - def sample(self, sample_shape, **kwargs): - return self.proposal.sample((sample_shape,), **kwargs) - - def log_prob(self, theta, **kwargs): - return self.proposal.log_prob(theta) + self.max_log_ratio + self.log_m - - scaled_proposal = ScaledProposal( - proposal, max_log_ratio, torch.log(torch.as_tensor(m)) - ) - - samples, _ = rejection_sample_raw( - potential_fn, scaled_proposal, num_samples=num_samples - ) - return samples - - num_accepted = 0 - all_ = [] - while num_accepted < num_samples: - candidates = proposal.sample((sampling_batch_size,)) - probs = potential_fn(candidates) - target_log_probs = potential_fn(candidates) - proposal_log_probs = proposal.log_prob(candidates) + max_ratio - target_proposal_ratio = exp(target_log_probs - proposal_log_probs) - acceptance = rand(target_proposal_ratio.shape) - accepted = candidates[target_proposal_ratio > acceptance] - num_accepted += accepted.shape[0] - all_.append(accepted) - samples = torch.cat(all_)[:num_samples] - return samples - class PotentialFunctionProvider: """ @@ -506,7 +431,7 @@ def __call__( likelihood_nn: Neural likelihood estimator that can be evaluated. x: Conditioning variable for posterior $p(\theta|x)$. Can be a batch of iid x. - mcmc_method: One of `slice_np`, `slice`, `hmc` or `nuts`. + method: One of `slice_np`, `slice`, `hmc` or `nuts`, `rejection`. Returns: Potential function for sampler. diff --git a/sbi/inference/posteriors/ratio_based_posterior.py b/sbi/inference/posteriors/ratio_based_posterior.py index 7c1b027e1..b529cbf1b 100644 --- a/sbi/inference/posteriors/ratio_based_posterior.py +++ b/sbi/inference/posteriors/ratio_based_posterior.py @@ -10,7 +10,7 @@ from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.types import Shape -from sbi.utils import del_entries +from sbi.utils import del_entries, optimize_potential_fn, rejection_sample from sbi.utils.torchutils import ScalarFloat, atleast_2d, ensure_theta_batched @@ -34,8 +34,10 @@ def __init__( neural_net: nn.Module, prior, x_shape: torch.Size, + sample_with: str = "mcmc", mcmc_method: str = "slice_np", mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, device: str = "cpu", ): """ @@ -43,7 +45,14 @@ def __init__( method_family: One of snpe, snl, snre_a or snre_b. neural_net: A classifier for SNRE, a density estimator for SNPE and SNL. prior: Prior distribution with `.log_prob()` and `.sample()`. - x_shape: Shape of a single simulator output. + x_shape: Shape of the simulated data. It can differ from the + observed data the posterior is conditioned on later in the batch + dimension. If it differs, the additional entries are interpreted as + independent and identically distributed data / trials. I.e., the data is + assumed to be generated based on the same (unknown) model parameters or + experimental condations. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy implementation of slice sampling; select `hmc`, `nuts` or `slice` for @@ -56,6 +65,13 @@ def __init__( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. device: Training device, e.g., cpu or cuda:0. """ kwargs = del_entries(locals(), entries=("self", "__class__")) @@ -115,10 +131,10 @@ def sample( sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, show_progress_bars: bool = True, - sample_with: str = "mcmc", + sample_with: Optional[str] = None, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, - rejection_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: r""" Return samples from posterior distribution $p(\theta|x)$ with MCMC. @@ -130,8 +146,8 @@ def sample( x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. show_progress_bars: Whether to show sampling progress monitor. - sample_with: Method to use for sampling from the posterior. Must be in - [`mcmc` | `rejection` | `vi`]. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. The following parameters are supported: `thin` to set the thinning @@ -141,10 +157,13 @@ def sample( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. - rejection_parameters: Dictionary overriding the default parameters for - rejection sampling. The following parameters are supported: `m` as - multiplier to the maximum ratio between potential function and the - proposal. `proposal`, as the proposal distribtution. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Samples from posterior. @@ -152,14 +171,19 @@ def sample( self.net.eval() + sample_with = sample_with if sample_with is not None else self._sample_with + x, num_samples = self._prepare_for_sample(x, sample_shape) + potential_fn_provider = PotentialFunctionProvider() if sample_with == "mcmc": - x, num_samples, mcmc_method, mcmc_parameters = self._prepare_for_sample( - x, sample_shape, mcmc_method, mcmc_parameters + mcmc_method, mcmc_parameters = self._potentially_replace_mcmc_parameters( + mcmc_method, mcmc_parameters ) samples = self._sample_posterior_mcmc( num_samples=num_samples, - potential_fn=potential_fn_provider(self._prior, self.net, x, mcmc_method), + potential_fn=potential_fn_provider( + self._prior, self.net, x, mcmc_method + ), init_fn=self._build_mcmc_init_fn( self._prior, potential_fn_provider(self._prior, self.net, x, "slice_np"), @@ -170,16 +194,25 @@ def sample( **mcmc_parameters, ) elif sample_with == "rejection": - samples = rejection_sample( - num_samples=num_samples, + rejection_sampling_parameters = ( + self._potentially_replace_rejection_parameters( + rejection_sampling_parameters + ) + ) + if "proposal" not in rejection_sampling_parameters: + rejection_sampling_parameters["proposal"] = self._prior + + samples, _ = rejection_sample( potential_fn=potential_fn_provider( - self._prior, self.net, x, "slice_np" + self._prior, self.net, x, "rejection" ), - proposal=self._prior, - m=1.0, + num_samples=num_samples, + **rejection_sampling_parameters, ) else: - raise NameError("The only implemented sampling methods are `mcmc` and `rejection`.") + raise NameError( + "The only implemented sampling methods are `mcmc` and `rejection`." + ) self.net.train(True) @@ -191,9 +224,11 @@ def sample_conditional( condition: Tensor, dims_to_sample: List[int], x: Optional[Tensor] = None, + sample_with: str = "mcmc", show_progress_bars: bool = True, mcmc_method: Optional[str] = None, mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> Tensor: r""" Return samples from conditional posterior $p(\theta_i|\theta_j, x)$. @@ -217,6 +252,9 @@ def sample_conditional( `condition`. x: Conditioning context for posterior $p(\theta|x)$. If not provided, fall back onto `x` passed to `set_default_x()`. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. In this method, the value of + `self._sample_with` will be ignored. show_progress_bars: Whether to show sampling progress monitor. mcmc_method: Optional parameter to override `self.mcmc_method`. mcmc_parameters: Dictionary overriding the default parameters for MCMC. @@ -227,6 +265,13 @@ def sample_conditional( will draw init locations from prior, whereas `sir` will use Sequential- Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Samples from conditional posterior. @@ -238,9 +283,11 @@ def sample_conditional( condition, dims_to_sample, x, + sample_with, show_progress_bars, mcmc_method, mcmc_parameters, + rejection_sampling_parameters, ) def map( @@ -404,7 +451,7 @@ def __call__( prior, classifier: nn.Module, x: Tensor, - mcmc_method: str, + method: str, ) -> Callable: r"""Return potential function for posterior $p(\theta|x)$. @@ -415,7 +462,7 @@ def __call__( classifier: Binary classifier approximating the likelihood up to a constant. x: Conditioning variable for posterior $p(\theta|x)$. - mcmc_method: One of `slice_np`, `slice`, `hmc` or `nuts`. + method: One of `slice_np`, `slice`, `hmc` or `nuts`, `rejection`. Returns: Potential function for sampler. @@ -426,10 +473,35 @@ def __call__( self.device = next(classifier.parameters()).device self.x = atleast_2d(x).to(self.device) - if mcmc_method in ("slice", "hmc", "nuts"): + if method in ("slice", "hmc", "nuts"): return self.pyro_potential - else: + elif "slice_np" in method: return self.np_potential + elif method == "rejection": + return self.rejection_potential + else: + NotImplementedError + + def rejection_potential(self, theta: np.array) -> ScalarFloat: + r"""Return posterior log prob. of theta $p(\theta|x)$" + + The only difference to the `np_potential` is that it tracks the gradients. + + Args: + theta: Parameters $\theta$, batch dimension 1. + + Returns: + Posterior log probability of the theta, $-\infty$ if impossible under prior. + """ + theta = torch.as_tensor(theta, dtype=torch.float32) + theta = ensure_theta_batched(theta) + + log_ratio = RatioBasedPosterior._log_ratios_over_trials( + self.x, theta, self.classifier, track_gradients=True + ) + + # Notice opposite sign to pyro potential. + return log_ratio.cpu() + self.prior.log_prob(theta) def np_potential(self, theta: np.array) -> ScalarFloat: """Return potential for Numpy slice sampler." @@ -472,7 +544,7 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor: theta = ensure_theta_batched(theta) log_ratio = RatioBasedPosterior._log_ratios_over_trials( - self.x, theta, self.classifier, track_gradients=False + self.x, theta, self.classifier, track_gradients=True ) return -(log_ratio.cpu() + self.prior.log_prob(theta)) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 7a82570af..037d81273 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -276,7 +276,7 @@ def build_posterior( Args: density_estimator: The density estimator that the posterior is based on. If `None`, use the latest neural density estimator that was trained. - sample_with: Method to use for sampling from the posterior. Must be in + sample_with: Method to use for sampling from the posterior. Must be one of [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy @@ -290,13 +290,13 @@ def build_posterior( draw init locations from prior, whereas `sir` will use Sequential-Importance-Resampling using `init_strategy_num_candidates` to find init locations. - rejection_sampling_parameters: Dictionary overriding the default parameters for - rejection sampling. The following parameters are supported: - `proposal`, as the proposal distribtution. `num_samples_to_find_max` - as the number of samples that are used to find the maximum of the - `potential_fn / proposal` ratio. `m` as multiplier to that ratio. - `sampling_batch_size` as the batchsize of samples being drawn from - the proposal at every iteration. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 39ef6f660..54420f9a8 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -298,8 +298,8 @@ def build_posterior( neural_net=wrapped_density_estimator, prior=self._prior, x_shape=self._x_shape, + sample_with="rejection", rejection_sampling_parameters=rejection_sampling_parameters, - sample_with_mcmc=False, device=device, ) diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 98faf6c58..dfbbc757a 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -347,10 +347,11 @@ def train( def build_posterior( self, density_estimator: Optional[TorchModule] = None, - rejection_sampling_parameters: Optional[Dict[str, Any]] = None, - sample_with_mcmc: bool = False, + sample_with: str = "rejection", mcmc_method: str = "slice_np", mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, + sample_with_mcmc: Optional[bool] = None, ) -> DirectPosterior: r""" Build posterior from the neural density estimator. @@ -367,13 +368,8 @@ def build_posterior( Args: density_estimator: The density estimator that the posterior is based on. If `None`, use the latest neural density estimator that was trained. - rejection_sampling_parameters: Dictionary overriding the default parameters - for rejection sampling. The following parameters are supported: - `max_sampling_batch_size` to set the batch size for drawing new - samples from the candidate distribution, e.g., the posterior. Larger - batch size speeds up sampling. - sample_with_mcmc: Whether to sample with MCMC. MCMC can be used to deal - with high leakage. + sample_with: Method to use for sampling from the posterior. Must be one of + [`rejection` | `mcmc`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy implementation of slice sampling; select `hmc`, `nuts` or `slice` for @@ -386,10 +382,27 @@ def build_posterior( draw init locations from prior, whereas `sir` will use Sequential-Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the trained + neural net). `num_samples_to_find_max` as the number of samples that + are used to find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. + sample_with_mcmc: Deprecated since `sbi v0.16.0`. Use `sample_with=mcmc` + instead. Returns: Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. """ + if sample_with_mcmc is not None: + warn( + f"You set `sample_with_mcmc={sample_with_mcmc}`. This is deprecated " + "since `sbi v0.16.0` and will lead to an error in future versions. " + "Please use `sample_with=mcmc` instead." + ) + if sample_with_mcmc: + sample_with = "mcmc" if density_estimator is None: density_estimator = self._neural_net @@ -404,10 +417,10 @@ def build_posterior( neural_net=density_estimator, prior=self._prior, x_shape=self._x_shape, - rejection_sampling_parameters=rejection_sampling_parameters, - sample_with_mcmc=sample_with_mcmc, + sample_with=sample_with, mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters, + rejection_sampling_parameters=rejection_sampling_parameters, device=device, ) diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 40f0f317b..fdfa26f56 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -267,8 +267,10 @@ def train( def build_posterior( self, density_estimator: Optional[TorchModule] = None, + sample_with: str = "mcmc", mcmc_method: str = "slice_np", mcmc_parameters: Optional[Dict[str, Any]] = None, + rejection_sampling_parameters: Optional[Dict[str, Any]] = None, ) -> RatioBasedPosterior: r""" Build posterior from the neural density estimator. @@ -285,6 +287,8 @@ def build_posterior( Args: density_estimator: The density estimator that the posterior is based on. If `None`, use the latest neural density estimator that was trained. + sample_with: Method to use for sampling from the posterior. Must be one of + [`mcmc` | `rejection`]. mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, `hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy implementation of slice sampling; select `hmc`, `nuts` or `slice` for @@ -297,6 +301,13 @@ def build_posterior( draw init locations from prior, whereas `sir` will use Sequential-Importance-Resampling using `init_strategy_num_candidates` to find init locations. + rejection_sampling_parameters: Dictionary overriding the default parameters + for rejection sampling. The following parameters are supported: + `proposal` as the proposal distribtution (default is the prior). + `num_samples_to_find_max` as the number of samples that are used to + find the maximum of the `potential_fn / proposal` ratio. + `sampling_batch_size` as the batchsize of samples being drawn from the + proposal at every iteration. Returns: Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods @@ -316,8 +327,10 @@ def build_posterior( neural_net=density_estimator, prior=self._prior, x_shape=self._x_shape, + sample_with=sample_with, mcmc_method=mcmc_method, mcmc_parameters=mcmc_parameters, + rejection_sampling_parameters=rejection_sampling_parameters, device=device, ) diff --git a/sbi/mcmc/__init__.py b/sbi/mcmc/__init__.py index 83a1e2b89..7c59fb054 100644 --- a/sbi/mcmc/__init__.py +++ b/sbi/mcmc/__init__.py @@ -1,5 +1,4 @@ from sbi.mcmc.init_strategy import IterateParameters, prior_init, sir -from sbi.mcmc.rejection_sampling import rejection_sample from sbi.mcmc.slice import Slice from sbi.mcmc.slice_numpy import SliceSampler from sbi.mcmc.slice_numpy_vectorized import SliceSamplerVectorized diff --git a/sbi/mcmc/rejection_sampling.py b/sbi/mcmc/rejection_sampling.py deleted file mode 100644 index 9121e1772..000000000 --- a/sbi/mcmc/rejection_sampling.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any -from warnings import warn - -import torch -from torch import Tensor, exp, log, rand - -from sbi.utils import ( - optimize_potential_fn, - rejection_sample_raw, - sample_posterior_within_prior, -) diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 91175b34d..7127d79c0 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -21,8 +21,7 @@ mask_sims_from_prior, mog_log_prob, optimize_potential_fn, - rejection_sample_raw, - sample_posterior_within_prior, + rejection_sample, standardizing_net, standardizing_transform, warn_if_zscoring_changes_data, diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index d86b9312f..62b729822 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -146,19 +146,19 @@ def standardizing_net(batch_t: Tensor, min_std: float = 1e-7) -> nn.Module: return Standardize(t_mean, t_std) -def rejection_sample_raw( - potential_fn: nn.Module, - proposal, +def rejection_sample( + potential_fn: Callable, + proposal: Any, num_samples: int = 1, show_progress_bars: bool = False, warn_acceptance: float = 0.01, sample_for_correction_factor: bool = False, max_sampling_batch_size: int = 10_000, -): - r"""Return samples from a posterior $p(\theta|x)$ only within the prior support. - - This is relevant for snpe methods and flows for which the posterior tends to have - mass outside the prior boundaries. + binary_rejection_criterion: bool = False, + num_samples_to_find_max: int = 10_000, + m: float = 1.2, +) -> Tuple[Tensor, Tensor]: + r"""Return samples from a `potential_fn` obtained via rejection sampling. This function uses rejection sampling with samples from posterior in order to 1) obtain posterior samples within the prior support, and @@ -166,103 +166,164 @@ def rejection_sample_raw( density during evaluation of the posterior. Args: - posterior_nn: Neural net representing the posterior. - prior: Distribution-like object that evaluates probabilities with `log_prob`. + potential_fn: The potential to sample from. The potential should be passed as + the logarithm of the desired distribution. + proposal: The proposal from which to draw candidate samples. Must have a + `sample()` and a `log_prob()` method. num_samples: Desired number of samples. show_progress_bars: Whether to show a progressbar during sampling. warn_acceptance: A minimum acceptance rate under which to warn about slowness. sample_for_correction_factor: True if this function was called by `leakage_correction()`. False otherwise. Will be used to adapt the leakage - warning. + warning and to decide whether we have to search for the maximum. max_sampling_batch_size: Batch size for drawing samples from the posterior. Takes effect only in the second iteration of the loop below, i.e., in case of leakage or `num_samples>max_sampling_batch_size`. Larger batch size speeds up sampling. + binary_rejection_criterion: If `True`, the proposal will not be scaled up and + samples will be rejected / accepted depending on whether the potential_fn + is `-inf` or not. If `True`, the `potential_fn` must have a `.log_prob()` + method. It will be set to `True` for `SNPE` when the proposal is the + neural network estimating the posterior. + num_samples_to_find_max: Number of samples that are used to find the maximum + of the `potential_fn / proposal` ratio. + m: Multiplier to the maximum ratio between potential function and the + proposal. This factor is applied after already having scaled the proposal + with the maximum ratio of the `potential_fn / proposal` ratio. A higher + value will ensure that the samples are indeed from the correct + distribution, but will increase the fraction of rejected samples and thus + computation time. Returns: Accepted samples and acceptance rate as scalar Tensor. """ - # Progress bar can be skipped, e.g. when sampling after each round just for logging. - pbar = tqdm( - disable=not show_progress_bars, - total=num_samples, - desc=f"Drawing {num_samples} posterior samples", - ) + if not binary_rejection_criterion and not sample_for_correction_factor: + find_max = proposal.sample((num_samples_to_find_max,)) + + # Define a potential as the ratio between target distribution and proposal. + def potential_over_proposal(theta): + return potential_fn(theta) - proposal.log_prob(theta) + + # Search for the maximum of the ratio. + _, max_log_ratio = optimize_potential_fn( + potential_fn=potential_over_proposal, + inits=find_max, + dist_specifying_bounds=proposal, + num_iter=100, + learning_rate=0.01, + num_to_optimize=max(1, int(num_samples_to_find_max / 10)), + show_progress_bars=False, + ) - num_sampled_total, num_remaining = 0, num_samples - accepted, acceptance_rate = [], float("Nan") - leakage_warning_raised = False + if m < 1.0: + warnings.warn( + "A value of m < 1.0 will lead to systematically wrong results." + ) + + class ScaledProposal: + def __init__(self, proposal: Any, max_log_ratio: float, log_m: float): + self.proposal = proposal + self.max_log_ratio = max_log_ratio + self.log_m = log_m - # To cover cases with few samples without leakage: - sampling_batch_size = min(num_samples, max_sampling_batch_size) - while num_remaining > 0: + def sample(self, sample_shape, **kwargs): + return self.proposal.sample((sample_shape,), **kwargs) - # Sample and reject. - candidates = ( - proposal.sample(sampling_batch_size) - .reshape(sampling_batch_size, -1) - .cpu() # Move to cpu to evaluate under prior. + def log_prob(self, theta, **kwargs): + return self.proposal.log_prob(theta) + self.max_log_ratio + self.log_m + + proposal = ScaledProposal( + proposal, max_log_ratio, torch.log(torch.as_tensor(m)) ) - # are_within_prior = within_support(prior, candidates) - # samples = candidates[are_within_prior] - target_proposal_ratio = torch.exp( - potential_fn(candidates) - proposal.log_prob(candidates) + with torch.no_grad(): + # Progress bar can be skipped, e.g. when sampling after each round just for + # logging. + pbar = tqdm( + disable=not show_progress_bars, + total=num_samples, + desc=f"Drawing {num_samples} posterior samples", ) - uniform_rand = torch.rand(target_proposal_ratio.shape) - samples = candidates[target_proposal_ratio > uniform_rand] - - accepted.append(samples) - - # Update. - num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[0] - pbar.update(samples.shape[0]) - - # To avoid endless sampling when leakage is high, we raise a warning if the - # acceptance rate is too low after the first 1_000 samples. - acceptance_rate = (num_samples - num_remaining) / num_sampled_total - - # For remaining iterations (leakage or many samples) continue sampling with - # fixed batch size. - sampling_batch_size = max_sampling_batch_size - if ( - num_sampled_total > 1000 - and acceptance_rate < warn_acceptance - and not leakage_warning_raised - ): - if sample_for_correction_factor: - logging.warning( - f"""Drawing samples from posterior to estimate the normalizing - constant for `log_prob()`. However, only {acceptance_rate:.0%} - posterior samples are within the prior support. It may take a - long time to collect the remaining {num_remaining} samples. - Consider interrupting (Ctrl-C) and either basing the estimate - of the normalizing constant on fewer samples (by calling - `posterior.leakage_correction(x_o, num_rejection_samples=N)`, - where `N` is the number of samples you want to base the - estimate on (default N=10000), or not estimating the - normalizing constant at all - (`log_prob(..., norm_posterior=False)`. The latter will result - in an unnormalized `log_prob()`.""" - ) + + num_sampled_total, num_remaining = 0, num_samples + accepted, acceptance_rate = [], float("Nan") + leakage_warning_raised = False + + # To cover cases with few samples without leakage: + sampling_batch_size = min(num_samples, max_sampling_batch_size) + while num_remaining > 0: + + # Sample and reject. + candidates = proposal.sample(sampling_batch_size).reshape( + sampling_batch_size, -1 + ) + + if binary_rejection_criterion: + # SNPE-style rejection-sampling when the proposal is the neural net. + candidates = candidates.cpu() + are_within_prior = within_support(potential_fn, candidates) + samples = candidates[are_within_prior] else: - logging.warning( - f"""Only {acceptance_rate:.0%} posterior samples are within the - prior support. It may take a long time to collect the remaining - {num_remaining} samples. Consider interrupting (Ctrl-C) - and switching to `sample_with_mcmc=True`.""" - ) - leakage_warning_raised = True # Ensure warning is raised just once. - - pbar.close() - - # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted)[:num_samples] - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + # Classical rejection sampling. + target_proposal_ratio = torch.exp( + potential_fn(candidates) - proposal.log_prob(candidates) + ).cpu() + uniform_rand = torch.rand(target_proposal_ratio.shape) + samples = candidates.cpu()[target_proposal_ratio > uniform_rand] + + accepted.append(samples) + + # Update. + num_sampled_total += sampling_batch_size + num_remaining -= samples.shape[0] + pbar.update(samples.shape[0]) + + # To avoid endless sampling when leakage is high, we raise a warning if the + # acceptance rate is too low after the first 1_000 samples. + acceptance_rate = (num_samples - num_remaining) / num_sampled_total + + # For remaining iterations (leakage or many samples) continue sampling with + # fixed batch size. + sampling_batch_size = max_sampling_batch_size + if ( + num_sampled_total > 1000 + and acceptance_rate < warn_acceptance + and not leakage_warning_raised + ): + if sample_for_correction_factor: + logging.warning( + f"""Drawing samples from posterior to estimate the normalizing + constant for `log_prob()`. However, only + {acceptance_rate:.0%} posterior samples are within the + prior support. It may take a long time to collect the + remaining {num_remaining} samples. + Consider interrupting (Ctrl-C) and either basing the + estimate of the normalizing constant on fewer samples (by + calling `posterior.leakage_correction(x_o, + num_rejection_samples=N)`, where `N` is the number of + samples you want to base the + estimate on (default N=10000), or not estimating the + normalizing constant at all + (`log_prob(..., norm_posterior=False)`. The latter will + result in an unnormalized `log_prob()`.""" + ) + else: + logging.warning( + f"""Only {acceptance_rate:.0%} posterior samples are within the + prior support. It may take a long time to collect the + remaining {num_remaining} samples. Consider interrupting + (Ctrl-C) and switching to `sample_with_mcmc=True`.""" + ) + leakage_warning_raised = True # Ensure warning is raised just once. + + pbar.close() + + # When in case of leakage a batch size was used there could be too many samples. + samples = torch.cat(accepted)[:num_samples] + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 804d89a31..ff1fda567 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -60,7 +60,7 @@ def simulator(theta): ) mcmc_kwargs = ( dict( - sample_with_mcmc=True, + sample_with="mcmc", mcmc_method="slice_np", ) if method == SNPE_C diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 1e0d9e2e8..bb95255b9 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -238,13 +238,13 @@ def test_c2st_multi_round_snl_on_linearGaussian(num_trials: int, set_seed): check_c2st(samples, target_samples, alg="multi-round-snl") -# TODO: add test for rejection sampling. @pytest.mark.slow @pytest.mark.parametrize("prior_str", ("gaussian", "uniform")) @pytest.mark.parametrize( - "mcmc_method", ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc") + "sampling_method", + ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc", "rejection"), ) -def test_api_snl_sampling_methods(mcmc_method: str, prior_str: str, set_seed): +def test_api_snl_sampling_methods(sampling_method: str, prior_str: str, set_seed): """Runs SNL on linear Gaussian and tests sampling from posterior via mcmc. Args: @@ -257,10 +257,14 @@ def test_api_snl_sampling_methods(mcmc_method: str, prior_str: str, set_seed): num_samples = 10 num_trials = 2 # HMC with uniform prior needs good likelihood. - num_simulations = 10000 if mcmc_method == "hmc" else 1000 + num_simulations = 10000 if sampling_method == "hmc" else 1000 x_o = zeros((num_trials, num_dim)) # Test for multiple chains is cheap when vectorized. - num_chains = 3 if mcmc_method == "slice_np_vectorized" else 1 + num_chains = 3 if sampling_method == "slice_np_vectorized" else 1 + if sampling_method == "rejection": + sample_with = "rejection" + else: + sample_with = "mcmc" if prior_str == "gaussian": prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) @@ -274,7 +278,9 @@ def test_api_snl_sampling_methods(mcmc_method: str, prior_str: str, set_seed): simulator, prior, num_simulations, simulation_batch_size=50 ) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - posterior = inference.build_posterior(mcmc_method=mcmc_method).set_default_x(x_o) + posterior = inference.build_posterior( + sample_with=sample_with, mcmc_method=sampling_method + ).set_default_x(x_o) posterior.sample( sample_shape=(num_samples,), diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 2af04c5d7..fc8c9e0a6 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -290,17 +290,17 @@ def test_c2st_multi_round_snpe_on_linearGaussian(method_str: str, set_seed): @pytest.mark.slow @pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C]) @pytest.mark.parametrize( - "sample_with_mcmc, mcmc_method, prior_str", + "sample_with, mcmc_method, prior_str", ( - (True, "slice_np", "gaussian"), - (True, "slice", "gaussian"), + ("mcmc", "slice_np", "gaussian"), + ("mcmc", "slice", "gaussian"), # XXX (True, "slice", "uniform"), # XXX takes very long. fix when refactoring pyro sampling - (False, "rejection", "uniform"), + ("rejection", "rejection", "uniform"), ), ) def test_api_snpe_c_posterior_correction( - snpe_method: type, sample_with_mcmc, mcmc_method, prior_str, set_seed + snpe_method: type, sample_with, mcmc_method, prior_str, set_seed ): """Test that leakage correction applied to sampling works, with both MCMC and rejection. @@ -329,7 +329,7 @@ def test_api_snpe_c_posterior_correction( inference = snpe_method( prior, simulation_batch_size=50, - sample_with_mcmc=sample_with_mcmc, + sample_with=sample_with, mcmc_method=mcmc_method, show_progress_bars=False, ) @@ -338,9 +338,7 @@ def test_api_snpe_c_posterior_correction( _ = inference.append_simulations(theta, x).train(max_num_epochs=5) posterior = inference.build_posterior() - posterior = posterior.set_sample_with_mcmc(sample_with_mcmc).set_mcmc_method( - mcmc_method - ) + posterior = posterior.set_sample_with(sample_with).set_mcmc_method(mcmc_method) # Posterior should be corrected for leakage even if num_rounds just 1. samples = posterior.sample((10,), x=x_o) diff --git a/tests/linearGaussian_snre_test.py b/tests/linearGaussian_snre_test.py index b3490ac98..2494dad6c 100644 --- a/tests/linearGaussian_snre_test.py +++ b/tests/linearGaussian_snre_test.py @@ -228,16 +228,12 @@ def simulator(theta): @pytest.mark.slow +@pytest.mark.parametrize("prior_str", ("gaussian", "uniform")) @pytest.mark.parametrize( - "mcmc_method, prior_str", - ( - ("slice_np", "gaussian"), - ("slice_np", "uniform"), - ("slice", "gaussian"), - ("slice", "uniform"), - ), + "sampling_method", + ("slice_np", "slice_np_vectorized", "slice", "nuts", "hmc", "rejection"), ) -def test_api_sre_sampling_methods(mcmc_method: str, prior_str: str, set_seed): +def test_api_sre_sampling_methods(sampling_method: str, prior_str: str, set_seed): """Test leakage correction both for MCMC and rejection sampling. Args: @@ -245,23 +241,37 @@ def test_api_sre_sampling_methods(mcmc_method: str, prior_str: str, set_seed): prior_str: one of "gaussian" or "uniform" set_seed: fixture for manual seeding """ - num_dim = 2 - x_o = zeros(num_dim) + num_samples = 10 + num_trials = 2 + # HMC with uniform prior needs good likelihood. + num_simulations = 10000 if sampling_method == "hmc" else 1000 + x_o = zeros((num_trials, num_dim)) + # Test for multiple chains is cheap when vectorized. + num_chains = 3 if sampling_method == "slice_np_vectorized" else 1 + if sampling_method == "rejection": + sample_with = "rejection" + else: + sample_with = "mcmc" + if prior_str == "gaussian": prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) else: - prior = utils.BoxUniform(low=-1.0 * ones(num_dim), high=ones(num_dim)) + prior = utils.BoxUniform(-1.0 * ones(num_dim), ones(num_dim)) simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior) - inference = SNRE_B( - prior, - classifier="resnet", - show_progress_bars=False, - ) + inference = SNRE_B(prior, classifier="resnet", show_progress_bars=False) - theta, x = simulate_for_sbi(simulator, prior, 200, simulation_batch_size=50) + theta, x = simulate_for_sbi( + simulator, prior, num_simulations, simulation_batch_size=50 + ) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - posterior = inference.build_posterior(mcmc_method=mcmc_method) + posterior = inference.build_posterior( + sample_with=sample_with, mcmc_method=sampling_method + ).set_default_x(x_o) - posterior.sample(sample_shape=(10,), x=x_o) + posterior.sample( + sample_shape=(num_samples,), + x=x_o, + mcmc_parameters={"thin": 3, "num_chains": num_chains}, + )