Skip to content

Commit

Permalink
pushing the refactoring along. SNPE and SNLE work now
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed May 19, 2021
1 parent 6c2f046 commit 8806832
Show file tree
Hide file tree
Showing 16 changed files with 741 additions and 453 deletions.
172 changes: 123 additions & 49 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
check_dist_class,
check_warn_and_setstate,
optimize_potential_fn,
rejection_sample,
)
from sbi.utils.torchutils import (
BoxUniform,
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
neural_net: A classifier for SNRE, a density estimator for SNPE and SNL.
prior: Prior distribution with `.log_prob()` and `.sample()`.
x_shape: Shape of the simulator data.
sample_with: Method to use for sampling from the posterior. Must be in
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
Expand All @@ -80,13 +81,15 @@ def __init__(
will draw init locations from prior, whereas `sir` will use Sequential-
Importance-Resampling. Init strategies may have their own keywords
which can also be set from `mcmc_parameters`.
rejection_sampling_parameters: Dictionary overriding the default parameters for
rejection sampling. The following parameters are supported:
`proposal`, as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution.
`max_sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration. `num_samples_to_find_max` as the
number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `num_iter_to_find_max` as the number
of gradient ascent iterations to find the maximum of that ratio. `m` as
multiplier to that ratio.
device: Training device, e.g., cpu or cuda.
"""
if method_family in ("snpe", "snle", "snre_a", "snre_b"):
Expand Down Expand Up @@ -152,7 +155,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
return self

@property
def sample_with(self) -> bool:
def sample_with(self) -> str:
"""
Return `True` if NeuralPosterior instance should use MCMC in `.sample()`.
"""
Expand All @@ -164,7 +167,7 @@ def sample_with(self, value: str) -> None:
self.set_sample_with(value)

def set_sample_with(self, sample_with: str) -> "NeuralPosterior":
"""Turns MCMC sampling on or off and returns `NeuralPosterior`.
"""Set the sampling method for the `NeuralPosterior`.
Args:
sample_with: The method to sample with.
Expand All @@ -176,6 +179,10 @@ def set_sample_with(self, sample_with: str) -> "NeuralPosterior":
ValueError: on attempt to turn off MCMC sampling for family of methods that
do not support rejection sampling.
"""
if sample_with not in ("mcmc", "rejection"):
raise NameError(
"The only implemented sampling methods are `mcmc` and `rejection`."
)
self._sample_with = sample_with
return self

Expand Down Expand Up @@ -235,7 +242,7 @@ def set_mcmc_parameters(self, parameters: Dict[str, Any]) -> "NeuralPosterior":

@property
def rejection_sampling_parameters(self) -> dict:
"""Returns rejection sampling parameter."""
"""Returns rejection sampling parameters."""
if self._rejection_sampling_parameters is None:
return {}
else:
Expand All @@ -254,9 +261,11 @@ def set_rejection_sampling_parameters(
Args:
parameters: Dictonary overriding the default parameters
for rejection sampling. The following parameters are supported:
`max_sampling_batch_size` to the set the batch size for drawing new
samples from the candidate distribution, e.g., the posterior. Larger
batch size speeds up sampling.
`proposal` as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
Returns:
`NeuralPosterior for chainable calls.
Expand Down Expand Up @@ -357,7 +366,7 @@ def _prepare_for_sample(
sample_shape: Optional[Tensor],
) -> Tuple[Tensor, int]:
r"""
Return checked and correctly shaped values for `x` and `sample_shape`.
Return checked, reshaped, potentially default values for `x` and `sample_shape`.
Args:
sample_shape: Desired shape of samples that are drawn from posterior. If
Expand All @@ -367,8 +376,8 @@ def _prepare_for_sample(
fall back onto `x_o` if previously provided for multiround training, or
to a set default (see `set_default_x()` method).
Returns: Single (potentially default) $x$ with batch dimension; an integer
number of samples
Returns: Single (default) $x$ with batch dimension; an integer number of
samples.
"""

x = atleast_2d_float32_tensor(self._x_else_default_x(x))
Expand Down Expand Up @@ -397,7 +406,7 @@ def _potentially_replace_mcmc_parameters(
Importance-Resampling using `init_strategy_num_candidates` to find init
locations.
Returns: A (potentially default) mcmc method and (potentially
Returns: A (default) mcmc method and (potentially
default) mcmc parameters.
"""
mcmc_method = mcmc_method if mcmc_method is not None else self.mcmc_method
Expand All @@ -413,11 +422,13 @@ def _potentially_replace_rejection_parameters(
Return potentially default values to rejection sample the posterior.
Args:
rejection_sampling_parameters: Dictionary overriding the default
parameters for rejection sampling. The following parameters are
supported: `m` as multiplier to the maximum ratio between
potential function and the proposal. `proposal`, as the proposal
distribtution.
rejection_sampling_parameters: Dictionary overriding the default
parameters for rejection sampling. The following parameters are
supported: `proposal` as the proposal distribtution.
`num_samples_to_find_max` as the number of samples that are used to
find the maximum of the `potential_fn / proposal` ratio. `m` as
multiplier to that ratio. `sampling_batch_size` as the batchsize of
samples being drawn from the proposal at every iteration.
Returns: Potentially default rejection sampling parameters.
"""
Expand Down Expand Up @@ -624,9 +635,11 @@ def sample_conditional(
condition: Tensor,
dims_to_sample: List[int],
x: Optional[Tensor] = None,
sample_with: str = "mcmc",
show_progress_bars: bool = True,
mcmc_method: Optional[str] = None,
mcmc_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Tensor:
r"""
Return samples from conditional posterior $p(\theta_i|\theta_j, x)$.
Expand All @@ -635,7 +648,7 @@ def sample_conditional(
from a few parameter dimensions while the other parameter dimensions are kept
fixed at values specified in `condition`.
Samples are obtained with MCMC.
Samples are obtained with MCMC or rejection sampling.
Args:
potential_fn_provider: Returns the potential function for the unconditional
Expand All @@ -653,6 +666,9 @@ def sample_conditional(
x: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x_o` if previously provided for multiround training, or
to a set default (see `set_default_x()` method).
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`]. In this method, the value of
`self.sample_with` will be ignored.
show_progress_bars: Whether to show sampling progress monitor.
mcmc_method: Optional parameter to override `self.mcmc_method`.
mcmc_parameters: Dictionary overriding the default parameters for MCMC.
Expand All @@ -663,38 +679,71 @@ def sample_conditional(
will draw init locations from prior, whereas `sir` will use Sequential-
Importance-Resampling using `init_strategy_num_candidates` to find init
locations.
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution (default is the prior).
`max_sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration. `num_samples_to_find_max` as the
number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `num_iter_to_find_max` as the number
of gradient ascent iterations to find the maximum of that ratio. `m` as
multiplier to that ratio.
Returns:
Samples from conditional posterior.
"""

x, num_samples, mcmc_method, mcmc_parameters = self._prepare_for_sample(
x, sample_shape, mcmc_method, mcmc_parameters
)

self.net.eval()

x, num_samples = self._prepare_for_sample(x, sample_shape)

cond_potential_fn_provider = ConditionalPotentialFunctionProvider(
potential_fn_provider, condition, dims_to_sample
)

samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, mcmc_method
),
init_fn=self._build_mcmc_init_fn(
# Restrict prior to sample only free dimensions.
RestrictedPriorForConditional(self._prior, dims_to_sample),
cond_potential_fn_provider(self._prior, self.net, x, "slice_np"),
if sample_with == "mcmc":
mcmc_method, mcmc_parameters = self._potentially_replace_mcmc_parameters(
mcmc_method, mcmc_parameters
)
samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, mcmc_method
),
init_fn=self._build_mcmc_init_fn(
# Restrict prior to sample only free dimensions.
RestrictedPriorForConditional(self._prior, dims_to_sample),
cond_potential_fn_provider(self._prior, self.net, x, "slice_np"),
**mcmc_parameters,
),
mcmc_method=mcmc_method,
condition=condition,
dims_to_sample=dims_to_sample,
show_progress_bars=show_progress_bars,
**mcmc_parameters,
),
mcmc_method=mcmc_method,
condition=condition,
dims_to_sample=dims_to_sample,
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
)
elif sample_with == "rejection":
rejection_sampling_parameters = (
self._potentially_replace_rejection_parameters(
rejection_sampling_parameters
)
)
if "proposal" not in rejection_sampling_parameters:
rejection_sampling_parameters[
"proposal"
] = RestrictedPriorForConditional(self._prior, dims_to_sample)

samples, _ = rejection_sample(
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, "rejection"
),
num_samples=num_samples,
**rejection_sampling_parameters,
)
else:
raise NameError(
"The only implemented sampling methods are `mcmc` and `rejection`."
)

self.net.train(True)

Expand Down Expand Up @@ -1071,18 +1120,43 @@ def __init__(
self.condition = ensure_theta_batched(condition)
self.dims_to_sample = dims_to_sample

def __call__(self, prior, net: nn.Module, x: Tensor, mcmc_method: str) -> Callable:
def __call__(self, prior, net: nn.Module, x: Tensor, method: str) -> Callable:
"""Return potential function.
Switch on numpy or pyro potential function based on `mcmc_method`.
Switch on numpy or pyro potential function based on `method`.
"""
# Set prior, net, and x as attributes of unconditional potential_fn_provider.
_ = self.potential_fn_provider.__call__(prior, net, x, mcmc_method)
_ = self.potential_fn_provider.__call__(prior, net, x, method)

if mcmc_method in ("slice", "hmc", "nuts"):
if method in ("slice", "hmc", "nuts"):
return self.pyro_potential
else:
elif "slice_np" in method:
return self.np_potential
elif method == "rejection":
return self.rejection_potential
else:
NotImplementedError

def rejection_potential(self, theta: np.ndarray) -> ScalarFloat:
r"""
Return conditional posterior log-probability or $-\infty$ if outside prior.
The only differences to the `np_potential` is that it tracks the gradients and
does not return a `numpy` array.
Args:
theta: Free parameters $\theta_i$, batch dimension 1.
Returns:
Conditional posterior log-probability $\log(p(\theta_i|\theta_j, x))$,
masked outside of prior.
"""
theta = torch.as_tensor(theta, dtype=torch.float32)

theta_condition = deepcopy(self.condition)
theta_condition[:, self.dims_to_sample] = theta

return self.potential_fn_provider.rejection_potential(theta_condition)

def np_potential(self, theta: np.ndarray) -> ScalarFloat:
r"""
Expand Down
Loading

0 comments on commit 8806832

Please sign in to comment.