Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Apr 9, 2024
1 parent cceb940 commit 89f401c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 22 deletions.
5 changes: 3 additions & 2 deletions sbi/neural_nets/density_estimators/categorical_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
"""Return log-probability of samples.
Args:
input: Input datapoints of shape `(sample_dim, batch_dim, *event_shape_input)`.
Must be a discrete indicator of class identity.
input: Input datapoints of shape
`(sample_dim, batch_dim, *event_shape_input)`.Must be a discrete
indicator of class identity.
condition: Conditions of shape `(batch_dim, *event_shape_condition)`.
Returns:
Expand Down
37 changes: 22 additions & 15 deletions sbi/neural_nets/density_estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(
Args:
discrete_net: neural net to model discrete part of the data.
continuous_net: neural net to model the continuous data.
log_transform_input: whether to transform the continous part of the data into
logarithmic domain before training. This is helpful for bounded data, e.
g.,for reaction times.
log_transform_input: whether to transform the continous part of the data
into logarithmic domain before training. This is helpful for bounded
data, e.g.,for reaction times.
"""
super(MixedDensityEstimator, self).__init__(
net=continuous_net, condition_shape=condition_shape
Expand Down Expand Up @@ -138,9 +138,11 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
))
cont_log_prob = self.continuous_net.log_prob(
# Transform to log-space if needed.
torch.log(cont_input_reshaped)
if self.log_transform_input
else cont_input_reshaped,
(
torch.log(cont_input_reshaped)
if self.log_transform_input
else cont_input_reshaped
),
condition=condition_reshaped,
)
cont_log_prob = cont_log_prob.reshape(disc_log_prob.shape)
Expand All @@ -158,7 +160,7 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
return self.log_prob(input, condition)

def log_prob_iid(self, input: Tensor, condition: Tensor) -> Tensor:
"""Return log prob given a batch of iid input and a different batch of condition.
"""Return logprob given a batch of iid input and a different batch of condition.
This is different from `.log_prob()` to enable speed ups in evaluation during
inference. The speed up is achieved by exploiting the fact that there are only
Expand All @@ -172,10 +174,10 @@ def log_prob_iid(self, input: Tensor, condition: Tensor) -> Tensor:
`.log_prob()` would pass `1000 * num_conditions`.
Args:
input: batch of iid data, data observed given the same underlying parameters or
experimental conditions.
condition: batch of parameters to be evaluated, i.e., each batch entry will be
evaluated for the entire batch of iid input.
input: batch of iid data, data observed given the same underlying parameters
or experimental conditions.
condition: batch of parameters to be evaluated, i.e., each batch entry will
be evaluated for the entire batch of iid input.
Returns:
log probs with shape (num_trials, num_parameters), i.e., the log prob for
Expand All @@ -192,7 +194,10 @@ def log_prob_iid(self, input: Tensor, condition: Tensor) -> Tensor:
net_device = next(self.discrete_net.parameters()).device
assert (
net_device == input.device and input.device == condition.device
), f"device mismatch: net, x, condition: {net_device}, {input.device}, {condition.device}."
), (
f"device mismatch: net, x, condition: "
f"{net_device}, {input.device}, {condition.device}."
)

input_cont_repeated, input_disc_repeated = _separate_input(input_repeated)
input_cont, input_disc = _separate_input(input)
Expand Down Expand Up @@ -220,9 +225,11 @@ def log_prob_iid(self, input: Tensor, condition: Tensor) -> Tensor:

# Get repeat discrete data and condition to match in batch shape for flow eval.
log_probs_cont = self.continuous_net.log_prob(
torch.log(input_cont_repeated)
if self.log_transform_input
else input_cont_repeated,
(
torch.log(input_cont_repeated)
if self.log_transform_input
else input_cont_repeated
),
condition=torch.cat((condition_repeated, input_disc_repeated), dim=1),
)

Expand Down
6 changes: 3 additions & 3 deletions sbi/neural_nets/density_estimators/shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def reshape_to_sample_batch_event(
- (sample, batch, event)
event_shape: The shape of a single datapoint (without batch dimension or sample
dimension).
leading_is_sample: Used only if `theta_or_x` has exactly one dimension beyond the
`event` dims. Defines whether the leading dimension is interpreted as batch
dimension or as sample dimension.
leading_is_sample: Used only if `theta_or_x` has exactly one dimension beyond
the `event` dims. Defines whether the leading dimension is interpreted as
batch dimension or as sample dimension.
Returns:
A tensor of shape `(sample, batch, event)`.
Expand Down
5 changes: 4 additions & 1 deletion sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,13 +749,16 @@ def validate_theta_and_x(


def test_posterior_net_for_multi_d_x(
net: "DensityEstimator", theta: Tensor, x: Tensor
net, theta: Tensor, x: Tensor
) -> None:
"""Test log prob method of the net.
This is done to make sure the net can handle multidimensional inputs via an
embedding net. If not, it usually fails with a RuntimeError. Here we catch the
error, append a debug hint and raise it again.
Args:
net: A `DensityEstimator`.
"""
try:
# torch.nn.functional needs at least two inputs here.
Expand Down
2 changes: 1 addition & 1 deletion tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def simulator(parameter_set):
prior,
method="SNPE_A",
num_simulations=10,
init_kwargs={'num_components': 5},
init_kwargs={"num_components": 5},
train_kwargs={"max_num_epochs": 2},
build_posterior_kwargs={"prior": prior},
)
Expand Down

0 comments on commit 89f401c

Please sign in to comment.