Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape conventions for all DensityEstimators #1066

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
posterior_estimator_based_potential,
)
from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.sbi_types import Shape
from sbi.utils import check_prior, within_support
Expand Down Expand Up @@ -101,17 +105,13 @@ def sample(
"""

num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator._condition_shape
x = self._x_else_default_x(x)

try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err
# [1:] because we remove batch dimension for `reshape_to_batch_event`.
# Note: This line will break if `x_shape` is `None` and if `x` is passed without
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
# batch dimension.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
x = reshape_to_batch_event(x, event_shape=x_shape)
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand Down Expand Up @@ -171,24 +171,29 @@ def log_prob(
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
x = self._x_else_default_x(x)
condition_shape = self.posterior_estimator._condition_shape
try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err

# TODO Train exited here, entered after sampling?
self.posterior_estimator.eval()
# [1:] to remove batch dimension for `reshape_to_sample_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]

theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(x, x_shape)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."

self.posterior_estimator.eval()

with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(theta, condition=x)
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)
# `log_prob` supports only a single observation (i.e. `batchsize==1`).
janfb marked this conversation as resolved.
Show resolved Hide resolved
# We now remove this additional dimension.
unnorm_log_prob = unnorm_log_prob.squeeze(dim=1)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
Expand Down Expand Up @@ -238,14 +243,18 @@ def leakage_correction(
"""

def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
return accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
max_sampling_batch_size=rejection_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
proposal_sampling_kwargs={
"condition": reshape_to_batch_event(x, x_shape)
},
)[1]

# Check if the provided x matches the default x (short-circuit on identity).
Expand Down
62 changes: 46 additions & 16 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
Expand Down Expand Up @@ -110,28 +114,42 @@ def _log_likelihoods_over_trials(
Repeats `x` and $\theta$ to cover all their combinations of batch entries.

Args:
x: batch of iid data.
theta: batch of parameters.
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta: Batch of parameters of shape `(batch_dim, *event_shape)`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.

Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
"""
# unsqueeze to ensure that the x-batch dimension is the first dimension for the
# broadcasting of the density estimator.
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1)
# Shape of `x` is (iid_dim, *event_shape).
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
x = reshape_to_sample_batch_event(
x, event_shape=x.shape[1:], leading_is_sample=True
)

# Match the number of `x` to the number of conditions (`theta`). This is important
# if the potential is simulataneously evaluated at multiple `theta` (e.g.
# multi-chain MCMC).
theta_batch_size = theta.shape[0]
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_size, *trailing_minus_ones)

assert (
next(estimator.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: estimator, x, theta: \
{next(estimator.parameters()).device}, {x.device},
{theta.device}."""

# Shape of `theta` is (batch_dim, *event_shape). Therefore, the call below should
# not change anything, and we just have it as "best practice" before calling
# `DensityEstimator.log_prob`.
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
janfb marked this conversation as resolved.
Show resolved Hide resolved

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
log_likelihood_trial_batch = estimator.log_prob(x, condition=theta)
# Reshape to (-1, theta_batch_size), sum over trial-log likelihoods.
# Sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0)

return log_likelihood_trial_sum
Expand Down Expand Up @@ -170,28 +188,40 @@ def mixed_likelihood_estimator_based_potential(
class MixedLikelihoodBasedPotential(LikelihoodBasedPotential):
def __init__(
self,
likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
# TODO Fix pyright issue by making MixedDensityEstimator a subclass
# of DensityEstimator
super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore
super().__init__(likelihood_estimator, prior, x_o, device)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

# Shape of `x` is (iid_dim, *event_shape)
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(
self.x_o, event_shape=self.x_o.shape[1:], leading_is_sample=True
)
theta_batch_dim = theta.shape[0]
# Match the number of `x` to the number of conditions (`theta`). This is
# importantif the potential is simulataneously evaluated at multiple `theta`
# (e.g. multi-chain MCMC).
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_dim, *trailing_minus_ones)

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO: how to fix pyright issues?
log_likelihood_trial_batch = self.likelihood_estimator.log_prob_iid(
x=self.x_o,
context=theta.to(self.device),
) # type: ignore
# TODO log_prob_iid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please link these lines in the MNLE IID issue. so that I don't forget it 🙏

log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
self.x_o.shape[0], -1
).sum(0)

return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
return log_likelihood_trial_sum + prior_log_prob
22 changes: 18 additions & 4 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import within_support
Expand Down Expand Up @@ -98,15 +102,25 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
the potential or manually set self._x_o."
)

theta = ensure_theta_batched(torch.as_tensor(theta))
theta, x = theta.to(self.device), self.x_o.to(self.device)
theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device)

with torch.set_grad_enabled(track_gradients):
posterior_log_prob = self.posterior_estimator.log_prob(theta, condition=x)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

x = reshape_to_batch_event(self.x_o, event_shape=self.x_o.shape[1:])
assert (
x.shape[0] == 1
), f"`x` has batchsize {x.shape[0]}. Only `batchsize == 1` is supported."
theta = reshape_to_sample_batch_event(
theta, event_shape=theta.shape[1:], leading_is_sample=True
)
# We assume that a single `x` is passed (i.e. batchsize==1), so we squeeze
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
# the batch dimension of the log-prob with `.squeeze(dim=1)`.
posterior_log_prob = self.posterior_estimator.log_prob(
theta, condition=x
).squeeze(dim=1)

posterior_log_prob = torch.where(
in_prior_support,
posterior_log_prob,
Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _log_ratios_over_trials(
Repeats `x` and $\theta$ to cover all their combinations of batch entries.

Args:
x: batch of iid data.
theta: batch of parameters
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
theta: Batch of parameters of shape `(batch_dim, *event_shape)`.
net: neural net representing the classifier to approximate the ratio.
track_gradients: Whether to track gradients.

Returns:
log_ratio_trial_sum: log ratio for each parameter, summed over all
batch entries (iid trials) in `x`.
Expand Down
8 changes: 7 additions & 1 deletion sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.snle.snle_base import LikelihoodEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils import check_prior, del_entries
Expand Down Expand Up @@ -205,4 +209,6 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
return -self._neural_net.log_prob(x, context=theta)
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(x, event_shape=self._x_shape[1:])
return -self._neural_net.log_prob(x, condition=theta)
8 changes: 8 additions & 0 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.neural_nets import DensityEstimator, likelihood_nn
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation


Expand Down Expand Up @@ -366,4 +370,8 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(
x, event_shape=self._x_shape[1:], leading_is_sample=False
)
return self._neural_net.loss(x, condition=theta)
16 changes: 13 additions & 3 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils import torchutils
from sbi.utils.torchutils import atleast_2d


class SNPE_A(PosteriorEstimator):
Expand Down Expand Up @@ -408,12 +409,14 @@
if isinstance(proposal, (utils.BoxUniform, MultivariateNormal)):
self._apply_correction = False
else:
# Add iid dimension.
default_x = proposal.default_x # type: ignore

Check warning on line 413 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L413

Added line #L413 was not covered by tests
self._apply_correction = True
(
logits_pp,
m_pp,
prec_pp,
) = proposal.posterior_estimator._posthoc_correction(proposal.default_x) # type: ignore
) = proposal.posterior_estimator._posthoc_correction(default_x)
self._logits_pp, self._m_pp, self._prec_pp = (
logits_pp.detach(),
m_pp.detach(),
Expand Down Expand Up @@ -544,17 +547,24 @@
estimator and the proposal.

Args:
x: Conditioning context for posterior.
x: Conditioning context for posterior, shape
`(batch_dim, *event_shape)`.

Returns:
Mixture components of the posterior.
"""
# Remove the batch dimension of `x` (SNPE-A always has a single `x`).
assert (

Check warning on line 557 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L557

Added line #L557 was not covered by tests
x.shape[0] == 1
), f"Batchsize of `x_o` == {x.shape[0]}. SNPE-A only supports a single `x_o`."
x = x.squeeze(dim=0)

Check warning on line 560 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L560

Added line #L560 was not covered by tests
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved

# Evaluate the density estimator.
embedded_x = self._neural_net.net._embedding_net(x)
dist = self._neural_net.net._distribution # defined to avoid black formatting.
logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(embedded_x)
norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True)
norm_logits_d = atleast_2d(norm_logits_d)

Check warning on line 567 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L567

Added line #L567 was not covered by tests
janfb marked this conversation as resolved.
Show resolved Hide resolved

# The following if case is needed because, in the constructor, we call
# `_posthoc_correction` regardless of whether the `proposal` itself had a
Expand All @@ -572,6 +582,7 @@
logits_p, m_p, prec_p, cov_p = self._proposal_posterior_transformation(
logits_pp, m_pp, prec_pp, norm_logits_d, m_d, prec_d
)
logits_p = atleast_2d(logits_p)

Check warning on line 585 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L585

Added line #L585 was not covered by tests
return logits_p, m_p, prec_p

def _proposal_posterior_transformation(
Expand Down Expand Up @@ -606,7 +617,6 @@
Returns: (Component weight, mean, precision matrix, covariance matrix) of each
Gaussian of the approximate posterior.
"""

precisions_post, covariances_post = self._precisions_posterior(
precisions_pp, precisions_d
)
Expand Down
Loading