Skip to content

Commit

Permalink
Add docstrings; clean up; minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 7, 2022
1 parent 02b2b5c commit 913ac84
Show file tree
Hide file tree
Showing 30 changed files with 635 additions and 339 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# v0.18.0 (next release)
-

# v0.17.2

## Minor changes
Expand Down
2 changes: 1 addition & 1 deletion sbi/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
)
from sbi.analysis.plot import conditional_pairplot, pairplot
from sbi.analysis.sensitivity_analysis import ActiveSubspace
from sbi.analysis.get_maximum import get_maximum
from sbi.analysis.gradient_ascent import gradient_ascent
74 changes: 61 additions & 13 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn

import torch.distributions.transforms as torch_tf
from sbi.utils import conditional_corrcoeff as utils_conditional_corrcoeff
from sbi.utils import eval_conditional_density as utils_eval_conditional_density
from sbi.types import Shape
Expand Down Expand Up @@ -124,21 +124,42 @@ def conditional_corrcoeff(


def parameter_conditional_mdn(
net: nn.Module, xo: Tensor, prior: Any, condition: Tensor, dims_to_sample: List[int]
):
net: nn.Module,
x_o: Tensor,
condition: Tensor,
dims_to_sample: List[int],
) -> "ConditionedMDN":
r"""
Returns a class that can sample and log-prob a conditional mixture-of-gaussians.
Args:
net: Mixture density network that models $p(\theta|x).
x_o: The datapoint at which the `net` is evaluated.
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
i.e. it could e.g. be a sample from the posterior distribution.
The entries at all `dims_to_sample` will be ignored.
dims_to_sample: Which dimensions to sample from. The dimensions not
specified in `dims_to_sample` will be fixed to values given in
`condition`.
Returns:
A mixture of Gaussians with `.sample()` and `.log_prob` methods.
"""

class ConditionedMDN:
def __init__(
self,
net: nn.Module,
xo: Tensor,
x_o: Tensor,
condition: Tensor,
dims_to_sample: List[int],
):
condition = atleast_2d_float32_tensor(condition)

logits, means, precfs, _ = extract_and_transform_mog(nn=net, context=xo)
logits, means, precfs, _ = extract_and_transform_mog(nn=net, context=x_o)
self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog(
prior, condition, dims_to_sample, logits, means, precfs
condition, dims_to_sample, logits, means, precfs
)
self.prec = self.precfs.transpose(3, 2) @ self.precfs

Expand All @@ -159,27 +180,54 @@ def log_prob(self, theta: Tensor):
)
return log_prob

conditioned_mdn = ConditionedMDN(net, xo, condition, dims_to_sample)
conditioned_mdn = ConditionedMDN(net, x_o, condition, dims_to_sample)
return conditioned_mdn


def parameter_conditonal_potential(
potential_fn: Callable,
potential_tf,
theta_transform: torch_tf,
prior: Any,
condition: Tensor,
dims_to_sample: List[int],
):
) -> Tuple[Callable, torch_tf.Transform, Any]:
r"""
Returns a potential function that can be used to sample the conditional potential.
It also returns a transform and a prior to be used to sample the conditional
potential.
The conditional potential is $p(\theta_i | \theta_j, x_o) \propto p(\theta | x_o)$
but is a function only of $\theta_i$.
Args:
potential_fn: The potential function to be conditioned.
theta_transform: The parameter transformation that should be reduced (by
ignoring dimensions not contained in `dims_to_sample`).
prior: The prior distribution that should be reduced (by ignoring dimensions
not contained in `dims_to_sample`).
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
i.e. it could e.g. be a sample from the posterior distribution.
The entries at all `dims_to_sample` will be ignored.
dims_to_sample: Which dimensions to sample from. The dimensions not
specified in `dims_to_sample` will be fixed to values given in
`condition`.
Returns:
"""

restricted_tf = RestrictedTransformForConditional(
potential_tf, condition, dims_to_sample
theta_transform, condition, dims_to_sample
)

condition = atleast_2d_float32_tensor(condition)

# Transform the `condition` to unconstrained space.
transformed_condition = potential_tf(condition)
transformed_condition = theta_transform(condition)

conditioned_potential_fn = _build_conditioned_potential_fn(
conditioned_potential_fn = build_conditioned_potential_fn(
potential_fn, transformed_condition, dims_to_sample
)

Expand Down
33 changes: 17 additions & 16 deletions sbi/analysis/get_maximum.py → sbi/analysis/gradient_ascent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch.distributions.transforms as torch_tf


def get_maximum(
def gradient_ascent(
potential_fn: Callable,
inits: Tensor,
potential_tf: Optional[torch_tf.Transform] = None,
theta_transform: Optional[torch_tf.Transform] = None,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
Expand All @@ -16,9 +16,9 @@ def get_maximum(
interruption_note: str = "",
) -> Tuple[Tensor, Tensor]:
"""
Returns the `argmax` and `max` of a `potential_fn`.
Returns the `argmax` and `max` of a `potential_fn` via gradient ascent.
The method can be interrupted (Ctrl-C) when the user sees that the log-probability
The method can be interrupted (Ctrl-C) when the user sees that the potential_fn
converges. The best estimate will be returned.
The maximum is obtained by running gradient ascent from given starting parameters.
Expand All @@ -31,9 +31,8 @@ def get_maximum(
Args:
potential_fn: The function on which to optimize.
inits: The initial parameters at which to start the gradient ascent steps.
dist_specifying_bounds: Distribution the specifies bounds for the optimization.
If it is a `sbi.utils.BoxUniform`, we transform the space into
unconstrained space and carry out the optimization there.
theta_transform: If passed, this transformation will be applied during the
optimization.
num_iter: Number of optimization steps that the algorithm takes
to find the MAP.
num_to_optimize: From the drawn `num_init_samples`, use the `num_to_optimize`
Expand All @@ -51,12 +50,12 @@ def get_maximum(
The `argmax` and `max` of the `potential_fn`.
"""

if potential_tf is None:
potential_tf = torch_tf.IndependentTransform(
if theta_transform is None:
theta_transform = torch_tf.IndependentTransform(
torch_tf.identity_transform, reinterpreted_batch_ndims=1
)
else:
potential_tf = potential_tf
theta_transform = theta_transform

init_probs = potential_fn(inits).detach()

Expand All @@ -76,7 +75,7 @@ def get_maximum(
argmax_ = best_theta_overall
max_val = best_log_prob_overall

optimize_inits = potential_tf(optimize_inits)
optimize_inits = theta_transform(optimize_inits)
optimize_inits.requires_grad_(True)
optimizer = optim.Adam([optimize_inits], lr=learning_rate)

Expand All @@ -89,7 +88,7 @@ def get_maximum(
while iter_ < num_iter:

optimizer.zero_grad()
probs = potential_fn(potential_tf.inv(optimize_inits)).squeeze()
probs = potential_fn(theta_transform.inv(optimize_inits)).squeeze()
loss = -probs.sum()
loss.backward()
optimizer.step()
Expand All @@ -98,12 +97,14 @@ def get_maximum(
if iter_ % save_best_every == 0 or iter_ == num_iter - 1:
# Evaluate the optimized locations and pick the best one.
log_probs_of_optimized = potential_fn(
potential_tf.inv(optimize_inits)
theta_transform.inv(optimize_inits)
)
best_theta_iter = optimize_inits[
torch.argmax(log_probs_of_optimized)
]
best_log_prob_iter = potential_fn(potential_tf.inv(best_theta_iter))
best_log_prob_iter = potential_fn(
theta_transform.inv(best_theta_iter)
)
if best_log_prob_iter > best_log_prob_overall:
best_theta_overall = best_theta_iter.detach().clone()
best_log_prob_overall = best_log_prob_iter.detach().clone()
Expand All @@ -116,7 +117,7 @@ def get_maximum(
{best_log_prob_iter.item():.2f} (= unnormalized log-prob""",
end="\r",
)
argmax_ = potential_tf.inv(best_theta_overall)
argmax_ = theta_transform.inv(best_theta_overall)
max_val = best_log_prob_overall

iter_ += 1
Expand All @@ -126,4 +127,4 @@ def get_maximum(
print(interruption + interruption_note)
return argmax_, max_val

return potential_tf.inv(best_theta_overall), max_val
return theta_transform.inv(best_theta_overall), max_val
3 changes: 3 additions & 0 deletions sbi/inference/posteriors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
70 changes: 31 additions & 39 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@
from torch import nn

from sbi import utils as utils

from sbi.types import Array, Shape
from sbi.utils.sbiutils import (
check_warn_and_setstate,
mcmc_transform,
optimize_potential_fn,
)
from sbi.utils.torchutils import (
ScalarFloat,
atleast_2d_float32_tensor,
Expand All @@ -37,58 +31,33 @@ class NeuralPosterior(ABC):
All inference methods in sbi train a neural network which is then used to obtain
the posterior distribution. The `NeuralPosterior` class wraps the trained network
such that one can directly evaluate the (unnormalized) log probability and draw
samples from the posterior. The neural network itself can be accessed via the `.net`
attribute.
samples from the posterior.
"""

def __init__(
self,
potential_fn: Callable,
potential_tf: Optional[torch_tf.Transform] = None,
theta_transform: Optional[torch_tf.Transform] = None,
device: str = "cpu",
):
"""
Args:
method_family: One of snpe, snl, snre_a or snre_b.
neural_net: A classifier for SNRE, a density estimator for SNPE and SNL.
prior: Prior distribution with `.log_prob()` and `.sample()`.
x_shape: Shape of the simulator data.
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
Pyro-based sampling.
mcmc_parameters: Dictionary overriding the default parameters for MCMC.
The following parameters are supported: `thin` to set the thinning
factor for the chain, `warmup_steps` to set the initial number of
samples to discard, `num_chains` for the number of chains,
`init_strategy` for the initialisation strategy for chains; `prior`
will draw init locations from prior, whereas `sir` will use Sequential-
Importance-Resampling. Init strategies may have their own keywords
which can also be set from `mcmc_parameters`.
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution.
`max_sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration. `num_samples_to_find_max` as the
number of samples that are used to find the maximum of the
`potential_fn / proposal` ratio. `num_iter_to_find_max` as the number
of gradient ascent iterations to find the maximum of that ratio. `m` as
multiplier to that ratio.
potential_fn: The potential function from which to draw samples.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0".
"""
# Ensure device string.
device = process_device(device)

self.potential_fn = potential_fn

if potential_tf is None:
self.potential_tf = torch_tf.IndependentTransform(
if theta_transform is None:
self.theta_transform = torch_tf.IndependentTransform(
torch_tf.identity_transform, reinterpreted_batch_ndims=1
)
else:
self.potential_tf = potential_tf
self.theta_transform = theta_transform

self._num_trained_rounds = 0
self._num_iid_trials = None
Expand All @@ -98,12 +67,35 @@ def __init__(
self._purpose = ""

def potential(self, theta: Tensor, track_gradients: bool = False) -> Tensor:
r"""
Evaluates theta under the potential that is used to sample the posterior.
The potential is the unnormalized log-probability of theta under the posterior.
Args:
theta: Parameters $\theta$.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))
return self.potential_fn(
theta.to(self._device), track_gradients=track_gradients
)

def log_prob(self, theta: Tensor, track_gradients: bool = False) -> Tensor:
r"""
Returns the log-probability of theta under the posterior.
Args:
theta: Parameters $\theta$.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
Returns:
`len($\theta$)`-shaped log-probability.
"""
warn(
"`.log_prob()` is deprecated for methods that can only evaluate the log-probability up to a normalizing constant. Use `.potential()` instead."
)
Expand Down
Loading

0 comments on commit 913ac84

Please sign in to comment.