Skip to content

Commit

Permalink
let abc return particles or kde, add kde methods, adapt tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jul 15, 2021
1 parent a615d91 commit 9ed18ca
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 59 deletions.
77 changes: 50 additions & 27 deletions sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Callable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from numpy import ndarray
from pyro.distributions.empirical import Empirical
from torch import Tensor, ones
from torch.distributions.distribution import Distribution
from torch import Tensor
from torch.distributions.transforms import Transform

from sbi.inference.abc.abc_base import ABCBASE
from sbi.utils.user_input_checks import process_x
from sbi.utils import KDEWrapper, get_kde, process_x


class MCABC(ABCBASE):
Expand Down Expand Up @@ -61,14 +60,21 @@ def __call__(
num_simulations: int,
eps: Optional[float] = None,
quantile: Optional[float] = None,
return_distances: bool = False,
return_x_accepted: bool = False,
lra: bool = False,
sass: bool = False,
sass_fraction: float = 0.25,
sass_expansion_degree: int = 1,
) -> Union[Distribution, Tuple[Distribution, Tensor]]:
r"""Run MCABC.
kde: bool = False,
kde_kwargs: Dict[str, Any] = dict(
kde_bandwidth="cv",
kde_transform=None,
sample_weights=None,
num_cv_partitions=20,
num_cv_repetitions=5,
),
return_summary: bool = False,
) -> Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper]:
r"""Run MCABC and return accepted parameters or KDE object fitted on them.
Args:
x_o: Observed data.
Expand All @@ -78,20 +84,30 @@ def __call__(
quantile: Upper quantile of smallest distances for which the corresponding
parameters are returned, e.g, q=0.01 will return the top 1%. Exactly
one of quantile or `eps` have to be passed.
return_distances: Whether to return the distances corresponding to
the accepted parameters.
return_distances: Whether to return the simulated data corresponding to
the accepted parameters.
lra: Whether to run linear regression adjustment as in Beaumont et al. 2002
sass: Whether to determine semi-automatic summary statistics as in
Fearnhead & Prangle 2012.
sass_fraction: Fraction of simulation budget used for the initial sass run.
sass_expansion_degree: Degree of the polynomial feature expansion for the
sass regression, default 1 - no expansion.
kde: Whether to run KDE on the accepted parameters to return a KDE
object from which one can sample.
kde_kwargs: kwargs for performing KDE:
'bandwidth='; either a float, or a string naming a bandwidth
heuristics, e.g., 'cv' (cross validation), 'silvermann' or 'scott',
default 'cv'.
'transform': transform applied to the parameters before doing KDE.
'sample_weights': weights associated with samples. See 'get_kde' for
more details
return_summary: Whether to return the distances and data corresponding to
the accepted parameters.
Returns:
posterior: Empirical distribution based on selected parameters.
distances: Tensor of distances of the selected parameters.
theta (if kde False): accepted parameters
kde (if kde True): KDE object based on accepted parameters from which one
can .sample() and .log_prob().
summary (if summary True): dictionary containing the accepted paramters (if
kde True), distances and simulated data x.
"""
# Exactly one of eps or quantile need to be passed.
assert (eps is not None) ^ (
Expand Down Expand Up @@ -152,19 +168,26 @@ def __call__(
# Maybe adjust theta with LRA.
if lra:
self.logger.info("Running Linear regression adjustment.")
theta_adjusted = self.run_lra(
theta_accepted, x_accepted, observation=self.x_o
)
final_theta = self.run_lra(theta_accepted, x_accepted, observation=self.x_o)
else:
theta_adjusted = theta_accepted
final_theta = theta_accepted

if kde:
self.logger.info(
f"""KDE on {final_theta.shape[0]} samples with bandwidth option
{kde_kwargs["bandwidth"]}. Beware that KDE can give unreliable
results when used with too few samples and in high dimensions."""
)

posterior = Empirical(theta_adjusted, log_weights=ones(theta_accepted.shape[0]))
kde_dist = get_kde(final_theta, **kde_kwargs)

if return_distances and return_x_accepted:
return posterior, distances_accepted, x_accepted
if return_distances:
return posterior, distances_accepted
if return_x_accepted:
return posterior, x_accepted
if return_summary:
return kde_dist, dict(
theta=final_theta, distances=distances_accepted, x=x_accepted
)
else:
return kde_dist
elif return_summary:
return final_theta, dict(distances=distances_accepted, x=x_accepted)
else:
return posterior
return final_theta
78 changes: 59 additions & 19 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Callable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
from numpy import ndarray
from pyro.distributions import Uniform
from pyro.distributions.empirical import Empirical
from torch import Tensor, ones, tensor
from torch.distributions import Distribution, Multinomial, MultivariateNormal

from sbi.inference.abc.abc_base import ABCBASE
from sbi.utils import within_support
from sbi.utils.user_input_checks import process_x
from sbi.utils import KDEWrapper, get_kde, process_x, within_support


class SMCABC(ABCBASE):
Expand Down Expand Up @@ -106,13 +104,21 @@ def __call__(
kernel_variance_scale: float = 1.0,
use_last_pop_samples: bool = True,
return_summary: bool = False,
kde: bool = False,
kde_kwargs: Dict[str, Any] = dict(
kde_bandwidth="cv",
kde_transform=None,
sample_weights=None,
num_cv_partitions=20,
num_cv_repetitions=5,
),
lra: bool = False,
lra_with_weights: bool = False,
sass: bool = False,
sass_fraction: float = 0.25,
sass_expansion_degree: int = 1,
) -> Union[Distribution, Tuple[Distribution, dict]]:
r"""Run SMCABC.
) -> Union[Tensor, KDEWrapper, Tuple[Tensor, dict], Tuple[KDEWrapper, dict]]:
r"""Run SMCABC and return accepted parameters or KDE object fitted on them.
Args:
x_o: Observed data.
Expand All @@ -130,8 +136,6 @@ def __call__(
samples from the previous population when the budget is used up. If
False, the current population is discarded and the previous population
is returned.
return_summary: Whether to return a dictionary with all accepted particles,
weights, etc. at the end.
lra: Whether to run linear regression adjustment as in Beaumont et al. 2002
lra_with_weights: Whether to run lra as weighted linear regression with SMC
weights
Expand All @@ -140,12 +144,25 @@ def __call__(
sass_fraction: Fraction of simulation budget used for the initial sass run.
sass_expansion_degree: Degree of the polynomial feature expansion for the
sass regression, default 1 - no expansion.
kde: Whether to run KDE on the accepted parameters to return a KDE
object from which one can sample.
kde_kwargs: kwargs for performing KDE:
'bandwidth='; either a float, or a string naming a bandwidth
heuristics, e.g., 'cv' (cross validation), 'silvermann' or 'scott',
default 'cv'.
'transform': transform applied to the parameters before doing KDE.
'sample_weights': weights associated with samples. See 'get_kde' for
more details
return_summary: Whether to return a dictionary with all accepted particles,
weights, etc. at the end.
Returns:
posterior: Empirical posterior distribution defined by the accepted
particles and their weights.
summary (optional): A dictionary containing particles, weights, epsilons
and distances of each population.
theta (if kde False): accepted parameters of the last population.
kde (if kde True): KDE object fitted on accepted parameters, from which one
can .sample() and .log_prob().
summary (if return_summary True): dictionary containing the accepted
paramters (if kde True), distances and simulated data x of all
populations.
"""

pop_idx = 0
Expand Down Expand Up @@ -239,20 +256,43 @@ def sass_simulator(theta):
# Maybe run LRA and adjust weights.
if lra:
self.logger.info("Running Linear regression adjustment.")
adjusted_particels, adjusted_weights = self.run_lra_update_weights(
adjusted_particles, adjusted_weights = self.run_lra_update_weights(
particles=all_particles[-1],
xs=all_x[-1],
observation=x_o,
log_weights=all_log_weights[-1],
lra_with_weights=lra_with_weights,
)
posterior = Empirical(adjusted_particels, log_weights=adjusted_weights)
final_particles = adjusted_particles
else:
posterior = Empirical(all_particles[-1], log_weights=all_log_weights[-1])
final_particles = all_particles[-1]

if kde:
self.logger.info(
f"""KDE on {final_particles.shape[0]} samples with bandwidth option
{kde_kwargs["bandwidth"]}. Beware that KDE can give unreliable
results when used with too few samples and in high dimensions."""
)

kde_dist = get_kde(final_particles, **kde_kwargs)

if return_summary:
return (
kde_dist,
dict(
particles=all_particles,
weights=all_log_weights,
epsilons=all_epsilons,
distances=all_distances,
xs=all_x,
),
)
else:
return kde_dist

if return_summary:
return (
posterior,
final_particles,
dict(
particles=all_particles,
weights=all_log_weights,
Expand All @@ -262,14 +302,14 @@ def sass_simulator(theta):
),
)
else:
return posterior
return final_particles

def _set_xo_and_sample_initial_population(
self,
x_o,
num_particles: int,
num_initial_pop: int,
) -> Tuple[Tensor, float, Tensor]:
) -> Tuple[Tensor, float, Tensor, Tensor]:
"""Return particles, epsilon and distances of initial population."""

assert (
Expand Down Expand Up @@ -307,7 +347,7 @@ def _sample_next_population(
epsilon: float,
x: Tensor,
use_last_pop_samples: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Return particles, weights and distances of new population."""

new_particles = []
Expand Down
10 changes: 9 additions & 1 deletion sbi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.


from typing import NewType, Sequence, Tuple, TypeVar, Union
from typing import NewType, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch
Expand All @@ -17,6 +17,13 @@

ScalarFloat = Union[torch.Tensor, float]

transform_types = Optional[
Union[
torch.distributions.transforms.Transform,
torch.distributions.transforms.ComposeTransform,
]
]

# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
Expand All @@ -29,4 +36,5 @@
"ScalarFloat",
"TensorboardSummaryWriter",
"TorchModule",
"transform_types",
]
2 changes: 2 additions & 0 deletions sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.io import get_data_root, get_log_root, get_project_root
from sbi.utils.kde import KDEWrapper, get_kde
from sbi.utils.plot import conditional_pairplot, pairplot
from sbi.utils.restriction_estimator import RestrictedPrior, RestrictionEstimator
from sbi.utils.sbiutils import (
Expand Down Expand Up @@ -61,6 +62,7 @@
)
from sbi.utils.user_input_checks import (
check_estimator_arg,
process_x,
test_posterior_net_for_multi_d_x,
validate_theta_and_x,
)
Expand Down
Loading

0 comments on commit 9ed18ca

Please sign in to comment.