Skip to content

Commit

Permalink
Code revision and fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jul 1, 2022
1 parent 607c3da commit ce389f8
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 52 deletions.
8 changes: 5 additions & 3 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
generated with sampling importance resampling (SIR). With
`importance`, the `.sample()` method returns a tuple of samples and
corresponding importance weights.
oversampling_factor: Number of proposed samples form which only one is
oversampling_factor: Number of proposed samples from which only one is
selected based on its importance weight.
max_sampling_batch_size: The batchsize of samples being drawn from the
max_sampling_batch_size: The batch size of samples being drawn from the
proposal at every iteration.
device: Device on which to sample, e.g., "cpu", "cuda" or "cuda:0". If
None, `potential_fn.device` is used.
Expand Down Expand Up @@ -119,10 +119,12 @@ def log_prob(
def estimate_normalization_constant(
self, x: Tensor, num_samples: int = 10_000, force_update: bool = False
) -> Tensor:
"""Returns the normalization constant with importance sampling.
"""Returns the normalization constant via importance sampling.
Args:
num_samples: Number of importance samples used for the estimate.
force_update: Whether to re-calculate the normlization constant when x is
unchanged and have a cached value.
"""
# Check if the provided x matches the default x (short-circuit on identity).
is_new_x = self.default_x is None or (
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def _get_initial_params(
)

# Parallelize inits for resampling only.
if num_workers > 1 and init_strategy == "resample":
if num_workers > 1 and (init_strategy == "resample" or init_strategy == "sir"):

def seeded_init_fn(seed):
torch.manual_seed(seed)
Expand Down
11 changes: 9 additions & 2 deletions sbi/samplers/importance/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def importance_sample(
def exponentiate_weights(log_weights: Tensor) -> Tensor:
"""Subtracts the maximum of the `log_weights` and then exponentiates them.
It also filters out infinite `log_weights`, thus the input and output shape can
differ.
Args:
log_weights: Logarithm of the importance weights.
Expand All @@ -48,11 +51,15 @@ def largest_weight_indices(weights: Tensor) -> Tensor:
"""Returns the indizes of the largest weights.
Args:
weights: Importance weights.
weights: Weights of which to return the largest indices. Usually importance
weights.
Returns:
Tensor: The indices of the largest importance weights.
"""
# Compute number of weights that are used for estimating the Pareto distribution.
# Vehtari, Gelman, Gabry, 2017.
# Yao, Vehtari, Simpson, Gelman, 2018
number_of_weights = int(min(len(weights) / 5, 3 * sqrt(len(weights))))
_, inds = weights.sort()
return inds[-number_of_weights:]
Expand All @@ -61,7 +68,7 @@ def largest_weight_indices(weights: Tensor) -> Tensor:
def gpdfit(
x: Tensor, sorted: bool = True, eps: float = 1e-8, return_quadrature: bool = False
) -> Tuple:
"""Maximum aposteriori estimate of a Generalized Paretto distribution.
"""Maximum a posteriori estimate of a Generalized Paretto distribution.
Pytorch version of gpdfit according to
https://github.com/avehtari/PSIS/blob/master/py/psis.py. This function will compute
Expand Down
3 changes: 2 additions & 1 deletion sbi/samplers/importance/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def sampling_importance_resampling(
"""

selected_samples = []
max_sampling_batch_size = int(max_sampling_batch_size / oversampling_factor)

max_sampling_batch_size = max(1, int(max_sampling_batch_size / oversampling_factor))
sampling_batch_size = min(num_samples, max_sampling_batch_size)

num_remaining = num_samples
Expand Down
1 change: 1 addition & 0 deletions sbi/samplers/mcmc/init_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def sir_init(
oversampling_factor=sir_num_batches * sir_batch_size,
max_sampling_batch_size=sir_batch_size,
)
print(sample.shape)
return transform(sample) # type: ignore


Expand Down
12 changes: 4 additions & 8 deletions tests/inference_with_NaN_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,10 @@ def test_z_scoring_warning(snpe_method: type):

@pytest.mark.slow
@pytest.mark.parametrize(
("method", "exclude_invalid_x", "percent_nans"),
((SNPE_C, True, 0.05), (SNL, True, 0.05), (SRE, True, 0.05)),
("method", "percent_nans"),
((SNPE_C, 0.05), (SNL, 0.05), (SRE, 0.05)),
)
def test_inference_with_nan_simulator(
method: type, exclude_invalid_x: bool, percent_nans: float
):
def test_inference_with_nan_simulator(method: type, percent_nans: float):

# likelihood_mean will be likelihood_shift+theta
num_dim = 3
Expand Down Expand Up @@ -102,9 +100,7 @@ def linear_gaussian_nan(
inference = method(prior=prior)

theta, x = simulate_for_sbi(simulator, prior, num_simulations)
_ = inference.append_simulations(
theta, x, exclude_invalid_x=exclude_invalid_x
).train()
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()

samples = posterior.sample((num_samples,), x=x_o)
Expand Down
78 changes: 41 additions & 37 deletions tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_c2st_multi_round_snl_on_linearGaussian_vi(num_trials: int):
("importance", "gaussian"),
),
)
@pytest.mark.parametrize("init_strategy", ("proposal", "resample"))
@pytest.mark.parametrize("init_strategy", ("proposal", "resample", "sir"))
def test_api_snl_sampling_methods(
sampling_method: str, prior_str: str, init_strategy: str
):
Expand Down Expand Up @@ -410,44 +410,48 @@ def test_api_snl_sampling_methods(
else:
prior = utils.BoxUniform(-1.0 * ones(num_dim), ones(num_dim))

simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
inference = SNLE(show_progress_bars=False)
# Why do we have this if-case? Only the `MCMCPosterior` uses the `init_strategy`.
# Thus, we would not like to run, e.g., VI with all init_strategies, but only once
# (namely with `init_strategy=proposal`).
if sample_with == "mcmc" or init_strategy == "proposal":
simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
inference = SNLE(show_progress_bars=False)

theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
)
likelihood_estimator = inference.append_simulations(theta, x).train(
max_num_epochs=5
)
potential_fn, theta_transform = likelihood_estimator_based_potential(
prior=prior, likelihood_estimator=likelihood_estimator, x_o=x_o
)
if sample_with == "rejection":
posterior = RejectionPosterior(potential_fn=potential_fn, proposal=prior)
elif (
"slice" in sampling_method
or "nuts" in sampling_method
or "hmc" in sampling_method
):
posterior = MCMCPosterior(
potential_fn,
proposal=prior,
theta_transform=theta_transform,
method=sampling_method,
thin=3,
num_chains=num_chains,
init_strategy=init_strategy,
theta, x = simulate_for_sbi(
simulator, prior, num_simulations, simulation_batch_size=1000
)
elif sample_with == "importance":
posterior = ImportanceSamplingPosterior(
potential_fn,
proposal=prior,
theta_transform=theta_transform,
likelihood_estimator = inference.append_simulations(theta, x).train(
max_num_epochs=5
)
else:
posterior = VIPosterior(
potential_fn, theta_transform=theta_transform, vi_method=sampling_method
potential_fn, theta_transform = likelihood_estimator_based_potential(
prior=prior, likelihood_estimator=likelihood_estimator, x_o=x_o
)
posterior.train(max_num_iters=10)
if sample_with == "rejection":
posterior = RejectionPosterior(potential_fn=potential_fn, proposal=prior)
elif (
"slice" in sampling_method
or "nuts" in sampling_method
or "hmc" in sampling_method
):
posterior = MCMCPosterior(
potential_fn,
proposal=prior,
theta_transform=theta_transform,
method=sampling_method,
thin=3,
num_chains=num_chains,
init_strategy=init_strategy,
)
elif sample_with == "importance":
posterior = ImportanceSamplingPosterior(
potential_fn,
proposal=prior,
theta_transform=theta_transform,
)
else:
posterior = VIPosterior(
potential_fn, theta_transform=theta_transform, vi_method=sampling_method
)
posterior.train(max_num_iters=10)

posterior.sample(sample_shape=(num_samples,))
posterior.sample(sample_shape=(num_samples,))

0 comments on commit ce389f8

Please sign in to comment.