Skip to content

Commit

Permalink
made the RatioEstimator abstraction apply to existing usecases
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmi committed Apr 9, 2024
1 parent b037965 commit 3b8afae
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _log_ratios_over_trials(

# Calculate ratios in one batch.
with torch.set_grad_enabled(track_gradients):
log_ratio_trial_batch = net([theta_repeated, x_repeated])
log_ratio_trial_batch = net(theta_repeated, x_repeated)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_ratio_trial_sum = log_ratio_trial_batch.reshape(x.shape[0], -1).sum(0)

Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor
batch_size * num_atoms, -1
)

return self._neural_net([atomic_theta, repeated_x])
return self._neural_net(atomic_theta, repeated_x)

@abstractmethod
def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
Expand Down
29 changes: 17 additions & 12 deletions sbi/neural_nets/ratio_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def __init__(
def combine_embedded_theta_and_x(
self, embedded_theta: Tensor, embedded_x: Tensor
) -> Tensor:
"""combine embedded theta and embedded x"""
return None
r"""Combine embedded theta and embedded x sensibly for the data type.
@abstractmethod
def embed_and_combine_theta_and_x(self, theta: Tensor, x: Tensor) -> Tensor:
Args:
embedded_theta: theta after embedding
embedded_x: x after embedding
Returns:
Single object containing both embedded_theta and embedded_x
"""
return None

@abstractmethod
Expand All @@ -64,6 +68,10 @@ def unnormalized_log_ratio(self, theta: Tensor, x: Tensor, **kwargs) -> Tensor:

raise NotImplementedError

def forward(self, *args, **kwargs) -> Tensor:
r"""Wraps `unnormalized_log_ratio`"""
return self.unnormalized_log_ratio(*args, **kwargs)


class TensorRatioEstimator(RatioEstimator):
def __init__(
Expand All @@ -87,16 +95,13 @@ def __init__(

@staticmethod
def combine_embedded_theta_and_x(
embedded_theta: Tensor, embedded_x: Tensor
embedded_theta: Tensor, embedded_x: Tensor, dim: int = -1
) -> Tensor:
"""concatenate embedded theta and embedded x"""
return torch.cat([embedded_theta, embedded_x], dim=-1)
"""Concatenate embedded theta and embedded x"""
return torch.cat([embedded_theta, embedded_x], dim=dim)

def embed_and_combine_theta_and_x(self, theta: Tensor, x: Tensor) -> Tensor:
return self.combine_embedded_theta_and_x(
def unnormalized_log_ratio(self, theta: Tensor, x: Tensor) -> Tensor:
z = self.combine_embedded_theta_and_x(
self.embedding_net_theta(theta), self.embedding_net_x(x)
)

def unnormalized_log_ratio(self, theta: Tensor, x: Tensor) -> Tensor:
z = self.embed_and_combine_theta_and_x(theta, x)
return self.net(z)

0 comments on commit 3b8afae

Please sign in to comment.