Skip to content

Commit

Permalink
New shape conventions for all DensityEstimators
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Apr 9, 2024
1 parent 46db263 commit a3dce4e
Show file tree
Hide file tree
Showing 19 changed files with 733 additions and 610 deletions.
53 changes: 31 additions & 22 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
posterior_estimator_based_potential,
)
from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.sbi_types import Shape
from sbi.utils import check_prior, within_support
Expand Down Expand Up @@ -101,17 +105,13 @@ def sample(
"""

num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator._condition_shape
x = self._x_else_default_x(x)

try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err
# [1:] because we remove batch dimension for `reshape_to_batch_event`.
# Note: This line will break if `x_shape` is `None` and if `x` is passed without
# batch dimension.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
x = reshape_to_batch_event(x, event_shape=x_shape)

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand Down Expand Up @@ -171,24 +171,29 @@ def log_prob(
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
x = self._x_else_default_x(x)
condition_shape = self.posterior_estimator._condition_shape
try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err

# TODO Train exited here, entered after sampling?
self.posterior_estimator.eval()
# [1:] to remove batch dimension for `reshape_to_sample_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]

theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(x, x_shape)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."

self.posterior_estimator.eval()

with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(theta, condition=x)
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)
# `log_prob` supports only a single observation (i.e. `batchsize==1`).
# We now remove this additional dimension.
unnorm_log_prob = unnorm_log_prob.squeeze(dim=1)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
Expand Down Expand Up @@ -238,14 +243,18 @@ def leakage_correction(
"""

def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
return accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
max_sampling_batch_size=rejection_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
proposal_sampling_kwargs={
"condition": reshape_to_batch_event(x, x_shape)
},
)[1]

# Check if the provided x matches the default x (short-circuit on identity).
Expand Down
62 changes: 46 additions & 16 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
Expand Down Expand Up @@ -110,28 +114,42 @@ def _log_likelihoods_over_trials(
Repeats `x` and $\theta$ to cover all their combinations of batch entries.
Args:
x: batch of iid data.
theta: batch of parameters.
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta: Batch of parameters of shape `(batch_dim, *event_shape)`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
"""
# unsqueeze to ensure that the x-batch dimension is the first dimension for the
# broadcasting of the density estimator.
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1)
# Shape of `x` is (iid_dim, *event_shape).
x = reshape_to_sample_batch_event(
x, event_shape=x.shape[1:], leading_is_sample=True
)

# Match the number of `x` to the number of conditions (`theta`). This is important
# if the potential is simulataneously evaluated at multiple `theta` (e.g.
# multi-chain MCMC).
theta_batch_size = theta.shape[0]
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_size, *trailing_minus_ones)

assert (
next(estimator.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: estimator, x, theta: \
{next(estimator.parameters()).device}, {x.device},
{theta.device}."""

# Shape of `theta` is (batch_dim, *event_shape). Therefore, the call below should
# not change anything, and we just have it as "best practice" before calling
# `DensityEstimator.log_prob`.
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
log_likelihood_trial_batch = estimator.log_prob(x, condition=theta)
# Reshape to (-1, theta_batch_size), sum over trial-log likelihoods.
# Sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0)

return log_likelihood_trial_sum
Expand Down Expand Up @@ -170,28 +188,40 @@ def mixed_likelihood_estimator_based_potential(
class MixedLikelihoodBasedPotential(LikelihoodBasedPotential):
def __init__(
self,
likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
# TODO Fix pyright issue by making MixedDensityEstimator a subclass
# of DensityEstimator
super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore
super().__init__(likelihood_estimator, prior, x_o, device)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

# Shape of `x` is (iid_dim, *event_shape)
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(
self.x_o, event_shape=self.x_o.shape[1:], leading_is_sample=True
)
theta_batch_dim = theta.shape[0]
# Match the number of `x` to the number of conditions (`theta`). This is
# importantif the potential is simulataneously evaluated at multiple `theta`
# (e.g. multi-chain MCMC).
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_dim, *trailing_minus_ones)

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO: how to fix pyright issues?
log_likelihood_trial_batch = self.likelihood_estimator.log_prob_iid(
x=self.x_o,
context=theta.to(self.device),
) # type: ignore
# TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
self.x_o.shape[0], -1
).sum(0)

return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
return log_likelihood_trial_sum + prior_log_prob
22 changes: 18 additions & 4 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import within_support
Expand Down Expand Up @@ -98,15 +102,25 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
the potential or manually set self._x_o."
)

theta = ensure_theta_batched(torch.as_tensor(theta))
theta, x = theta.to(self.device), self.x_o.to(self.device)
theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device)

with torch.set_grad_enabled(track_gradients):
posterior_log_prob = self.posterior_estimator.log_prob(theta, condition=x)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

x = reshape_to_batch_event(self.x_o, event_shape=self.x_o.shape[1:])
assert (
x.shape[0] == 1
), f"`x` has batchsize {x.shape[0]}. Only `batchsize == 1` is supported."
theta = reshape_to_sample_batch_event(
theta, event_shape=theta.shape[1:], leading_is_sample=True
)
# We assume that a single `x` is passed (i.e. batchsize==1), so we squeeze
# the batch dimension of the log-prob with `.squeeze(dim=1)`.
posterior_log_prob = self.posterior_estimator.log_prob(
theta, condition=x
).squeeze(dim=1)

posterior_log_prob = torch.where(
in_prior_support,
posterior_log_prob,
Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _log_ratios_over_trials(
Repeats `x` and $\theta$ to cover all their combinations of batch entries.
Args:
x: batch of iid data.
theta: batch of parameters
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta: Batch of parameters of shape `(batch_dim, *event_shape)`.
net: neural net representing the classifier to approximate the ratio.
track_gradients: Whether to track gradients.
Returns:
log_ratio_trial_sum: log ratio for each parameter, summed over all
batch entries (iid trials) in `x`.
Expand Down
8 changes: 7 additions & 1 deletion sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.snle.snle_base import LikelihoodEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils import check_prior, del_entries
Expand Down Expand Up @@ -205,4 +209,6 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
return -self._neural_net.log_prob(x, context=theta)
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(x, event_shape=self._x_shape[1:])
return -self._neural_net.log_prob(x, condition=theta)
8 changes: 8 additions & 0 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.neural_nets import DensityEstimator, likelihood_nn
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation


Expand Down Expand Up @@ -366,4 +370,8 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(
x, event_shape=self._x_shape[1:], leading_is_sample=False
)
return self._neural_net.loss(x, condition=theta)
16 changes: 13 additions & 3 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils import torchutils
from sbi.utils.torchutils import atleast_2d


class SNPE_A(PosteriorEstimator):
Expand Down Expand Up @@ -408,12 +409,14 @@ def __init__(
if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)):
self._apply_correction = False
else:
# Add iid dimension.
default_x = proposal.default_x # type: ignore

Check warning on line 413 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L413

Added line #L413 was not covered by tests
self._apply_correction = True
(
logits_pp,
m_pp,
prec_pp,
) = proposal.posterior_estimator._posthoc_correction(proposal.default_x) # type: ignore
) = proposal.posterior_estimator._posthoc_correction(default_x)
self._logits_pp, self._m_pp, self._prec_pp = (
logits_pp.detach(),
m_pp.detach(),
Expand Down Expand Up @@ -544,17 +547,24 @@ def _posthoc_correction(self, x: Tensor):
estimator and the proposal.
Args:
x: Conditioning context for posterior.
x: Conditioning context for posterior, shape
`(batch_dim, *event_shape)`.
Returns:
Mixture components of the posterior.
"""
# Remove the batch dimension of `x` (SNPE-A always has a single `x`).
assert (

Check warning on line 557 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L557

Added line #L557 was not covered by tests
x.shape[0] == 1
), f"Batchsize of `x_o` == {x.shape[0]}. SNPE-A only supports a single `x_o`."
x = x.squeeze(dim=0)

Check warning on line 560 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L560

Added line #L560 was not covered by tests

# Evaluate the density estimator.
embedded_x = self._neural_net.net._embedding_net(x)
dist = self._neural_net.net._distribution # defined to avoid black formatting.
logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(embedded_x)
norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True)
norm_logits_d = atleast_2d(norm_logits_d)

Check warning on line 567 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L567

Added line #L567 was not covered by tests

# The following if case is needed because, in the constructor, we call
# `_posthoc_correction` regardless of whether the `proposal` itself had a
Expand All @@ -572,6 +582,7 @@ def _posthoc_correction(self, x: Tensor):
logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation(
logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d
)
logits_p = atleast_2d(logits_p)

Check warning on line 585 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L585

Added line #L585 was not covered by tests
return logits_p, m_p, prec_p

def _proposal_posterior_transformation(
Expand Down Expand Up @@ -606,7 +617,6 @@ def _proposal_posterior_transformation(
Returns: (Component weight, mean, precision matrix, covariance matrix) of each
Gaussian of the approximate posterior.
"""

precisions_post, covariances_post = self._precisions_posterior(
precisions_pp, precisions_d
)
Expand Down
Loading

0 comments on commit a3dce4e

Please sign in to comment.