Skip to content

Commit

Permalink
pushing the refactoring along. SNPE and SNLE work now
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed May 18, 2021
1 parent b809c66 commit e06fa36
Show file tree
Hide file tree
Showing 16 changed files with 664 additions and 441 deletions.
188 changes: 145 additions & 43 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch import Tensor, float32
from torch import multiprocessing as mp
from torch import nn, optim
from torch._C import device

from sbi import utils as utils
from sbi.mcmc import (
Expand All @@ -29,6 +30,7 @@
check_dist_class,
check_warn_and_setstate,
optimize_potential_fn,
rejection_sample,
)
from sbi.utils.torchutils import (
BoxUniform,
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(
neural_net: A classifier for SNRE, a density estimator for SNPE and SNL.
prior: Prior distribution with `.log_prob()` and `.sample()`.
x_shape: Shape of the simulator data.
sample_with: Method to use for sampling from the posterior. Must be in
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
Expand All @@ -80,12 +82,12 @@ def __init__(
will draw init locations from prior, whereas `sir` will use Sequential-
Importance-Resampling. Init strategies may have their own keywords
which can also be set from `mcmc_parameters`.
rejection_sampling_parameters: Dictionary overriding the default parameters for
rejection sampling. The following parameters are supported:
`proposal`, as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
device: Training device, e.g., cpu or cuda.
"""
Expand Down Expand Up @@ -164,7 +166,7 @@ def sample_with(self, value: str) -> None:
self.set_sample_with(value)

def set_sample_with(self, sample_with: str) -> "NeuralPosterior":
"""Turns MCMC sampling on or off and returns `NeuralPosterior`.
"""Set the sampling method for the `NeuralPosterior`.
Args:
sample_with: The method to sample with.
Expand Down Expand Up @@ -235,7 +237,7 @@ def set_mcmc_parameters(self, parameters: Dict[str, Any]) -> "NeuralPosterior":

@property
def rejection_sampling_parameters(self) -> dict:
"""Returns rejection sampling parameter."""
"""Returns rejection sampling parameters."""
if self._rejection_sampling_parameters is None:
return {}
else:
Expand All @@ -254,9 +256,11 @@ def set_rejection_sampling_parameters(
Args:
parameters: Dictonary overriding the default parameters
for rejection sampling. The following parameters are supported:
`max_sampling_batch_size` to the set the batch size for drawing new
samples from the candidate distribution, e.g., the posterior. Larger
batch size speeds up sampling.
`proposal` as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
Returns:
`NeuralPosterior for chainable calls.
Expand Down Expand Up @@ -357,7 +361,7 @@ def _prepare_for_sample(
sample_shape: Optional[Tensor],
) -> Tuple[Tensor, int]:
r"""
Return checked and correctly shaped values for `x` and `sample_shape`.
Return checked, reshaped, potentially default values for `x` and `sample_shape`.
Args:
sample_shape: Desired shape of samples that are drawn from posterior. If
Expand Down Expand Up @@ -413,11 +417,13 @@ def _potentially_replace_rejection_parameters(
Return potentially default values to rejection sample the posterior.
Args:
rejection_sampling_parameters: Dictionary overriding the default
parameters for rejection sampling. The following parameters are
supported: `m` as multiplier to the maximum ratio between
potential function and the proposal. `proposal`, as the proposal
distribtution.
rejection_sampling_parameters: Dictionary overriding the default
parameters for rejection sampling. The following parameters are
supported: `proposal` as the proposal distribtution.
`num_samples_to_find_max` as the number of samples that are used to
find the maximum of the `potential_fn / proposal` ratio. `m` as
multiplier to that ratio. `sampling_batch_size` as the batchsize of
samples being drawn from the proposal at every iteration.
Returns: Potentially default rejection sampling parameters.
"""
Expand Down Expand Up @@ -624,9 +630,11 @@ def sample_conditional(
condition: Tensor,
dims_to_sample: List[int],
x: Optional[Tensor] = None,
sample_with: str = "mcmc",
show_progress_bars: bool = True,
mcmc_method: Optional[str] = None,
mcmc_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Tensor:
r"""
Return samples from conditional posterior $p(\theta_i|\theta_j, x)$.
Expand All @@ -635,7 +643,7 @@ def sample_conditional(
from a few parameter dimensions while the other parameter dimensions are kept
fixed at values specified in `condition`.
Samples are obtained with MCMC.
Samples are obtained with MCMC or rejection sampling.
Args:
potential_fn_provider: Returns the potential function for the unconditional
Expand All @@ -653,6 +661,9 @@ def sample_conditional(
x: Conditioning context for posterior $p(\theta|x)$. If not provided,
fall back onto `x_o` if previously provided for multiround training, or
to a set default (see `set_default_x()` method).
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`]. In this method, the value of
`self.sample_with` will be ignored.
show_progress_bars: Whether to show sampling progress monitor.
mcmc_method: Optional parameter to override `self.mcmc_method`.
mcmc_parameters: Dictionary overriding the default parameters for MCMC.
Expand All @@ -663,38 +674,69 @@ def sample_conditional(
will draw init locations from prior, whereas `sir` will use Sequential-
Importance-Resampling using `init_strategy_num_candidates` to find init
locations.
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution. `num_samples_to_find_max`
as the number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `m` as multiplier to that ratio.
`sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
Returns:
Samples from conditional posterior.
"""

x, num_samples, mcmc_method, mcmc_parameters = self._prepare_for_sample(
x, sample_shape, mcmc_method, mcmc_parameters
)

self.net.eval()

x, num_samples = self._prepare_for_sample(x, sample_shape)

cond_potential_fn_provider = ConditionalPotentialFunctionProvider(
potential_fn_provider, condition, dims_to_sample
)

samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, mcmc_method
),
init_fn=self._build_mcmc_init_fn(
# Restrict prior to sample only free dimensions.
RestrictedPriorForConditional(self._prior, dims_to_sample),
cond_potential_fn_provider(self._prior, self.net, x, "slice_np"),
if sample_with == "mcmc":
mcmc_method, mcmc_parameters = self._potentially_replace_mcmc_parameters(
mcmc_method, mcmc_parameters
)
samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, mcmc_method
),
init_fn=self._build_mcmc_init_fn(
# Restrict prior to sample only free dimensions.
RestrictedPriorForConditional(self._prior, dims_to_sample),
cond_potential_fn_provider(self._prior, self.net, x, "slice_np"),
**mcmc_parameters,
),
mcmc_method=mcmc_method,
condition=condition,
dims_to_sample=dims_to_sample,
show_progress_bars=show_progress_bars,
**mcmc_parameters,
),
mcmc_method=mcmc_method,
condition=condition,
dims_to_sample=dims_to_sample,
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
)
elif sample_with == "rejection":
rejection_sampling_parameters = (
self._potentially_replace_rejection_parameters(
rejection_sampling_parameters
)
)
if "proposal" not in rejection_sampling_parameters:
rejection_sampling_parameters[
"proposal"
] = RestrictedPriorForConditional(self._prior, dims_to_sample)

samples, _ = rejection_sample(
potential_fn=cond_potential_fn_provider(
self._prior, self.net, x, "rejection"
),
num_samples=num_samples,
**rejection_sampling_parameters,
)
else:
raise NameError(
"The only implemented sampling methods are `mcmc` and `rejection`."
)

self.net.train(True)

Expand Down Expand Up @@ -1071,18 +1113,43 @@ def __init__(
self.condition = ensure_theta_batched(condition)
self.dims_to_sample = dims_to_sample

def __call__(self, prior, net: nn.Module, x: Tensor, mcmc_method: str) -> Callable:
def __call__(self, prior, net: nn.Module, x: Tensor, method: str) -> Callable:
"""Return potential function.
Switch on numpy or pyro potential function based on `mcmc_method`.
"""
# Set prior, net, and x as attributes of unconditional potential_fn_provider.
_ = self.potential_fn_provider.__call__(prior, net, x, mcmc_method)
_ = self.potential_fn_provider.__call__(prior, net, x, method)

if mcmc_method in ("slice", "hmc", "nuts"):
if method in ("slice", "hmc", "nuts"):
return self.pyro_potential
else:
elif "slice_np" in method:
return self.np_potential
elif method == "rejection":
return self.rejection_potential
else:
NotImplementedError

def rejection_potential(self, theta: np.ndarray) -> ScalarFloat:
r"""
Return conditional posterior log-probability or $-\infty$ if outside prior.
The only differences to the `np_potential` is that it tracks the gradients and
does not return a `numpy` array.
Args:
theta: Free parameters $\theta_i$, batch dimension 1.
Returns:
Conditional posterior log-probability $\log(p(\theta_i|\theta_j, x))$,
masked outside of prior.
"""
theta = torch.as_tensor(theta, dtype=torch.float32)

theta_condition = deepcopy(self.condition)
theta_condition[:, self.dims_to_sample] = theta

return self.potential_fn_provider.rejection_potential(theta_condition)

def np_potential(self, theta: np.ndarray) -> ScalarFloat:
r"""
Expand Down Expand Up @@ -1155,3 +1222,38 @@ def log_prob(self, *args, **kwargs):
the $\theta$ under the full joint once we have added the condition.
"""
return self.full_prior.log_prob(*args, **kwargs)


class NeuralNetDefaultX:
def __init__(self, density_estimator: Any, x: Tensor) -> None:
r"""
Wraps the neural density estimator returned by `nflows` to have a default $x$.
Currently, it is used only as `proposal` for rejection sampling with `SNPE`.
Args:
density_estimator: The neural density estimator parameterizing $p(y|x)$.
x: The value for $x$ at which to evaluate or sample $p(y|x)$.
"""
self.net = density_estimator
self.x = x

def sample(self, num_samples, **kwargs) -> Tensor:
"""
Return samples from $p(y|x)$.
Args:
num_samples: Number of samples to return.
"""
s = self.net.sample(num_samples, context=self.x, **kwargs)
return s

def log_prob(self, y, **kwargs) -> Tensor:
r"""
Return the log-probabilities $\log(p(y|x))$.
Args:
y: Location at which to evlauate the log-probability.
"""
ys, xs = NeuralPosterior._match_theta_and_x_batch_shapes(y, self.x)
return self.net.log_prob(ys, context=xs)
Loading

0 comments on commit e06fa36

Please sign in to comment.