Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNLE with experimental conditions #829

Merged
merged 3 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// Editor settings for python
//
"[python]": {
"editor.defaultFormatter": "ms-python.python",
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.sortImports": true
Expand All @@ -36,7 +36,7 @@
// Formatting
// https://code.visualstudio.com/docs/python/editing#_formatting
//
"python.formatting.provider": "black",
"python.formatting.provider": "none",
"python.formatting.blackArgs": [
"--line-length=88"
],
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.
from functools import partial
from math import ceil
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Union
from warnings import warn

import arviz as az
Expand Down Expand Up @@ -198,7 +198,7 @@ def sample(
sample_with: Optional[str] = None,
num_workers: Optional[int] = None,
show_progress_bars: bool = True,
) -> Union[Tensor, Tuple[Tensor, InferenceData]]:
) -> Tensor:
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.

Check the `__init__()` method for a description of all arguments as well as
Expand Down Expand Up @@ -452,7 +452,6 @@ def _slice_np_mcmc(

Returns:
Tensor of shape (num_samples, shape_of_single_theta).
Arviz InferenceData object.
"""

num_chains, dim_samples = initial_params.shape
Expand Down Expand Up @@ -516,7 +515,6 @@ def _pyro_mcmc(

Returns:
Tensor of shape (num_samples, shape_of_single_theta).
Arviz InferenceData object.
"""
num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains

Expand Down
28 changes: 27 additions & 1 deletion sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> MixedDensityEstimator:
density_estimator = super().train(
**del_entries(locals(), entries=("self", "__class__"))
)
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""Internal net must be of type
MixedDensityEstimator but is {type(density_estimator)}."""
return density_estimator

def build_posterior(
self,
density_estimator: Optional[TorchModule] = None,
Expand Down Expand Up @@ -128,7 +151,10 @@ def build_posterior(
), f"""net must be of type MixedDensityEstimator but is {type
(likelihood_estimator)}."""

potential_fn, theta_transform = mixed_likelihood_estimator_based_potential(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)

Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ def _log_prob_proposal_posterior_mog(
)
utils.assert_all_finite(
log_prob_proposal_posterior,
"""the evaluation of the MoG proposal posterior. This is likely due to a
numerical instability in the training procedure. Please create an issue on Github.""",
"""the evaluation of the MoG proposal posterior. This is likely due to a
numerical instability in the training procedure. Please create an issue on
Github.""",
)

return log_prob_proposal_posterior
Expand Down
6 changes: 2 additions & 4 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def test_api_sre_on_linearGaussian(num_dim: int, SNRE: RatioEstimator):
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))

simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
inference = SNRE(
classifier="resnet",
show_progress_bars=False,
)
inference = SNRE(classifier="resnet", show_progress_bars=False)

theta, x = simulate_for_sbi(simulator, prior, 1000, simulation_batch_size=50)
ratio_estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)
Expand All @@ -70,6 +67,7 @@ def test_api_sre_on_linearGaussian(num_dim: int, SNRE: RatioEstimator):
num_chains=2,
)
posterior.sample(sample_shape=(10,))
posterior.map(num_iter=1)


@pytest.mark.parametrize("SNRE", (SNRE_B, SNRE_C))
Expand Down
171 changes: 132 additions & 39 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,57 @@

import pytest
import torch
from numpy import isin
from pyro.distributions import InverseGamma
from torch.distributions import Beta, Binomial, Gamma
from torch.distributions import Beta, Binomial, Categorical, Gamma

from sbi.inference import MNLE, MCMCPosterior, likelihood_estimator_based_potential
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
MixedLikelihoodBasedPotential,
)
from sbi.utils import BoxUniform, likelihood_nn, mcmc_transform
from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st


# toy simulator for mixed data
def mixed_simulator(theta, stimulus_condition=2.0):
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(
concentration=stimulus_condition * torch.ones_like(beta), rate=beta
).sample()

return torch.cat((rts, choices), dim=1)


mcmc_kwargs = dict(
num_chains=10,
warmup_steps=100,
method="slice_np_vectorized",
init_strategy="proposal",
)


@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "cuda"))
def test_mnle_on_device(device):
# Generate mixed data.
num_simulations = 100
theta = torch.rand(num_simulations, 2)
x = torch.cat(
(torch.rand(num_simulations, 1), torch.randint(0, 2, (num_simulations, 1))),
(
torch.rand(num_simulations, 1),
torch.randint(0, 2, (num_simulations, 1)),
),
dim=1,
).to(device)

Expand All @@ -41,55 +73,49 @@ def test_mnle_api(sampler):
num_simulations = 100
theta = torch.rand(num_simulations, 2)
x = torch.cat(
(torch.rand(num_simulations, 1), torch.randint(0, 2, (num_simulations, 1))),
(
torch.rand(num_simulations, 1),
torch.randint(0, 2, (num_simulations, 1)),
),
dim=1,
)

# Train and infer.
prior = BoxUniform(torch.zeros(2), torch.ones(2))
x_o = x[0]
# Build estimator manually.
density_estimator = likelihood_nn(model="mnle", **dict(tail_bound=2.0))
density_estimator = likelihood_nn(model="mnle")
trainer = MNLE(density_estimator=density_estimator)
mnle = trainer.append_simulations(theta, x).train(max_num_epochs=1)
trainer.append_simulations(theta, x).train(max_num_epochs=5)

# Test different samplers.
posterior = trainer.build_posterior(prior=prior, sample_with=sampler)
posterior.set_default_x(x_o)
if sampler == "vi":
posterior.train()
posterior.sample((1,), show_progress_bars=False)

# MNLE should work with the default potential as well.
potential_fn, parameter_transform = likelihood_estimator_based_potential(
mnle, prior, x_o
)
posterior = MCMCPosterior(
potential_fn,
proposal=prior,
theta_transform=parameter_transform,
init_strategy="proposal",
)
posterior.sample((1,), show_progress_bars=False)
if isinstance(posterior, VIPosterior):
posterior.train().sample((1,))
elif isinstance(posterior, RejectionPosterior):
posterior.sample((1,))
else:
posterior.sample(
(1,),
num_chains=2,
warmup_steps=1,
method="slice_np_vectorized",
init_strategy="proposal",
thin=1,
)


@pytest.mark.slow
@pytest.mark.parametrize(
"sampler",
(
"mcmc",
"rejection",
# "vi", # Failing because of transformed space dimension mismatch.
),
)
@pytest.mark.parametrize("sampler", ("mcmc", "rejection", "vi"))
def test_mnle_accuracy(sampler):
def mixed_simulator(theta):
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(concentration=2 * torch.ones_like(beta), rate=beta).sample()
rts = InverseGamma(concentration=1 * torch.ones_like(beta), rate=beta).sample()

return torch.cat((rts, choices), dim=1)

Expand All @@ -111,13 +137,6 @@ def mixed_simulator(theta):
trainer.append_simulations(theta, x).train()
posterior = trainer.build_posterior()

mcmc_kwargs = dict(
num_chains=10,
warmup_steps=100,
method="slice_np_vectorized",
init_strategy="proposal",
)

for num_trials in [10]:
theta_o = prior.sample((1,))
x_o = mixed_simulator(theta_o.repeat(num_trials, 1))
Expand All @@ -138,7 +157,7 @@ def mixed_simulator(theta):

mnle_posterior_samples = posterior.sample(
sample_shape=(num_samples,),
show_progress_bars=False,
show_progress_bars=True,
**mcmc_kwargs if sampler == "mcmc" else {},
)

Expand All @@ -154,9 +173,11 @@ class PotentialFunctionProvider(BasePotential):

allow_iid_x = True # type: ignore

def __init__(self, prior, x_o, device="cpu"):
def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
super().__init__(prior, x_o, device)

self.concentration_scaling = concentration_scaling

def __call__(self, theta, track_gradients: bool = True):
theta = atleast_2d(theta)

Expand All @@ -179,7 +200,8 @@ def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
lp_rts = torch.stack(
[
InverseGamma(
concentration=2 * torch.ones_like(beta_i), rate=beta_i
concentration=self.concentration_scaling * torch.ones_like(beta_i),
rate=beta_i,
).log_prob(self.x_o[:, :1])
for beta_i in theta[:, :1]
],
Expand All @@ -191,3 +213,74 @@ def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
)

return joint_likelihood.sum(0)


@pytest.mark.slow
def test_mnle_with_experiment_conditions():
def sim_wrapper(theta):
# simulate with experiment conditions
return mixed_simulator(theta[:, :2], theta[:, 2:] + 1)

proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
Categorical(probs=torch.ones(1, 3)),
],
validate_args=False,
)

num_simulations = 10000
num_samples = 1000
theta = proposal.sample((num_simulations,))
x = sim_wrapper(theta)
assert x.shape == (num_simulations, 2)

num_trials = 10
theta_o = proposal.sample((1,))
theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.
x_o = sim_wrapper(theta_o.repeat(num_trials, 1))

# MNLE
trainer = MNLE(proposal)
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)

potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)

conditioned_potential_fn = ConditionedPotential(
potential_fn, condition=theta_o, dims_to_sample=[0, 1], allow_iid_x=True
)

# True posterior samples
prior = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
],
validate_args=False,
)
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
PotentialFunctionProvider(
prior,
atleast_2d(x_o),
concentration_scaling=float(theta_o[0, 2])
+ 1.0, # add one because the sim_wrapper adds one (see above)
),
theta_transform=prior_transform,
proposal=prior,
**mcmc_kwargs,
).sample((num_samples,), x=x_o)

mcmc_posterior = MCMCPosterior(
potential_fn=conditioned_potential_fn,
theta_transform=prior_transform,
proposal=prior,
)
cond_samples = mcmc_posterior.sample((num_samples,), x=x_o)

check_c2st(
cond_samples,
true_posterior_samples,
alg="MNLE with experiment conditions",
)
Loading