Skip to content

Commit

Permalink
Bugfix for SNPE_A
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Mar 22, 2022
1 parent 3d4eac2 commit 565082c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def train(
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
resume_training: bool = False,
force_first_round_loss: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -146,6 +147,8 @@ def train(
force_first_round_loss: If `True`, train with maximum likelihood,
i.e., potentially ignoring the correction for using a proposal
distribution different from the prior.
force_first_round_loss: If `True`, train with maximum likelihood,
regardless of the proposal distribution.
retrain_from_scratch: Whether to retrain the conditional density
estimator for the posterior from scratch each round. Not supported for
SNPE-A.
Expand Down Expand Up @@ -434,7 +437,7 @@ def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor:
)
return log_prob_proposal_posterior # \hat{p} from eq (3) in [1]

def sample(self, num_samples: int, context: Tensor, batch_size: int) -> Tensor:
def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor:
context = context.to(self._device)

if not self._apply_correction:
Expand Down Expand Up @@ -464,7 +467,7 @@ def _sample_approx_posterior_mog(

# Compute the precision factors which represent the upper triangular matrix
# of the cholesky decomposition of the prec_p.
prec_factors_p = torch.cholesky(prec_p, upper=True)
prec_factors_p = torch.linalg.cholesky(prec_p, upper=True)

assert logits_p.ndim == 2
assert m_p.ndim == 3
Expand Down Expand Up @@ -696,14 +699,14 @@ def _precisions_posterior(self, precisions_pp: Tensor, precisions_d: Tensor):
# Check if precision matrices are positive definite.
for batches in precisions_pp:
for pprior in batches:
eig_pprior = torch.symeig(pprior, eigenvectors=False).eigenvalues
eig_pprior = torch.linalg.eigvalsh(pprior, UPLO="U")
if not (eig_pprior > 0).all():
raise AssertionError(
"The precision matrix of the proposal is not positive definite!"
)
for batches in precisions_d:
for d in batches:
eig_d = torch.symeig(d, eigenvectors=False).eigenvalues
eig_d = torch.linalg.eigvalsh(d, UPLO="U")
if not (eig_d > 0).all():
raise AssertionError(
"The precision matrix of the density estimator is not "
Expand Down

0 comments on commit 565082c

Please sign in to comment.