-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to how data is stored #685
Changes to how data is stored #685
Conversation
Hi @tbmiller-astro, thanks a lot for this PR, and sorry for the delay in the review--it seems we are all very busy at the moment. But I plan to have a look at this soon! Best, |
Hi @tbmiller-astro, thanks a lot for the PR. Your code is very clean and easy to follow, and your PR description also helped a lot! Thanks! I like that you introduced a I have two questions regarding your rational behind other changes though:
Thanks a lot for the PR! |
Hey @michaeldeistler, thanks a lot and these are great questions!
Happy to continue making changes if requested! Tim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Tim,
thanks a lot for your answer, this makes sense! I left a few minor comments below.
I think that I would prefer to store the data as lists
of Tensors
instead of as torch.TensorDataset
. Here's why:
- It does not require
get_simulations_indices()
, but we can instead use the simplerget_simulations_since_round()
- Not all operations (e.g, z-scroing) require the
torch.TensorDataset
. So, I would prefer to generate thetorch.TensorDataset
only when it is really needed (i.e. in.train()
). - My main reason is that
lists
provide a nice way of structuring data that was passed over several rounds. This structure is lost when all simulations are concatenated into a singletorch.TensorDataset
. Sure, one can potentially recover the structure via_num_sims_per_round
, but it's just less easy to see what is happening.
So, to be clear, I would only revert these changes in get_simulations()
and these in append_simulations()
. I am pefectly fine with your suggestion that the get_dataloaders()
should, in a single call, generate the dataloader. However, I would suggest to generate the dataloaders from the lists instead of from the torch.TensorDataset
.
Please let me know if you are up for implementing these changes! If you are busy, we can also merge your PR as it is right now and I'll make the changes myself.
Thanks again!
Best wishes
Michael
Hey @michaeldeistler, sounds good! Happy to make the changes suggested. For the data checks, I've found they have a big affect on the memory footprint. Both the functions |
Hi Tim, awesome, thanks! Hmmm I was not aware of this. However, So, since we already have to copy the dataset once (in I agree that we should ensure that we have as few copies of the whole dataset as possible. I created an issue #695 to track this. For now, I would suggest that we remove the option from |
Just addressed these requests. Also changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Thanks so much for your taking the time to fix these things and going through the review process! I left very minor comments (but feel free to ignore them if you are busy). I'll merge the PR on Monday.
All the best and thanks again!
Michael
sbi/inference/snle/snle_base.py
Outdated
@@ -176,10 +188,14 @@ def train( | |||
# This is passed into NeuralPosterior, to create a neural posterior which | |||
# can `sample()` and `log_prob()`. The network is accessible via `.net`. | |||
if self._neural_net is None or retrain_from_scratch: | |||
|
|||
# Get theta,x from dataset to initialize NN | |||
theta, x, _ = self.get_simulations() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
theta, x, _ = self.get_simulations(starting_round=start_idx)
x = self._x_roundwise[0][:training_batch_size] | ||
theta = self._theta_roundwise[0][:training_batch_size] | ||
self._neural_net = self._build_neural_net(theta.to("cpu"), x.to("cpu")) | ||
self._x_shape = x_shape_from_simulation(x.to("cpu")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you replace these lines with the code that you also used in snle_base.py
and snre_base.py
(i.e. use get_simulations
)? Or is there a reason to not use get_simulations()
here?
sbi/inference/snre/snre_base.py
Outdated
@@ -183,11 +198,14 @@ def train( | |||
# This is passed into NeuralPosterior, to create a neural posterior which | |||
# can `sample()` and `log_prob()`. The network is accessible via `.net`. | |||
if self._neural_net is None or retrain_from_scratch: | |||
|
|||
# Get theta,x from dataset to initialize NN | |||
theta, x, _ = self.get_simulations() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
theta, x, _ = self.get_simulations(starting_round=start_idx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall great! Thanks a lot for working on this!
see comments below.
sbi/inference/snle/snle_base.py
Outdated
self._neural_net = self._build_neural_net( | ||
theta[self.train_indices], x[self.train_indices] | ||
theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the theta
and x
batches are used by the neural net builder to build the standardizing net using sample mean
and std
. I am wondering whether using only first the training_batch_size
data points might affect the accuracy of the standardizing transform...
sbi/inference/snpe/snpe_base.py
Outdated
theta[self.train_indices], x[self.train_indices] | ||
|
||
# Get theta,x from dataset to initialize NN | ||
x = self._x_roundwise[0][:training_batch_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above applies here: here we are taking only the first training_batch_size
instead of the entire training data set to estimate the standard transform, no?
sbi/inference/snre/snre_base.py
Outdated
self._neural_net = self._build_neural_net( | ||
theta[self.train_indices], x[self.train_indices] | ||
theta[:training_batch_size].to("cpu"), x[:training_batch_size].to("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above.
@@ -647,7 +647,7 @@ def check_estimator_arg(estimator: Union[str, Callable]) -> None: | |||
|
|||
|
|||
def validate_theta_and_x( | |||
theta: Any, x: Any, training_device: str = "cpu" | |||
theta: Any, x: Any, data_device: str = "cpu", training_device: str = "cpu" | |||
) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Checks if the passed $(\theta, x)$ are valid. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be great to update this docstring and be more explicit about what this function is checking. What's the difference between training and data device, whats the overall goal of this function?
Thanks both! I should be able to make these final changes tomorrow! - Tim |
Thanks again! This is really fantastic work! All the best |
Changes to how data is stored within the
inference
class, discussed in #678. By default all the behavior should be the same as before but the changes help reduce the overall memory overhead and add extra flexibility when dealing with large datasets.Summary of major changes:
Dataset
rather than in lists as before. Intrain()
only a call toget_dataloaders
is needed.get_dataloader
is changed to create the Dataloader from the saved Dataset. Takes the argumentstart_round
to define which round(s) of simulations to load.append_simulations()
. Three arguments have been added:return_self
to control if the method returns a copy of the class,data_device
to control where data lives, independent of the device used for training andwarn_if_zscoring
to control if the z-scoring check is initiated.