Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to how data is stored #685

Merged
merged 18 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sbi/examples/minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def flexible():
inference = SNPE(prior)

theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=500)
density_estimator = inference.append_simulations(theta, x).train()
inference.append_simulations(theta, x)
density_estimator = inference.train()
tbmiller-astro marked this conversation as resolved.
Show resolved Hide resolved
posterior = inference.build_posterior(density_estimator)
posterior.sample((100,), x=x_o)

Expand Down
63 changes: 34 additions & 29 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
warn_on_invalid_x,
warn_on_invalid_x_for_snpec_leakage,
)
from sbi.utils.sbiutils import get_simulations_since_round
from sbi.utils.sbiutils import get_simulations_indcies
from sbi.utils.torchutils import check_if_prior_on_device, process_device
from sbi.utils.user_input_checks import prepare_for_sbi

Expand Down Expand Up @@ -128,7 +128,8 @@ def __init__(

# Initialize roundwise (theta, x, prior_masks) for storage of parameters,
# simulations and masks indicating if simulations came from prior.
self._theta_roundwise, self._x_roundwise, self._prior_masks = [], [], []
self._dataset = data.Dataset()
self._num_sims_per_round = []
self._model_bank = []

# Initialize list that indicates the round from which simulations were drawn.
Expand Down Expand Up @@ -159,8 +160,6 @@ def __init__(
def get_simulations(
self,
starting_round: int = 0,
exclude_invalid_x: bool = True,
warn_on_invalid: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`.

Expand All @@ -177,27 +176,24 @@ def get_simulations(
Returns: Parameters, simulation outputs, prior masks.
"""

theta = get_simulations_since_round(
self._theta_roundwise, self._data_round_index, starting_round
)
x = get_simulations_since_round(
self._x_roundwise, self._data_round_index, starting_round
)
prior_masks = get_simulations_since_round(
self._prior_masks, self._data_round_index, starting_round
# This is a pretty clunky implementation but not sure this will be used much with
# new implementation of `get_dataloaders`
indicies = get_simulations_indcies(
self._num_sims_per_round, self._data_round_index, starting_round
)
theta, x, prior_masks = [], [], []

# Check for NaNs in simulations.
is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x)
# Check for problematic z-scoring
warn_if_zscoring_changes_data(x[is_valid_x])
if warn_on_invalid:
warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x)
warn_on_invalid_x_for_snpec_leakage(
num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round
)
for ind in indicies:
theta_cur, x_cur, prior_mask_cur = self._dataset[ind]
theta.append(theta_cur)
x.append(x_cur)
prior_masks.append(prior_mask_cur)

theta = torch.stack(theta)
x = torch.stack(x)
prior_masks = torch.stack(prior_masks).squeeze()
tbmiller-astro marked this conversation as resolved.
Show resolved Hide resolved

return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x]
return theta, x, prior_masks

@abstractmethod
def train(
Expand All @@ -218,7 +214,7 @@ def train(

def get_dataloaders(
self,
dataset: data.TensorDataset,
starting_round: int = 0,
training_batch_size: int = 50,
validation_fraction: float = 0.1,
resume_training: bool = False,
Expand All @@ -239,15 +235,20 @@ def get_dataloaders(

"""

# Get total number of training examples.
num_examples = len(dataset)
# Generate indicies to use based on starting_round
indices = get_simulations_indcies(
tbmiller-astro marked this conversation as resolved.
Show resolved Hide resolved
self._num_sims_per_round, self._data_round_index, starting_round
)

# Get total number of training examples.
num_examples = len(indices)
# Select random train and validation splits from (theta, x) pairs.
num_training_examples = int((1 - validation_fraction) * num_examples)
num_validation_examples = num_examples - num_training_examples

if not resume_training:
permuted_indices = torch.randperm(num_examples)
# Seperate indicies for training and validation
permuted_indices = indices[torch.randperm(num_examples)]
self.train_indices, self.val_indices = (
permuted_indices[:num_training_examples],
permuted_indices[num_training_examples:],
Expand All @@ -272,8 +273,8 @@ def get_dataloaders(
train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs)
val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs)

train_loader = data.DataLoader(dataset, **train_loader_kwargs)
val_loader = data.DataLoader(dataset, **val_loader_kwargs)
train_loader = data.DataLoader(self._dataset, **train_loader_kwargs)
val_loader = data.DataLoader(self._dataset, **val_loader_kwargs)

return train_loader, val_loader

Expand Down Expand Up @@ -358,7 +359,11 @@ def _report_convergence_at_end(
)

def _summarize(
self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor
self,
round_: int,
x_o: Union[Tensor, None],
theta_bank: Union[Tensor, None],
x_bank: Union[Tensor, None],
) -> None:
"""Update the summary_writer with statistics for a given round.

Expand Down
87 changes: 63 additions & 24 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from sbi.utils import (
check_estimator_arg,
check_prior,
handle_invalid_x,
mask_sims_from_prior,
validate_theta_and_x,
warn_if_zscoring_changes_data,
warn_on_invalid_x,
x_shape_from_simulation,
)

Expand Down Expand Up @@ -83,7 +86,12 @@ def append_simulations(
theta: Tensor,
x: Tensor,
from_round: int = 0,
) -> "LikelihoodEstimator":
exclude_invalid_x: bool = True,
warn_on_invalid: bool = True,
warn_if_zscoring: bool = True,
tbmiller-astro marked this conversation as resolved.
Show resolved Hide resolved
return_self: bool = True,
tbmiller-astro marked this conversation as resolved.
Show resolved Hide resolved
data_device: Optional[str] = None,
) -> Union["LikelihoodEstimator", None]:
r"""Store parameters and simulation outputs to use them for later training.

Data are stored as entries in lists for each type of variable (parameter/data).
Expand All @@ -99,19 +107,56 @@ def append_simulations(
With default settings, this is not used at all for `SNLE`. Only when
the user later on requests `.train(discard_prior_samples=True)`, we
use these indices to find which training data stemmed from the prior.

exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
warn_on_invalid: Whether to warn if data is invalid
warn_if_zscoring: Whether to test if z-scoring causes duplicates
return_self: Whether to return a instance of the class, allows chaining
with `.train()`. Setting `False` decreases memory overhead.
data_device: Where to store the data, default is on the same device where
the training is happening. If training a large dataset on a GPU with not
much VRAM can set to 'cpu' to store data on system memory instead.
Returns:
NeuralInference object (returned so that this function is chainable).
"""

theta, x = validate_theta_and_x(theta, x, training_device=self._device)
is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x)

# Check for problematic z-scoring
if warn_if_zscoring:
warn_if_zscoring_changes_data(x[is_valid_x])
if warn_on_invalid:
warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x)

x = x[is_valid_x]
theta = theta[is_valid_x]

if data_device is None:
data_device = self._device
theta, x = validate_theta_and_x(theta, x, training_device=data_device)
prior_masks = mask_sims_from_prior(int(from_round), theta.size(0))

if len(self._num_sims_per_round) == 0:
# If first round, set up ConcatDataset
self._dataset = data.ConcatDataset(
[
data.TensorDataset(theta, x, prior_masks),
]
)
else:
# Otherwise append to Dataset
self._dataset = data.ConcatDataset(
self._dataset.datasets
+ [
data.TensorDataset(theta, x, prior_masks),
]
)

self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._prior_masks.append(mask_sims_from_prior(int(from_round), theta.size(0)))
self._num_sims_per_round.append(theta.size(0))
self._data_round_index.append(int(from_round))

return self
if return_self:
return self

def train(
self,
Expand All @@ -121,7 +166,6 @@ def train(
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
exclude_invalid_x: bool = True,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
Expand All @@ -131,8 +175,6 @@ def train(
r"""Train the density estimator to learn the distribution $p(x|\theta)$.

Args:
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
Expand All @@ -150,20 +192,13 @@ def train(
Returns:
Density estimator that has learned the distribution $p(x|\theta)$.
"""

# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)
# Load data from most recent round.
self._round = max(self._data_round_index)
theta, x, _ = self.get_simulations(
start_idx, exclude_invalid_x, warn_on_invalid=True
)

# Dataset is shared for training and validation loaders.
dataset = data.TensorDataset(theta, x)
# Starting index for the training set (1 = discard round-0 samples).
start_idx = int(discard_prior_samples and self._round > 0)

train_loader, val_loader = self.get_dataloaders(
dataset,
start_idx,
training_batch_size,
validation_fraction,
resume_training,
Expand All @@ -176,10 +211,14 @@ def train(
# This is passed into NeuralPosterior, to create a neural posterior which
# can `sample()` and `log_prob()`. The network is accessible via `.net`.
if self._neural_net is None or retrain_from_scratch:

# Get theta,x from dataset to initialize NN
theta, x, _ = self.get_simulations()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

theta, x, _ = self.get_simulations(starting_round=start_idx)

self._neural_net = self._build_neural_net(
theta[self.train_indices], x[self.train_indices]
theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the theta and x batches are used by the neural net builder to build the standardizing net using sample mean and std. I am wondering whether using only first the training_batch_size data points might affect the accuracy of the standardizing transform...

)
self._x_shape = x_shape_from_simulation(x)
self._x_shape = x_shape_from_simulation(x[:training_batch_size].to("cpu"))
del theta, x
assert (
len(self._x_shape) < 3
), "SNLE cannot handle multi-dimensional simulator output."
Expand Down Expand Up @@ -257,8 +296,8 @@ def train(
self._summarize(
round_=self._round,
x_o=None,
theta_bank=theta,
x_bank=x,
theta_bank=None,
x_bank=None,
)

# Update description for progress bar.
Expand Down
3 changes: 0 additions & 3 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def train(
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
exclude_invalid_x: bool = True,
resume_training: bool = False,
force_first_round_loss: bool = False,
retrain_from_scratch: bool = False,
Expand Down Expand Up @@ -138,8 +137,6 @@ def train(
prevent exploding gradients. Use None for no clipping.
calibration_kernel: A function to calibrate the loss with respect to the
simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
Expand Down
Loading