Skip to content

Commit

Permalink
Inverted the direction of the transform for MCMC
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jul 14, 2021
1 parent 82714dd commit a615d91
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 45 deletions.
8 changes: 4 additions & 4 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,12 @@ def sample_conditional(
condition = atleast_2d_float32_tensor(condition)

# Transform the `condition` to unconstrained space.
tf_condition = transform.inv(condition)
transformed_condition = transform(condition)
cond_potential_fn_provider = ConditionalPotentialFunctionProvider(
potential_fn_provider, tf_condition, dims_to_sample
potential_fn_provider, transformed_condition, dims_to_sample
)

tf_samples = self._sample_posterior_mcmc(
transformed_samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=cond_potential_fn_provider(
self._prior,
Expand Down Expand Up @@ -749,7 +749,7 @@ def sample_conditional(
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
samples = transform_dims_to_sample(tf_samples)
samples = transform_dims_to_sample.inv(transformed_samples)
elif sample_with == "rejection":
cond_potential_fn_provider = ConditionalPotentialFunctionProvider(
potential_fn_provider, condition, dims_to_sample
Expand Down
25 changes: 13 additions & 12 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def sample(
z-scored (and unconstrained) space.
rejection_sampling_parameters: Dictionary overriding the default parameters
for rejection sampling. The following parameters are supported:
`proposal` as the proposal distribtution (default is the prior).
`proposal` as the proposal distribtution (default is the trained
neural net).
`max_sampling_batch_size` as the batchsize of samples being drawn from
the proposal at every iteration.
`num_samples_to_find_max` as the number of samples that are used to
Expand Down Expand Up @@ -346,7 +347,7 @@ def sample(
transform = mcmc_transform(
self._prior, device=self._device, **mcmc_parameters
)
tf_samples = self._sample_posterior_mcmc(
transformed_samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=potential_fn_provider(
self._prior, self.net, x, mcmc_method, transform
Expand All @@ -363,7 +364,7 @@ def sample(
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
samples = transform(tf_samples)
samples = transform.inv(transformed_samples)
elif sample_with == "rejection":
rejection_sampling_parameters = (
self._potentially_replace_rejection_parameters(
Expand Down Expand Up @@ -607,13 +608,13 @@ def posterior_potential(
"""

# Device is the same for net and prior.
theta_tf = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to(
self.device
)
transformed_theta = ensure_theta_batched(
torch.as_tensor(theta, dtype=torch.float32)
).to(self.device)
# Transform `theta` from transformed (i.e. unconstrained) to untransformed
# space.
theta = self.transform(theta_tf)
log_abs_det = self.transform.log_abs_det_jacobian(theta_tf, theta)
theta = self.transform.inv(transformed_theta)
log_abs_det = self.transform.log_abs_det_jacobian(theta, transformed_theta)

theta_repeated, x_repeated = DirectPosterior._match_theta_and_x_batch_shapes(
theta, self.x
Expand All @@ -623,18 +624,18 @@ def posterior_potential(

# Evaluate on device, move back to cpu for comparison with prior.
posterior_log_prob = self.posterior_nn.log_prob(theta_repeated, x_repeated)
posterior_log_prob_tf = posterior_log_prob + log_abs_det
posterior_log_prob_transformed = posterior_log_prob - log_abs_det

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

posterior_log_prob_tf = torch.where(
posterior_log_prob_transformed = torch.where(
in_prior_support,
posterior_log_prob_tf,
posterior_log_prob_transformed,
torch.tensor(float("-inf"), dtype=torch.float32, device=self.device),
)

return posterior_log_prob_tf
return posterior_log_prob_transformed

def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
Expand Down
18 changes: 9 additions & 9 deletions sbi/inference/posteriors/likelihood_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def sample(
self._prior, device=self._device, **mcmc_parameters
)

tf_samples = self._sample_posterior_mcmc(
transformed_samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=potential_fn_provider(
self._prior, self.net, x, mcmc_method, transform
Expand All @@ -205,7 +205,7 @@ def sample(
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
samples = transform(tf_samples)
samples = transform.inv(transformed_samples)
elif sample_with == "rejection":
rejection_sampling_parameters = (
self._potentially_replace_rejection_parameters(
Expand Down Expand Up @@ -489,13 +489,13 @@ def posterior_potential(
"""

# Device is the same for net and prior.
theta_tf = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to(
self.device
)
transformed_theta = ensure_theta_batched(
torch.as_tensor(theta, dtype=torch.float32)
).to(self.device)
# Transform `theta` from transformed (i.e. unconstrained) to untransformed
# space.
theta = self.transform(theta_tf)
log_abs_det = self.transform.log_abs_det_jacobian(theta_tf, theta)
theta = self.transform.inv(transformed_theta)
log_abs_det = self.transform.log_abs_det_jacobian(theta, transformed_theta)

log_likelihoods = LikelihoodBasedPosterior._log_likelihoods_over_trials(
x=self.x,
Expand All @@ -504,8 +504,8 @@ def posterior_potential(
track_gradients=track_gradients,
)
posterior_potential = log_likelihoods + self.prior.log_prob(theta)
posterior_potential_tf = posterior_potential + log_abs_det
return posterior_potential_tf
posterior_potential_transformed = posterior_potential - log_abs_det
return posterior_potential_transformed

def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
Expand Down
18 changes: 9 additions & 9 deletions sbi/inference/posteriors/ratio_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def sample(
self._prior, device=self._device, **mcmc_parameters
)

tf_samples = self._sample_posterior_mcmc(
transformed_samples = self._sample_posterior_mcmc(
num_samples=num_samples,
potential_fn=potential_fn_provider(
self._prior, self.net, x, mcmc_method, transform
Expand All @@ -214,7 +214,7 @@ def sample(
show_progress_bars=show_progress_bars,
**mcmc_parameters,
)
samples = transform(tf_samples)
samples = transform.inv(transformed_samples)
elif sample_with == "rejection":
rejection_sampling_parameters = (
self._potentially_replace_rejection_parameters(
Expand Down Expand Up @@ -523,13 +523,13 @@ def posterior_potential(
"""

# Device is the same for net and prior.
theta_tf = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to(
self.device
)
transformed_theta = ensure_theta_batched(
torch.as_tensor(theta, dtype=torch.float32)
).to(self.device)
# Transform `theta` from transformed (i.e. unconstrained) to untransformed
# space.
theta = self.transform(theta_tf)
log_abs_det = self.transform.log_abs_det_jacobian(theta_tf, theta)
theta = self.transform.inv(transformed_theta)
log_abs_det = self.transform.log_abs_det_jacobian(theta, transformed_theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x,
Expand All @@ -538,8 +538,8 @@ def posterior_potential(
track_gradients=track_gradients,
)
posterior_potential = log_ratio + self.prior.log_prob(theta)
posterior_potential_tf = posterior_potential + log_abs_det
return posterior_potential_tf
posterior_potential_transformed = posterior_potential - log_abs_det
return posterior_potential_transformed

def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
Expand Down
10 changes: 5 additions & 5 deletions sbi/mcmc/init_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __call__(self):
def prior_init(prior: Any, transform: nflows.transforms, **kwargs: Any) -> Tensor:
"""Return a sample from the prior."""
prior_samples = prior.sample((1,)).detach()
prior_samples_tf = transform.inv(prior_samples)
return prior_samples_tf
transformed_prior_samples = transform(prior_samples)
return transformed_prior_samples


def sir(
Expand Down Expand Up @@ -61,9 +61,9 @@ def sir(
init_param_candidates = []
for i in range(sir_num_batches):
batch_draws = prior.sample((sir_batch_size,)).detach()
batch_draws_tf = transform.inv(batch_draws)
init_param_candidates.append(batch_draws_tf)
log_weights.append(potential_fn(batch_draws_tf.numpy()).detach())
transformed_batch_draws = transform(batch_draws)
init_param_candidates.append(transformed_batch_draws)
log_weights.append(potential_fn(transformed_batch_draws.numpy()).detach())
log_weights = torch.cat(log_weights)
init_param_candidates = torch.cat(init_param_candidates)

Expand Down
5 changes: 4 additions & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,9 @@ def mcmc_transform(
"""
Builds a transform that is applied to parameters during MCMC.
The resulting transform is defined such that the forward mapping maps from
constrained to unconstrained space.
It does two things:
1) When the prior support is bounded, it transforms the parameters into unbounded
space.
Expand Down Expand Up @@ -940,7 +943,7 @@ def mcmc_transform(
transform, reinterpreted_batch_ndims=1
)

return transform
return transform.inv


class ImproperEmpirical(Empirical):
Expand Down
11 changes: 6 additions & 5 deletions tests/sbiutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,13 @@ def test_mcmc_transform(prior, enable_transform):
Test whether the transform for MCMC returns the log_abs_det in the correct shape.
"""

num_samples = 1000
prior, _, _ = process_prior(prior)
tf = mcmc_transform(prior, enable_transform=enable_transform)

samples = prior.sample((1000,))
unconstrained_samples = tf.inv(samples)
samples_original = tf(unconstrained_samples)
samples = prior.sample((num_samples,))
unconstrained_samples = tf(samples)
samples_original = tf.inv(unconstrained_samples)

log_abs_det = tf.log_abs_det_jacobian(unconstrained_samples, samples_original)
assert log_abs_det.shape == torch.Size([1000])
log_abs_det = tf.log_abs_det_jacobian(samples_original, unconstrained_samples)
assert log_abs_det.shape == torch.Size([num_samples])

0 comments on commit a615d91

Please sign in to comment.