Skip to content

Commit

Permalink
moved mask_sims_from_prior to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jul 13, 2020
1 parent 69c8d1b commit 4cd745e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
17 changes: 2 additions & 15 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sbi.utils import get_log_root, handle_invalid_x, warn_on_invalid_x
from sbi.utils.plot import pairplot
from sbi.utils.torchutils import configure_default_device
from sbi.utils.sbiutils import get_data_after_round
from sbi.utils.sbiutils import get_data_after_round, mask_sims_from_prior


def infer(
Expand Down Expand Up @@ -187,7 +187,7 @@ def _append_to_round_bank(self, theta: Tensor, x: Tensor, round_: int) -> None:

self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._prior_masks.append(self._mask_sims_from_prior(round_, theta.size(0)))
self._prior_masks.append(mask_sims_from_prior(round_, theta.size(0)))
self._data_round_index.append(round_)

def _get_from_round_bank(
Expand Down Expand Up @@ -289,19 +289,6 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool:

return converged

def _mask_sims_from_prior(self, round_: int, num_simulations: int) -> Tensor:
"""Returns Tensor True where simulated from prior parameters.
Args:
round_: Current training round, starting at 0.
num_simulations: Actually performed simulations. This number can be below
the one fixed for the round if leakage correction through sampling is
active and `patience` is not enough to reach it.
"""

prior_mask_values = ones if round_ == 0 else zeros
return prior_mask_values((num_simulations, 1), dtype=torch.bool)

def _default_summary_writer(self) -> SummaryWriter:
"""Return summary writer logging to method- and simulator-specific directory."""
try:
Expand Down
1 change: 1 addition & 0 deletions sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
warn_on_invalid_x,
x_shape_from_simulation,
get_data_after_round,
mask_sims_from_prior,
)
from sbi.utils.torchutils import (
BoxUniform,
Expand Down
16 changes: 15 additions & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn
from pyknos.nflows import transforms
from torch import Tensor, as_tensor, ones
from torch import Tensor, as_tensor, ones, zeros
from tqdm.auto import tqdm


Expand Down Expand Up @@ -233,3 +233,17 @@ def get_data_after_round(
counting from 0.
"""
return torch.cat([t for t, r in zip(data, data_round_index) if r >= starting_round])


def mask_sims_from_prior(round_: int, num_simulations: int) -> Tensor:
"""Returns Tensor True where simulated from prior parameters.
Args:
round_: Current training round, starting at 0.
num_simulations: Actually performed simulations. This number can be below
the one fixed for the round if leakage correction through sampling is
active and `patience` is not enough to reach it.
"""

prior_mask_values = ones if round_ == 0 else zeros
return prior_mask_values((num_simulations, 1), dtype=torch.bool)

0 comments on commit 4cd745e

Please sign in to comment.