Skip to content

Commit

Permalink
renaming some functions for external data
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jul 13, 2020
1 parent 671d8c1 commit 506a888
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 28 deletions.
24 changes: 11 additions & 13 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sbi.utils import get_log_root, handle_invalid_x, warn_on_invalid_x
from sbi.utils.plot import pairplot
from sbi.utils.torchutils import configure_default_device
from sbi.utils.sbiutils import get_data_after_round, mask_sims_from_prior
from sbi.utils.sbiutils import get_data_since_round, mask_sims_from_prior


def infer(
Expand Down Expand Up @@ -168,9 +168,9 @@ def provide_presimulated(
from_round: Which round the data was simulated from. `from_round=0` means
that the data came from the first round.
"""
self._append_to_round_bank(theta, x, from_round)
self._append_to_data_bank(theta, x, from_round)

def _append_to_round_bank(self, theta: Tensor, x: Tensor, round_: int) -> None:
def _append_to_data_bank(self, theta: Tensor, x: Tensor, from_round: int) -> None:
r"""
Store data in as entries in a list for each type of variable (parameter/data).
Expand All @@ -181,16 +181,16 @@ def _append_to_round_bank(self, theta: Tensor, x: Tensor, round_: int) -> None:
Args:
theta: Parameter sets.
x: Simulated data.
round_: What round the $(\theta, x)$ pairs are coming from. We start
from_round: What round the $(\theta, x)$ pairs are coming from. We start
counting from round 0.
"""

self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._prior_masks.append(mask_sims_from_prior(round_, theta.size(0)))
self._data_round_index.append(round_)
self._prior_masks.append(mask_sims_from_prior(from_round, theta.size(0)))
self._data_round_index.append(from_round)

def _get_from_round_bank(
def _get_from_data_bank(
self,
starting_round: int = 0,
exclude_invalid_x: bool = True,
Expand All @@ -212,13 +212,13 @@ def _get_from_round_bank(
Returns: Parameters, simulation outputs, prior masks.
"""

theta = get_data_after_round(
theta = get_data_since_round(
self._theta_roundwise, self._data_round_index, starting_round
)
x = get_data_after_round(
x = get_data_since_round(
self._x_roundwise, self._data_round_index, starting_round
)
prior_masks = get_data_after_round(
prior_masks = get_data_since_round(
self._prior_masks, self._data_round_index, starting_round
)

Expand All @@ -245,16 +245,14 @@ def _run_simulations(self, round_: int, num_sims: int,) -> Tuple[Tensor, Tensor]

if round_ == 0:
theta = self._prior.sample((num_sims,))

x = self._batched_simulator(theta)
else:
theta = self._posterior.sample(
(num_sims,),
x=self._posterior.default_x,
show_progress_bars=self._show_progress_bars,
)

x = self._batched_simulator(theta)
x = self._batched_simulator(theta)

return theta, x

Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def __call__(
# Run simulations for the round.
theta, x = self._run_simulations(round_, num_sims)

self._append_to_round_bank(theta, x, round_)
self._append_to_data_bank(theta, x, round_)

# Load data from most recent round.
theta, x, _ = self._get_from_round_bank(round_, exclude_invalid_x, False)
theta, x, _ = self._get_from_data_bank(round_, exclude_invalid_x, False)

# First round or if retraining from scratch:
# Call the `self._build_neural_net` with the rounds' thetas and xs as
Expand Down Expand Up @@ -198,7 +198,7 @@ def _train(

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and round_ > 0)
theta, x, _ = self._get_from_round_bank(start_idx, exclude_invalid_x)
theta, x, _ = self._get_from_data_bank(start_idx, exclude_invalid_x)

# Get total number of training examples.
num_examples = len(theta)
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def __call__(

# Run simulations for the round.
theta, x = self._run_simulations(round_, num_sims)
self._append_to_round_bank(theta, x, round_)
self._append_to_data_bank(theta, x, round_)

# Load data from most recent round.
theta, x, _ = self._get_from_round_bank(round_, exclude_invalid_x, False)
theta, x, _ = self._get_from_data_bank(round_, exclude_invalid_x, False)

# First round or if retraining from scratch:
# Call the `self._build_neural_net` with the rounds' thetas and xs as
Expand Down Expand Up @@ -262,7 +262,7 @@ def _train(

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and round_ > 0)
theta, x, prior_masks = self._get_from_round_bank(start_idx, exclude_invalid_x)
theta, x, prior_masks = self._get_from_data_bank(start_idx, exclude_invalid_x)

# Select random neural net and validation splits from (theta, x) pairs.
num_total_examples = len(theta)
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __call__(
# Run simulations for the round.
theta, x = self._run_simulations(round_, num_sims)

self._append_to_round_bank(theta, x, round_)
self._append_to_data_bank(theta, x, round_)

# Load data from most recent round.
theta, x, _ = self._get_from_round_bank(round_, exclude_invalid_x, False)
theta, x, _ = self._get_from_data_bank(round_, exclude_invalid_x, False)

# First round or if retraining from scratch:
# Call the `self._build_neural_net` with the rounds' thetas and xs as
Expand Down Expand Up @@ -209,7 +209,7 @@ def _train(

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and round_ > 0)
theta, x, _ = self._get_from_round_bank(start_idx, exclude_invalid_x)
theta, x, _ = self._get_from_data_bank(start_idx, exclude_invalid_x)

# Get total number of training examples.
num_examples = len(theta)
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
handle_invalid_x,
warn_on_invalid_x,
x_shape_from_simulation,
get_data_after_round,
get_data_since_round,
mask_sims_from_prior,
)
from sbi.utils.torchutils import (
Expand Down
12 changes: 7 additions & 5 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,23 @@ def warn_on_invalid_x(num_nans: int, num_infs: int, exclude_invalid_x: bool) ->
)


def get_data_after_round(
data: List, data_round_index: List, starting_round: int
def get_data_since_round(
data: List, data_round_indices: List, starting_round_index: int
) -> Tensor:
"""
Returns tensor with all data coming from a round >= `starting_round`.
Args:
data: Each list entry contains a set of data (either parameters, simulation
outputs, or prior masks).
data_round_index: List with same length as data, each entry is an integer that
data_round_indices: List with same length as data, each entry is an integer that
indicates which round the data is from.
starting_round: From which round onwards to return the data. We start
starting_round_index: From which round onwards to return the data. We start
counting from 0.
"""
return torch.cat([t for t, r in zip(data, data_round_index) if r >= starting_round])
return torch.cat(
[t for t, r in zip(data, data_round_indices) if r >= starting_round_index]
)


def mask_sims_from_prior(round_: int, num_simulations: int) -> Tensor:
Expand Down

0 comments on commit 506a888

Please sign in to comment.