Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to how data is stored #685

Merged
merged 18 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/HH_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def tau_p(x):
+ g_leak * E_leak
+ gbar_M * p[i - 1] * E_K
+ I[i - 1]
+ nois_fact * rng.randn() / (tstep ** 0.5)
+ nois_fact * rng.randn() / (tstep**0.5)
) / (tau_V_inv * C)
V[i] = V_inf + (V[i - 1] - V_inf) * np.exp(-tstep * tau_V_inv)
n[i] = n_inf(V[i]) + (n[i - 1] - n_inf(V[i])) * np.exp(-tstep / tau_n(V[i]))
Expand Down
44 changes: 18 additions & 26 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,7 @@
import sbi.inference
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.simulators.simutils import simulate_in_batches
from sbi.utils import (
check_prior,
get_log_root,
handle_invalid_x,
warn_if_zscoring_changes_data,
warn_on_invalid_x,
warn_on_invalid_x_for_snpec_leakage,
)
from sbi.utils import check_prior, get_log_root
from sbi.utils.sbiutils import get_simulations_since_round
from sbi.utils.torchutils import check_if_prior_on_device, process_device
from sbi.utils.user_input_checks import prepare_for_sbi
Expand Down Expand Up @@ -128,7 +121,9 @@ def __init__(

# Initialize roundwise (theta, x, prior_masks) for storage of parameters,
# simulations and masks indicating if simulations came from prior.
self._theta_roundwise, self._x_roundwise, self._prior_masks = [], [], []
self._theta_roundwise = []
self._x_roundwise = []
self._prior_masks = []
self._model_bank = []

# Initialize list that indicates the round from which simulations were drawn.
Expand Down Expand Up @@ -159,8 +154,6 @@ def __init__(
def get_simulations(
self,
starting_round: int = 0,
exclude_invalid_x: bool = True,
warn_on_invalid: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`.

Expand All @@ -187,17 +180,7 @@ def get_simulations(
self._prior_masks, self._data_round_index, starting_round
)

# Check for NaNs in simulations.
is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x)
# Check for problematic z-scoring
warn_if_zscoring_changes_data(x[is_valid_x])
if warn_on_invalid:
warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x)
warn_on_invalid_x_for_snpec_leakage(
num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round
)

return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x]
return theta, x, prior_masks

@abstractmethod
def train(
Expand All @@ -218,7 +201,7 @@ def train(

def get_dataloaders(
self,
dataset: data.TensorDataset,
starting_round: int = 0,
training_batch_size: int = 50,
validation_fraction: float = 0.1,
resume_training: bool = False,
Expand All @@ -239,14 +222,19 @@ def get_dataloaders(

"""

# Get total number of training examples.
num_examples = len(dataset)
#
theta, x, prior_masks = self.get_simulations(starting_round)

dataset = data.TensorDataset(theta, x, prior_masks)

# Get total number of training examples.
num_examples = theta.size(0)
# Select random train and validation splits from (theta, x) pairs.
num_training_examples = int((1 - validation_fraction) * num_examples)
num_validation_examples = num_examples - num_training_examples

if not resume_training:
# Seperate indicies for training and validation
permuted_indices = torch.randperm(num_examples)
self.train_indices, self.val_indices = (
permuted_indices[:num_training_examples],
Expand Down Expand Up @@ -358,7 +346,11 @@ def _report_convergence_at_end(
)

def _summarize(
self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor
self,
round_: int,
x_o: Union[Tensor, None],
theta_bank: Union[Tensor, None],
x_bank: Union[Tensor, None],
) -> None:
"""Update the summary_writer with statistics for a given round.

Expand Down
56 changes: 36 additions & 20 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from sbi.utils import (
check_estimator_arg,
check_prior,
handle_invalid_x,
mask_sims_from_prior,
validate_theta_and_x,
warn_if_zscoring_changes_data,
warn_on_invalid_x,
x_shape_from_simulation,
)

Expand Down Expand Up @@ -83,6 +86,7 @@ def append_simulations(
theta: Tensor,
x: Tensor,
from_round: int = 0,
data_device: Optional[str] = None,
) -> "LikelihoodEstimator":
r"""Store parameters and simulation outputs to use them for later training.

Expand All @@ -99,16 +103,34 @@ def append_simulations(
With default settings, this is not used at all for `SNLE`. Only when
the user later on requests `.train(discard_prior_samples=True)`, we
use these indices to find which training data stemmed from the prior.

data_device: Where to store the data, default is on the same device where
the training is happening. If training a large dataset on a GPU with not
much VRAM can set to 'cpu' to store data on system memory instead.
Returns:
NeuralInference object (returned so that this function is chainable).
"""

theta, x = validate_theta_and_x(theta, x, training_device=self._device)
is_valid_x, num_nans, num_infs = handle_invalid_x(x, True) # Hardcode to True

x = x[is_valid_x]
theta = theta[is_valid_x]

# Check for problematic z-scoring
warn_if_zscoring_changes_data(x)
warn_on_invalid_x(num_nans, num_infs, True)

if data_device is None:
data_device = self._device
theta, x = validate_theta_and_x(
theta, x, data_device=data_device, training_device=self._device
)

prior_masks = mask_sims_from_prior(int(from_round), theta.size(0))

self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0)))
self._prior_masks.append(prior_masks)

self._data_round_index.append(int(from_round))

return self
Expand All @@ -121,7 +143,6 @@ def train(
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
exclude_invalid_x: bool = True,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
Expand All @@ -131,8 +152,6 @@ def train(
r"""Train the density estimator to learn the distribution $p(x|\theta)$.

Args:
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
Expand All @@ -150,20 +169,13 @@ def train(
Returns:
Density estimator that has learned the distribution $p(x|\theta)$.
"""

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)
# Load data from most recent round.
self._round = max(self._data_round_index)
theta, x, _ = self.get_simulations(
start_idx, exclude_invalid_x, warn_on_invalid=True
)

# Dataset is shared for training and validation loaders.
dataset = data.TensorDataset(theta, x)
# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)

train_loader, val_loader = self.get_dataloaders(
dataset,
start_idx,
training_batch_size,
validation_fraction,
resume_training,
Expand All @@ -176,10 +188,14 @@ def train(
# This is passed into NeuralPosterior, to create a neural posterior which
# can `sample()` and `log_prob()`. The network is accessible via `.net`.
if self._neural_net is None or retrain_from_scratch:

# Get theta,x from dataset to initialize NN
theta, x, _ = self.get_simulations()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

theta, x, _ = self.get_simulations(starting_round=start_idx)

self._neural_net = self._build_neural_net(
theta[self.train_indices], x[self.train_indices]
theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the theta and x batches are used by the neural net builder to build the standardizing net using sample mean and std. I am wondering whether using only first the training_batch_size data points might affect the accuracy of the standardizing transform...

)
self._x_shape = x_shape_from_simulation(x)
self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu"))
del theta, x
assert (
len(self._x_shape) < 3
), "SNLE cannot handle multi-dimensional simulator output."
Expand Down Expand Up @@ -257,8 +273,8 @@ def train(
self._summarize(
round_=self._round,
x_o=None,
theta_bank=theta,
x_bank=x,
theta_bank=None,
x_bank=None,
)

# Update description for progress bar.
Expand Down
3 changes: 0 additions & 3 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def train(
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
resume_training: bool = False,
force_first_round_loss: bool = False,
retrain_from_scratch: bool = False,
Expand Down Expand Up @@ -138,8 +137,6 @@ def train(
prevent exploding gradients. Use None for no clipping.
calibration_kernel: A function to calibrate the loss with respect to the
simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
Expand Down
Loading