Skip to content

Commit

Permalink
DensityEstimator.loss does not take sample_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Apr 24, 2024
1 parent 005aeac commit 5f84775
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 104 deletions.
16 changes: 0 additions & 16 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sbi.inference.snle.snle_base import LikelihoodEstimator
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
Expand Down Expand Up @@ -196,18 +195,3 @@ def build_posterior(
self._model_bank.append(deepcopy(self._posterior))

return deepcopy(self._posterior)

# Temporary: need to rewrite mixed likelihood estimators as DensityEstimator
# objects.
# TODO: Fix and merge issue #968
def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$.
Returns:
Negative log prob.
"""
theta = reshape_to_batch_event(
theta, event_shape=self._neural_net.condition_shape
)
x = reshape_to_sample_batch_event(x, event_shape=self._neural_net.input_shape)
return -self._neural_net.log_prob(x, condition=theta)
5 changes: 1 addition & 4 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from sbi.neural_nets import DensityEstimator, likelihood_nn
from sbi.neural_nets.density_estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation

Expand Down Expand Up @@ -369,7 +368,5 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
theta = reshape_to_batch_event(
theta, event_shape=self._neural_net.condition_shape
)
x = reshape_to_sample_batch_event(
x, event_shape=self._neural_net.input_shape, leading_is_sample=False
)
x = reshape_to_batch_event(x, event_shape=self._neural_net.input_shape)
return self._neural_net.loss(x, condition=theta)
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def _loss(
distribution different from the prior.
"""
if self._round == 0 or force_first_round_loss:
theta = reshape_to_sample_batch_event(
theta = reshape_to_batch_event(
theta, event_shape=self._neural_net.input_shape
)
x = reshape_to_batch_event(x, event_shape=self._neural_net.condition_shape)
Expand Down
44 changes: 12 additions & 32 deletions sbi/neural_nets/density_estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,15 @@ def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
Args:
input: Inputs to evaluate the log probability on of shape
(*batch_shape1, input_size).
condition: Conditions of shape (*batch_shape2, *condition_shape).
`(sample_dim_input, batch_dim_input, *event_shape_input)`.
condition: Conditions of shape
`(batch_dim_condition, *event_shape_condition)`.
Raises:
RuntimeError: If batch_shape1 and batch_shape2 are not broadcastable.
RuntimeError: If batch_dim_input and batch_dim_condition do not match.
Returns:
Sample-wise log probabilities.
Note:
This function should support PyTorch's automatic broadcasting. This means
the function should behave as follows for different input and condition
shapes:
- (input_size,) + (batch_size,*condition_shape) -> (batch_size,)
- (batch_size, input_size) + (*condition_shape) -> (batch_size,)
- (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,)
- (batch_size1, input_size) + (batch_size2, *condition_shape)
-> RuntimeError i.e. not broadcastable
- (batch_size1,1, input_size) + (batch_size2, *condition_shape)
-> (batch_size1,batch_size2)
- (batch_size1, input_size) + (batch_size2,1, *condition_shape)
-> (batch_size2,batch_size1)
"""

raise NotImplementedError
Expand All @@ -77,11 +64,12 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the loss for training the density estimator.
Args:
input: Inputs to evaluate the loss on of shape (batch_size, input_size).
condition: Conditions of shape (batch_size, *condition_shape).
input: Inputs to evaluate the loss on of shape
`(batch_dim, *input_event_shape)`.
condition: Conditions of shape `(batch_dim, *event_shape_condition)`.
Returns:
Loss of shape (batch_size,)
Loss of shape (batch_dim,)
"""

raise NotImplementedError
Expand All @@ -91,17 +79,10 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso
Args:
sample_shape: Shape of the samples to return.
condition: Conditions of shape (*batch_shape, *condition_shape).
condition: Conditions of shape `(batch_dim, *event_shape_condition)`.
Returns:
Samples of shape (*batch_shape, *sample_shape, input_size).
Note:
This function should support batched conditions and should admit the
following behavior for different condition shapes:
- (*condition_shape) -> (*sample_shape, input_size)
- (*batch_shape, *condition_shape)
-> (*batch_shape, *sample_shape, input_size)
Samples of shape (*sample_shape, batch_dim, *event_shape_input).
"""

raise NotImplementedError
Expand All @@ -113,12 +94,11 @@ def sample_and_log_prob(
Args:
sample_shape: Shape of the samples to return.
condition: Conditions of shape (*batch_shape, *condition_shape).
condition: Conditions of shape `(batch_dim, *event_shape_condition)`.
Returns:
Samples and associated log probabilities.
Note:
For some density estimators, computing log_probs for samples is
more efficient than computing them separately. This method should
Expand All @@ -133,7 +113,7 @@ def _check_condition_shape(self, condition: Tensor):
r"""This method checks whether the condition has the correct shape.
Args:
condition: Conditions of shape (*batch_shape, *condition_shape).
condition: Conditions of shape `(batch_dim, *event_shape_condition)`.
Raises:
ValueError: If the condition has a dimensionality that does not match
Expand Down
30 changes: 15 additions & 15 deletions sbi/neural_nets/density_estimators/categorical_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,54 +56,54 @@ def __init__(

self.output_layer = nn.Linear(num_hidden, num_categories)

def forward(self, context: Tensor) -> Tensor:
def forward(self, condition: Tensor) -> Tensor:
"""Return categorical probability predicted from a batch of inputs.
Args:
context: batch of context parameters for the net.
condition: batch of context parameters for the net.
Returns:
Tensor: batch of predicted categorical probabilities.
"""
# forward path
context = self.activation(self.input_layer(context))
condition = self.activation(self.input_layer(condition))

# iterate n hidden layers, input context and calculate tanh activation
# iterate n hidden layers, input condition and calculate tanh activation
for layer in self.hidden_layers:
context = self.activation(layer(context))
condition = self.activation(layer(condition))

return self.softmax(self.output_layer(context))
return self.softmax(self.output_layer(condition))

def log_prob(self, input: Tensor, context: Tensor) -> Tensor:
"""Return categorical log probability of categories input, given context.
def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
"""Return categorical log probability of categories input, given condition.
Args:
input: categories to evaluate.
context: parameters.
condition: parameters.
Returns:
Tensor: log probs with shape (input.shape[0],)
"""
# Predict categorical ps and evaluate.
ps = self.forward(context)
ps = self.forward(condition)
# Squeeze dim=1 because `Categorical` has `event_shape=()` but our data usually
# has an event_shape of `(1,)`.
return Categorical(probs=ps).log_prob(input.squeeze(dim=1))

def sample(self, sample_shape: torch.Size, context: Tensor) -> Tensor:
def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor:
"""Returns samples from categorical random variable with probs predicted from
the neural net.
Args:
sample_shape: number of samples to obtain.
context: batch of parameters for prediction.
condition: batch of parameters for prediction.
Returns:
Tensor: Samples with shape (num_samples, 1)
"""

# Predict Categorical ps and sample.
ps = self.forward(context)
ps = self.forward(condition)
return Categorical(probs=ps).sample(sample_shape=sample_shape)


Expand Down Expand Up @@ -176,11 +176,11 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the loss for training the density estimator.
Args:
input: Inputs of shape `(sample_dim, batch_dim, *input_event_shape)`.
input: Inputs of shape `(batch_dim, *input_event_shape)`.
condition: Conditions of shape `(batch_dim, *condition_event_shape)`.
Returns:
Loss of shape `(batch_dim,)`
"""

return -self.log_prob(input, condition)
return -self.log_prob(input.unsqueeze(0), condition)[0]
11 changes: 10 additions & 1 deletion sbi/neural_nets/density_estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,16 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
return log_probs_combined

def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
return self.log_prob(input, condition)
r"""Return the loss for training the density estimator.
Args:
input: Inputs of shape `(batch_dim, *input_event_shape)`.
condition: Conditions of shape `(batch_dim, *condition_event_shape)`.
Returns:
Loss of shape `(batch_dim,)`
"""
return -self.log_prob(input.unsqueeze(0), condition)[0]

def log_prob_iid(self, input: Tensor, condition: Tensor) -> Tensor:
"""Return logprob given a batch of iid input and a different batch of condition.
Expand Down
25 changes: 5 additions & 20 deletions sbi/neural_nets/density_estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,6 @@ def inverse_transform(self, input: Tensor, condition: Tensor) -> Tensor:
Returns:
noise: Transformed inputs.
Note:
This function should support PyTorch's automatic broadcasting. This means
the function should behave as follows for different input and condition
shapes:
- (input_size,) + (batch_size,*condition_shape) -> (batch_size,)
- (batch_size, input_size) + (*condition_shape) -> (batch_size,)
- (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,)
- (batch_size1, input_size) + (batch_size2, *condition_shape)
-> RuntimeError i.e. not broadcastable
- (batch_size1,1, input_size) + (batch_size2, *condition_shape)
-> (batch_size1,batch_size2)
- (batch_size1, input_size) + (batch_size2,1, *condition_shape)
-> (batch_size2,batch_size1)
"""
self._check_condition_shape(condition)
condition_dims = len(self.condition_shape)
Expand Down Expand Up @@ -121,17 +107,16 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
return log_probs.reshape((input_sample_dim, input_batch_dim))

def loss(self, input: Tensor, condition: Tensor) -> Tensor:
r"""Return the loss for training the density estimator.
r"""Return the negative log-probability for training the density estimator.
Args:
input: Inputs to evaluate the loss on of shape
`(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(sample_dim, batch_dim, *event_dim)`.
input: Inputs of shape `(batch_dim, *input_event_shape)`.
condition: Conditions of shape `(batch_dim, *condition_event_shape)`.
Returns:
Negative log_probability of shape `(input_sample_dim, condition_batch_dim)`.
Negative log-probability of shape `(batch_dim,)`.
"""
return -self.log_prob(input, condition)
return -self.log_prob(input.unsqueeze(0), condition)[0]

def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
r"""Return samples from the density estimator.
Expand Down
11 changes: 5 additions & 6 deletions sbi/neural_nets/density_estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,17 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
return log_probs

def loss(self, input: Tensor, condition: Tensor) -> Tensor:
r"""Return the loss for training the density estimator.
r"""Return the negative log-probability for training the density estimator.
Args:
input: Inputs to evaluate the loss on of shape
`(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(sample_dim, batch_dim, *event_dim)`.
input: Inputs of shape `(batch_dim, *input_event_shape)`.
condition: Conditions of shape `(batch_dim, *condition_event_shape)`.
Returns:
Negative log_probability of shape `(input_sample_dim, condition_batch_dim)`.
Negative log-probability of shape `(batch_dim,)`.
"""

return -self.log_prob(input, condition)
return -self.log_prob(input.unsqueeze(0), condition)[0]

def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
r"""Return samples from the density estimator.
Expand Down
18 changes: 9 additions & 9 deletions tests/density_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def test_density_estimator_loss_shapes(
input_sample_dim,
)

losses = density_estimator.loss(inputs, condition=conditions)
assert losses.shape == (input_sample_dim, batch_dim)
losses = density_estimator.loss(inputs[0], condition=conditions)
assert losses.shape == (batch_dim,)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -193,8 +193,8 @@ def test_density_estimator_log_prob_shapes_with_embedding(
input_sample_dim,
)

losses = density_estimator.log_prob(inputs, condition=conditions)
assert losses.shape == (input_sample_dim, batch_dim)
log_probs = density_estimator.log_prob(inputs, condition=conditions)
assert log_probs.shape == (input_sample_dim, batch_dim)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_density_estimator_sample_shapes(
condition_event_shape,
batch_dim,
):
"""Test whether `loss` of DensityEstimators follow the shape convention."""
"""Test whether `sample` of DensityEstimators follow the shape convention."""
density_estimator, _, conditions = _build_density_estimator_and_tensors(
density_estimator_build_fn, input_event_shape, condition_event_shape, batch_dim
)
Expand Down Expand Up @@ -264,13 +264,13 @@ def test_density_estimator_sample_shapes(
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@pytest.mark.parametrize("batch_dim", (1, 10))
def test_correctness_of_density_estimator_loss(
def test_correctness_of_density_estimator_log_prob(
density_estimator_build_fn,
input_event_shape,
condition_event_shape,
batch_dim,
):
"""Test whether identical inputs lead to identical loss values."""
"""Test whether identical inputs lead to identical log_prob values."""
input_sample_dim = 2
density_estimator, inputs, condition = _build_density_estimator_and_tensors(
density_estimator_build_fn,
Expand All @@ -279,8 +279,8 @@ def test_correctness_of_density_estimator_loss(
batch_dim,
input_sample_dim,
)
losses = density_estimator.loss(inputs, condition=condition)
assert torch.allclose(losses[0, :], losses[1, :])
log_probs = density_estimator.log_prob(inputs, condition=condition)
assert torch.allclose(log_probs[0, :], log_probs[1, :])


def _build_density_estimator_and_tensors(
Expand Down

0 comments on commit 5f84775

Please sign in to comment.