diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index b18dfd729..c33998b7a 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -1221,38 +1221,3 @@ def log_prob(self, *args, **kwargs): the $\theta$ under the full joint once we have added the condition. """ return self.full_prior.log_prob(*args, **kwargs) - - -class NeuralNetDefaultX: - def __init__(self, density_estimator: Any, x: Tensor) -> None: - r""" - Wraps the neural density estimator returned by `nflows` to have a default $x$. - - Currently, it is used only as `proposal` for rejection sampling with `SNPE`. - - Args: - density_estimator: The neural density estimator parameterizing $p(y|x)$. - x: The value for $x$ at which to evaluate or sample $p(y|x)$. - """ - self.net = density_estimator - self.x = x - - def sample(self, num_samples, **kwargs) -> Tensor: - """ - Return samples from $p(y|x)$. - - Args: - num_samples: Number of samples to return. - """ - s = self.net.sample(num_samples, context=self.x, **kwargs) - return s - - def log_prob(self, y, **kwargs) -> Tensor: - r""" - Return the log-probabilities $\log(p(y|x))$. - - Args: - y: Location at which to evlauate the log-probability. - """ - ys, xs = NeuralPosterior._match_theta_and_x_batch_shapes(y, self.x) - return self.net.log_prob(ys, context=xs) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 01eacfbbe..2121b11c0 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -240,12 +240,10 @@ def leakage_correction( def acceptance_at(x: Tensor) -> Tensor: - potential_fn_provider = PotentialFunctionProvider() - return utils.rejection_sample( - potential_fn=potential_fn_provider( - self._prior, self.net, x.to(self._device), "rejection" - ), - proposal=NeuralNetDefaultX(self.net, x.to(self._device)), + return utils.rejection_sample_posterior_within_prior( + posterior_nn=self.net, + prior=self._prior, + x=x.to(self._device), num_samples=num_rejection_samples, show_progress_bars=show_progress_bars, sample_for_correction_factor=True, @@ -362,31 +360,24 @@ def sample( not self.net.training ), "Posterior nn must be in eval mode for sampling." - # This case covers the scenario where we sample the "posterior within - # the prior". We simply want to draw samples from the posterior and - # reject them if the prior has `-inf` log-probability. It would be - # possible to implement this via standard rejection sampling by setting - # the potential function to a masked version of the neural net. - # However, this would require evaluating the proposal and potential, - # which generates overhead. Instead, we built an if-else case into the - # `rejection_sample()` method which is triggered by - # `binary_rejection_criterion=True`(see its docstring for more info). - rejection_sampling_parameters["binary_rejection_criterion"] = True - rejection_sampling_parameters["proposal"] = NeuralNetDefaultX( - self.net, x - ) - potential_fn = self._prior + samples = utils.rejection_sample_posterior_within_prior( + posterior_nn=self.net, + prior=self._prior, + x=x.to(self._device), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + sample_for_correction_factor=True, + **rejection_sampling_parameters, + )[0] else: - potential_fn = ( - potential_fn_provider(self._prior, self.net, x, "rejection"), + samples, _ = rejection_sample( + potential_fn=potential_fn_provider( + self._prior, self.net, x, "rejection" + ), + num_samples=num_samples, + show_progress_bars=show_progress_bars, + **rejection_sampling_parameters, ) - - samples, _ = rejection_sample( - potential_fn=potential_fn, - num_samples=num_samples, - show_progress_bars=show_progress_bars, - **rejection_sampling_parameters, - ) else: raise NameError( "The only implemented sampling methods are `mcmc` and `rejection`." diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 7127d79c0..a77d5fce1 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -22,6 +22,7 @@ mog_log_prob, optimize_potential_fn, rejection_sample, + rejection_sample_posterior_within_prior, standardizing_net, standardizing_transform, warn_if_zscoring_changes_data, diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index eadd6f65d..0c5b9e6c6 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -146,13 +146,150 @@ def standardizing_net(batch_t: Tensor, min_std: float = 1e-7) -> nn.Module: return Standardize(t_mean, t_std) +@torch.no_grad() +def rejection_sample_posterior_within_prior( + posterior_nn: Any, + prior: Callable, + x: Tensor, + num_samples: int, + show_progress_bars: bool = False, + warn_acceptance: float = 0.01, + sample_for_correction_factor: bool = False, + max_sampling_batch_size: int = 10_000, + **kwargs, +) -> Tuple[Tensor, Tensor]: + r"""Return samples from a posterior $p(\theta|x)$ only within the prior support. + + This is relevant for snpe methods and flows for which the posterior tends to have + mass outside the prior boundaries. + + This function could in principle be integrated into `rejection_sample()`. However, + to keep the warnings clean, to avoid additional code for integration, and confusing + if-cases, we decided to keep two separate functions. + + This function uses rejection sampling with samples from posterior in order to + 1) obtain posterior samples within the prior support, and + 2) calculate the fraction of accepted samples as a proxy for correcting the + density during evaluation of the posterior. + + Args: + posterior_nn: Neural net representing the posterior. + prior: Distribution-like object that evaluates probabilities with `log_prob`. + x: Conditioning variable $x$ for the posterior $p(\theta|x)$. + num_samples: Desired number of samples. + show_progress_bars: Whether to show a progressbar during sampling. + warn_acceptance: A minimum acceptance rate under which to warn about slowness. + sample_for_correction_factor: True if this function was called by + `leakage_correction()`. False otherwise. Will be used to adapt the leakage + warning and to decide whether we have to search for the maximum. + max_sampling_batch_size: Batch size for drawing samples from the posterior. + Takes effect only in the second iteration of the loop below, i.e., in case + of leakage or `num_samples>max_sampling_batch_size`. Larger batch size + speeds up sampling. + kwargs: Absorb additional unused arguments that can be passed to + `rejection_sample()`. Warn if not empty. + + Returns: + Accepted samples and acceptance rate as scalar Tensor. + """ + + if kwargs: + logging.warn( + f"You passed arguments to `rejection_sampling_parameters` that " + f"are unused when you do not specify a `proposal` in the same " + f"dictionary. The unused arguments are: {kwargs}" + ) + + # Progress bar can be skipped, e.g. when sampling after each round just for + # logging. + pbar = tqdm( + disable=not show_progress_bars, + total=num_samples, + desc=f"Drawing {num_samples} posterior samples", + ) + + num_sampled_total, num_remaining = 0, num_samples + accepted, acceptance_rate = [], float("Nan") + leakage_warning_raised = False + + # To cover cases with few samples without leakage: + sampling_batch_size = min(num_samples, max_sampling_batch_size) + while num_remaining > 0: + + # Sample and reject. + candidates = ( + posterior_nn.sample(sampling_batch_size, context=x) + .reshape(sampling_batch_size, -1) + .cpu() + ) + + # SNPE-style rejection-sampling when the proposal is the neural net. + candidates = candidates.cpu() + are_within_prior = within_support(prior, candidates) + samples = candidates[are_within_prior] + + accepted.append(samples) + + # Update. + num_sampled_total += sampling_batch_size + num_remaining -= samples.shape[0] + pbar.update(samples.shape[0]) + + # To avoid endless sampling when leakage is high, we raise a warning if the + # acceptance rate is too low after the first 1_000 samples. + acceptance_rate = (num_samples - num_remaining) / num_sampled_total + + # For remaining iterations (leakage or many samples) continue sampling with + # fixed batch size. + sampling_batch_size = max_sampling_batch_size + if ( + num_sampled_total > 1000 + and acceptance_rate < warn_acceptance + and not leakage_warning_raised + ): + if sample_for_correction_factor: + logging.warning( + f"""Drawing samples from posterior to estimate the normalizing + constant for `log_prob()`. However, only + {acceptance_rate:.0%} posterior samples are within the + prior support. It may take a long time to collect the + remaining {num_remaining} samples. + Consider interrupting (Ctrl-C) and either basing the + estimate of the normalizing constant on fewer samples (by + calling `posterior.leakage_correction(x_o, + num_rejection_samples=N)`, where `N` is the number of + samples you want to base the + estimate on (default N=10000), or not estimating the + normalizing constant at all + (`log_prob(..., norm_posterior=False)`. The latter will + result in an unnormalized `log_prob()`.""" + ) + else: + logging.warning( + f"""Only {acceptance_rate:.0%} posterior samples are within the + prior support. It may take a long time to collect the + remaining {num_remaining} samples. Consider interrupting + (Ctrl-C) and switching to `sample_with='mcmc'`.""" + ) + leakage_warning_raised = True # Ensure warning is raised just once. + + pbar.close() + + # When in case of leakage a batch size was used there could be too many samples. + samples = torch.cat(accepted)[:num_samples] + assert ( + samples.shape[0] == num_samples + ), "Number of accepted samples must match required samples." + + return samples, as_tensor(acceptance_rate) + + def rejection_sample( potential_fn: Callable, proposal: Any, num_samples: int = 1, show_progress_bars: bool = False, warn_acceptance: float = 0.01, - sample_for_correction_factor: bool = False, max_sampling_batch_size: int = 10_000, binary_rejection_criterion: bool = False, num_samples_to_find_max: int = 10_000, @@ -173,9 +310,6 @@ def rejection_sample( num_samples: Desired number of samples. show_progress_bars: Whether to show a progressbar during sampling. warn_acceptance: A minimum acceptance rate under which to warn about slowness. - sample_for_correction_factor: True if this function was called by - `leakage_correction()`. False otherwise. Will be used to adapt the leakage - warning and to decide whether we have to search for the maximum. max_sampling_batch_size: Batch size for drawing samples from the posterior. Takes effect only in the second iteration of the loop below, i.e., in case of leakage or `num_samples>max_sampling_batch_size`. Larger batch size @@ -198,44 +332,49 @@ def rejection_sample( Accepted samples and acceptance rate as scalar Tensor. """ - if not binary_rejection_criterion and not sample_for_correction_factor: - find_max = proposal.sample((num_samples_to_find_max,)) - - # Define a potential as the ratio between target distribution and proposal. - def potential_over_proposal(theta): - return potential_fn(theta) - proposal.log_prob(theta) - - # Search for the maximum of the ratio. - _, max_log_ratio = optimize_potential_fn( - potential_fn=potential_over_proposal, - inits=find_max, - dist_specifying_bounds=proposal, - num_iter=100, - learning_rate=0.01, - num_to_optimize=max(1, int(num_samples_to_find_max / 10)), - show_progress_bars=False, - ) + find_max = proposal.sample((num_samples_to_find_max,)) + + # Define a potential as the ratio between target distribution and proposal. + def potential_over_proposal(theta): + return potential_fn(theta) - proposal.log_prob(theta) + + # Search for the maximum of the ratio. + _, max_log_ratio = optimize_potential_fn( + potential_fn=potential_over_proposal, + inits=find_max, + dist_specifying_bounds=proposal, + num_iter=100, + learning_rate=0.01, + num_to_optimize=max(1, int(num_samples_to_find_max / 10)), + show_progress_bars=False, + ) - if m < 1.0: - warnings.warn( - "A value of m < 1.0 will lead to systematically wrong results." + if m < 1.0: + warnings.warn("A value of m < 1.0 will lead to systematically wrong results.") + + class ScaledProposal: + def __init__(self, proposal: Any, max_log_ratio: float, log_m: float): + self.proposal = proposal + self.max_log_ratio = max_log_ratio + self.log_m = log_m + + def sample(self, sample_shape: torch.Size, **kwargs) -> Tensor: + """ + Samples from the `ScaledProposal` are samples from the `proposal`. + """ + return self.proposal.sample((sample_shape,), **kwargs) + + def log_prob(self, theta: Tensor, **kwargs) -> Tensor: + """ + The log-prob is scaled such that the proposal is always above the potential. + """ + return ( + self.proposal.log_prob(theta, **kwargs) + + self.max_log_ratio + + self.log_m ) - class ScaledProposal: - def __init__(self, proposal: Any, max_log_ratio: float, log_m: float): - self.proposal = proposal - self.max_log_ratio = max_log_ratio - self.log_m = log_m - - def sample(self, sample_shape, **kwargs): - return self.proposal.sample((sample_shape,), **kwargs) - - def log_prob(self, theta, **kwargs): - return self.proposal.log_prob(theta) + self.max_log_ratio + self.log_m - - proposal = ScaledProposal( - proposal, max_log_ratio, torch.log(torch.as_tensor(m)) - ) + proposal = ScaledProposal(proposal, max_log_ratio, torch.log(torch.as_tensor(m))) with torch.no_grad(): # Progress bar can be skipped, e.g. when sampling after each round just for @@ -291,30 +430,12 @@ def log_prob(self, theta, **kwargs): and acceptance_rate < warn_acceptance and not leakage_warning_raised ): - if sample_for_correction_factor: - logging.warning( - f"""Drawing samples from posterior to estimate the normalizing - constant for `log_prob()`. However, only - {acceptance_rate:.0%} posterior samples are within the - prior support. It may take a long time to collect the - remaining {num_remaining} samples. - Consider interrupting (Ctrl-C) and either basing the - estimate of the normalizing constant on fewer samples (by - calling `posterior.leakage_correction(x_o, - num_rejection_samples=N)`, where `N` is the number of - samples you want to base the - estimate on (default N=10000), or not estimating the - normalizing constant at all - (`log_prob(..., norm_posterior=False)`. The latter will - result in an unnormalized `log_prob()`.""" - ) - else: - logging.warning( - f"""Only {acceptance_rate:.0%} posterior samples are within the - prior support. It may take a long time to collect the - remaining {num_remaining} samples. Consider interrupting - (Ctrl-C) and switching to `sample_with_mcmc=True`.""" - ) + logging.warning( + f"""Only {acceptance_rate:.0%} proposal samples were accepted. It + may take a long time to collect the remaining {num_remaining} + samples. Consider interrupting (Ctrl-C) and switching to + `sample_with='mcmc`.""" + ) leakage_warning_raised = True # Ensure warning is raised just once. pbar.close()