Skip to content

Commit

Permalink
Eventually decide to stick to 2 rejection_sampling methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed May 19, 2021
1 parent f14b6a8 commit 6d954c9
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 127 deletions.
35 changes: 0 additions & 35 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,38 +1221,3 @@ def log_prob(self, *args, **kwargs):
the $\theta$ under the full joint once we have added the condition.
"""
return self.full_prior.log_prob(*args, **kwargs)


class NeuralNetDefaultX:
def __init__(self, density_estimator: Any, x: Tensor) -> None:
r"""
Wraps the neural density estimator returned by `nflows` to have a default $x$.
Currently, it is used only as `proposal` for rejection sampling with `SNPE`.
Args:
density_estimator: The neural density estimator parameterizing $p(y|x)$.
x: The value for $x$ at which to evaluate or sample $p(y|x)$.
"""
self.net = density_estimator
self.x = x

def sample(self, num_samples, **kwargs) -> Tensor:
"""
Return samples from $p(y|x)$.
Args:
num_samples: Number of samples to return.
"""
s = self.net.sample(num_samples, context=self.x, **kwargs)
return s

def log_prob(self, y, **kwargs) -> Tensor:
r"""
Return the log-probabilities $\log(p(y|x))$.
Args:
y: Location at which to evlauate the log-probability.
"""
ys, xs = NeuralPosterior._match_theta_and_x_batch_shapes(y, self.x)
return self.net.log_prob(ys, context=xs)
49 changes: 20 additions & 29 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,10 @@ def leakage_correction(

def acceptance_at(x: Tensor) -> Tensor:

potential_fn_provider = PotentialFunctionProvider()
return utils.rejection_sample(
potential_fn=potential_fn_provider(
self._prior, self.net, x.to(self._device), "rejection"
),
proposal=NeuralNetDefaultX(self.net, x.to(self._device)),
return utils.rejection_sample_posterior_within_prior(
posterior_nn=self.net,
prior=self._prior,
x=x.to(self._device),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
Expand Down Expand Up @@ -362,31 +360,24 @@ def sample(
not self.net.training
), "Posterior nn must be in eval mode for sampling."

# This case covers the scenario where we sample the "posterior within
# the prior". We simply want to draw samples from the posterior and
# reject them if the prior has `-inf` log-probability. It would be
# possible to implement this via standard rejection sampling by setting
# the potential function to a masked version of the neural net.
# However, this would require evaluating the proposal and potential,
# which generates overhead. Instead, we built an if-else case into the
# `rejection_sample()` method which is triggered by
# `binary_rejection_criterion=True`(see its docstring for more info).
rejection_sampling_parameters["binary_rejection_criterion"] = True
rejection_sampling_parameters["proposal"] = NeuralNetDefaultX(
self.net, x
)
potential_fn = self._prior
samples = utils.rejection_sample_posterior_within_prior(
posterior_nn=self.net,
prior=self._prior,
x=x.to(self._device),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
**rejection_sampling_parameters,
)[0]
else:
potential_fn = (
potential_fn_provider(self._prior, self.net, x, "rejection"),
samples, _ = rejection_sample(
potential_fn=potential_fn_provider(
self._prior, self.net, x, "rejection"
),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
**rejection_sampling_parameters,
)

samples, _ = rejection_sample(
potential_fn=potential_fn,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
**rejection_sampling_parameters,
)
else:
raise NameError(
"The only implemented sampling methods are `mcmc` and `rejection`."
Expand Down
1 change: 1 addition & 0 deletions sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
mog_log_prob,
optimize_potential_fn,
rejection_sample,
rejection_sample_posterior_within_prior,
standardizing_net,
standardizing_transform,
warn_if_zscoring_changes_data,
Expand Down
247 changes: 184 additions & 63 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,150 @@ def standardizing_net(batch_t: Tensor, min_std: float = 1e-7) -> nn.Module:
return Standardize(t_mean, t_std)


@torch.no_grad()
def rejection_sample_posterior_within_prior(
posterior_nn: Any,
prior: Callable,
x: Tensor,
num_samples: int,
show_progress_bars: bool = False,
warn_acceptance: float = 0.01,
sample_for_correction_factor: bool = False,
max_sampling_batch_size: int = 10_000,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Return samples from a posterior $p(\theta|x)$ only within the prior support.
This is relevant for snpe methods and flows for which the posterior tends to have
mass outside the prior boundaries.
This function could in principle be integrated into `rejection_sample()`. However,
to keep the warnings clean, to avoid additional code for integration, and confusing
if-cases, we decided to keep two separate functions.
This function uses rejection sampling with samples from posterior in order to
1) obtain posterior samples within the prior support, and
2) calculate the fraction of accepted samples as a proxy for correcting the
density during evaluation of the posterior.
Args:
posterior_nn: Neural net representing the posterior.
prior: Distribution-like object that evaluates probabilities with `log_prob`.
x: Conditioning variable $x$ for the posterior $p(\theta|x)$.
num_samples: Desired number of samples.
show_progress_bars: Whether to show a progressbar during sampling.
warn_acceptance: A minimum acceptance rate under which to warn about slowness.
sample_for_correction_factor: True if this function was called by
`leakage_correction()`. False otherwise. Will be used to adapt the leakage
warning and to decide whether we have to search for the maximum.
max_sampling_batch_size: Batch size for drawing samples from the posterior.
Takes effect only in the second iteration of the loop below, i.e., in case
of leakage or `num_samples>max_sampling_batch_size`. Larger batch size
speeds up sampling.
kwargs: Absorb additional unused arguments that can be passed to
`rejection_sample()`. Warn if not empty.
Returns:
Accepted samples and acceptance rate as scalar Tensor.
"""

if kwargs:
logging.warn(
f"You passed arguments to `rejection_sampling_parameters` that "
f"are unused when you do not specify a `proposal` in the same "
f"dictionary. The unused arguments are: {kwargs}"
)

# Progress bar can be skipped, e.g. when sampling after each round just for
# logging.
pbar = tqdm(
disable=not show_progress_bars,
total=num_samples,
desc=f"Drawing {num_samples} posterior samples",
)

num_sampled_total, num_remaining = 0, num_samples
accepted, acceptance_rate = [], float("Nan")
leakage_warning_raised = False

# To cover cases with few samples without leakage:
sampling_batch_size = min(num_samples, max_sampling_batch_size)
while num_remaining > 0:

# Sample and reject.
candidates = (
posterior_nn.sample(sampling_batch_size, context=x)
.reshape(sampling_batch_size, -1)
.cpu()
)

# SNPE-style rejection-sampling when the proposal is the neural net.
candidates = candidates.cpu()
are_within_prior = within_support(prior, candidates)
samples = candidates[are_within_prior]

accepted.append(samples)

# Update.
num_sampled_total += sampling_batch_size
num_remaining -= samples.shape[0]
pbar.update(samples.shape[0])

# To avoid endless sampling when leakage is high, we raise a warning if the
# acceptance rate is too low after the first 1_000 samples.
acceptance_rate = (num_samples - num_remaining) / num_sampled_total

# For remaining iterations (leakage or many samples) continue sampling with
# fixed batch size.
sampling_batch_size = max_sampling_batch_size
if (
num_sampled_total > 1000
and acceptance_rate < warn_acceptance
and not leakage_warning_raised
):
if sample_for_correction_factor:
logging.warning(
f"""Drawing samples from posterior to estimate the normalizing
constant for `log_prob()`. However, only
{acceptance_rate:.0%} posterior samples are within the
prior support. It may take a long time to collect the
remaining {num_remaining} samples.
Consider interrupting (Ctrl-C) and either basing the
estimate of the normalizing constant on fewer samples (by
calling `posterior.leakage_correction(x_o,
num_rejection_samples=N)`, where `N` is the number of
samples you want to base the
estimate on (default N=10000), or not estimating the
normalizing constant at all
(`log_prob(..., norm_posterior=False)`. The latter will
result in an unnormalized `log_prob()`."""
)
else:
logging.warning(
f"""Only {acceptance_rate:.0%} posterior samples are within the
prior support. It may take a long time to collect the
remaining {num_remaining} samples. Consider interrupting
(Ctrl-C) and switching to `sample_with='mcmc'`."""
)
leakage_warning_raised = True # Ensure warning is raised just once.

pbar.close()

# When in case of leakage a batch size was used there could be too many samples.
samples = torch.cat(accepted)[:num_samples]
assert (
samples.shape[0] == num_samples
), "Number of accepted samples must match required samples."

return samples, as_tensor(acceptance_rate)


def rejection_sample(
potential_fn: Callable,
proposal: Any,
num_samples: int = 1,
show_progress_bars: bool = False,
warn_acceptance: float = 0.01,
sample_for_correction_factor: bool = False,
max_sampling_batch_size: int = 10_000,
binary_rejection_criterion: bool = False,
num_samples_to_find_max: int = 10_000,
Expand All @@ -173,9 +310,6 @@ def rejection_sample(
num_samples: Desired number of samples.
show_progress_bars: Whether to show a progressbar during sampling.
warn_acceptance: A minimum acceptance rate under which to warn about slowness.
sample_for_correction_factor: True if this function was called by
`leakage_correction()`. False otherwise. Will be used to adapt the leakage
warning and to decide whether we have to search for the maximum.
max_sampling_batch_size: Batch size for drawing samples from the posterior.
Takes effect only in the second iteration of the loop below, i.e., in case
of leakage or `num_samples>max_sampling_batch_size`. Larger batch size
Expand All @@ -198,44 +332,49 @@ def rejection_sample(
Accepted samples and acceptance rate as scalar Tensor.
"""

if not binary_rejection_criterion and not sample_for_correction_factor:
find_max = proposal.sample((num_samples_to_find_max,))

# Define a potential as the ratio between target distribution and proposal.
def potential_over_proposal(theta):
return potential_fn(theta) - proposal.log_prob(theta)

# Search for the maximum of the ratio.
_, max_log_ratio = optimize_potential_fn(
potential_fn=potential_over_proposal,
inits=find_max,
dist_specifying_bounds=proposal,
num_iter=100,
learning_rate=0.01,
num_to_optimize=max(1, int(num_samples_to_find_max / 10)),
show_progress_bars=False,
)
find_max = proposal.sample((num_samples_to_find_max,))

# Define a potential as the ratio between target distribution and proposal.
def potential_over_proposal(theta):
return potential_fn(theta) - proposal.log_prob(theta)

# Search for the maximum of the ratio.
_, max_log_ratio = optimize_potential_fn(
potential_fn=potential_over_proposal,
inits=find_max,
dist_specifying_bounds=proposal,
num_iter=100,
learning_rate=0.01,
num_to_optimize=max(1, int(num_samples_to_find_max / 10)),
show_progress_bars=False,
)

if m < 1.0:
warnings.warn(
"A value of m < 1.0 will lead to systematically wrong results."
if m < 1.0:
warnings.warn("A value of m < 1.0 will lead to systematically wrong results.")

class ScaledProposal:
def __init__(self, proposal: Any, max_log_ratio: float, log_m: float):
self.proposal = proposal
self.max_log_ratio = max_log_ratio
self.log_m = log_m

def sample(self, sample_shape: torch.Size, **kwargs) -> Tensor:
"""
Samples from the `ScaledProposal` are samples from the `proposal`.
"""
return self.proposal.sample((sample_shape,), **kwargs)

def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
"""
The log-prob is scaled such that the proposal is always above the potential.
"""
return (
self.proposal.log_prob(theta, **kwargs)
+ self.max_log_ratio
+ self.log_m
)

class ScaledProposal:
def __init__(self, proposal: Any, max_log_ratio: float, log_m: float):
self.proposal = proposal
self.max_log_ratio = max_log_ratio
self.log_m = log_m

def sample(self, sample_shape, **kwargs):
return self.proposal.sample((sample_shape,), **kwargs)

def log_prob(self, theta, **kwargs):
return self.proposal.log_prob(theta) + self.max_log_ratio + self.log_m

proposal = ScaledProposal(
proposal, max_log_ratio, torch.log(torch.as_tensor(m))
)
proposal = ScaledProposal(proposal, max_log_ratio, torch.log(torch.as_tensor(m)))

with torch.no_grad():
# Progress bar can be skipped, e.g. when sampling after each round just for
Expand Down Expand Up @@ -291,30 +430,12 @@ def log_prob(self, theta, **kwargs):
and acceptance_rate < warn_acceptance
and not leakage_warning_raised
):
if sample_for_correction_factor:
logging.warning(
f"""Drawing samples from posterior to estimate the normalizing
constant for `log_prob()`. However, only
{acceptance_rate:.0%} posterior samples are within the
prior support. It may take a long time to collect the
remaining {num_remaining} samples.
Consider interrupting (Ctrl-C) and either basing the
estimate of the normalizing constant on fewer samples (by
calling `posterior.leakage_correction(x_o,
num_rejection_samples=N)`, where `N` is the number of
samples you want to base the
estimate on (default N=10000), or not estimating the
normalizing constant at all
(`log_prob(..., norm_posterior=False)`. The latter will
result in an unnormalized `log_prob()`."""
)
else:
logging.warning(
f"""Only {acceptance_rate:.0%} posterior samples are within the
prior support. It may take a long time to collect the
remaining {num_remaining} samples. Consider interrupting
(Ctrl-C) and switching to `sample_with_mcmc=True`."""
)
logging.warning(
f"""Only {acceptance_rate:.0%} proposal samples were accepted. It
may take a long time to collect the remaining {num_remaining}
samples. Consider interrupting (Ctrl-C) and switching to
`sample_with='mcmc`."""
)
leakage_warning_raised = True # Ensure warning is raised just once.

pbar.close()
Expand Down

0 comments on commit 6d954c9

Please sign in to comment.