Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the ability to pass dataloader related kwargs in SNPE, SNRE and SNLE #435

Closed
wants to merge 0 commits into from

Conversation

narendramukherjee
Copy link
Contributor

@narendramukherjee narendramukherjee commented Feb 7, 2021

NOTE: The majority of the code changes are formatting related changes, which came up when I ran black and isort on the repo (as instructed in the contribution guidelines)

The real code changes are in the following files:

  1. inference/snpe/snpe_base.py
  2. inference/snre/snre_base.py
  3. inference/snle/snle_base.py

In each of these files, I have added a **dataloader_kwargs parameter to be passed to the train method. I am imagining this to be "additional" args to be passed to the dataloader, over and above the ones that are already being used to build the train and validation dataloaders. In my specific case, this would be a collate_fn. I thought that it would be a waste to make the user specify all the args to dataloader here, as most of the default args being used right now are great - any additional args can be passed using **dataloader_kwargs.

I tried testing this code change with SNPE - however, snpe_c.py has a relatively complicated way of calling the train method of snpe_base using kwargs built here: https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py#L150
So when I pass any dataloader related kwargs to snpe_c.train, those aren't recognized in the train method's signature. Since there's a complicated way in which the snpe_base.train is being called from here, I didn't want to mess with it before getting your opinion first.

Let me know how I should proceed, and if you have any other suggestions/questions about my approach!

closing #435

@janfb
Copy link
Contributor

janfb commented Feb 8, 2021

Many thanks for the PR, I will have a detailed look soon.

A first thing, I noticed a lot of formatting changes, which is great. Would it be possible to move all the formatting into a separate commit? Then we can see more clearly what was done when, where and why.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Regarding your question of how to pass the additional kwarg, e.g., from snpe_c to snpe_base I added two comments at the specific locations.

@@ -137,6 +138,8 @@ def train(
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss after the training.
dataloader_kwargs: Any additional kwargs to be passed to the training and
validation dataloaders (like a collate_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(like, e.g., a collate_fn)

@@ -192,6 +190,8 @@ def train(
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss after the training.
dataloader_kwargs: Any additional kwargs to be passed to the training and
validation dataloaders (like a collate_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment on formulation above.

@@ -203,9 +203,6 @@ def train(

max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs

# Load data from most recent round.
theta, x, _ = self.get_simulations(self._round, exclude_invalid_x, False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we fixed this already. Which lets it look like you are based on an older main branch. Have you tried rebasing your local branch on the most recent version of our main?

@@ -139,6 +140,8 @@ def train(
samples.
retrain_from_scratch_each_round: Whether to retrain the conditional density
estimator for the posterior from scratch each round.
dataloader_kwargs: Any additional kwargs to be passed to the training and
validation dataloaders (like a collate_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above.

@@ -166,6 +163,7 @@ def train(
discard_prior_samples: bool = False,
retrain_from_scratch_each_round: bool = False,
show_train_summary: bool = False,
**dataloader_kwargs: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't you make this an explicit kwarg of train()? E.g., dataloader_kwargs: Optional[dict] = None, and then the same in the snpe_c train method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in snpe_c line 108 add the same additional kwarg: dataloader_kwargs: Optional[dict] = None. I haven't tested it, just an idea.

@michaeldeistler
Copy link
Contributor

Hi there! Thanks for creating the PR! Just a quick heads-up: we just merged #431 which will generate merge conflicts. Please let us know if you need assistance in rebasing your branch :)

Michael

@narendramukherjee
Copy link
Contributor Author

Thanks for the heads-up @michaeldeistler and for the comments @janfb - I will merge the current main branch from the sbi repo and update my changes in the next couple of days!

@narendramukherjee
Copy link
Contributor Author

@janfb @michaeldeistler
I have been able to add in the dataloader_kwargs as a dict to the _base files of snpe, snre and snle, as well as the specific classes that inherit from those base classes (snpe_c, snle_a, etc). I overwrite the default train and val dataloader kwargs using dict here - this way the user can add a new kwarg to be passed to the dataloaders (or overwrite any of the default kwargs as well). I have no problem calling train with the additional dataloader_kwargs.

However, there are 2 additional issues that I now run into, despite being able to equip the dataloaders with a collate_fn. As I had explained earlier, my simulations consist of unequal length time-series, which have 5-D observations. I prepare the simulations as a 3-D tensor, with size (num_simulations x length of the longest series X 5). Observations in the shorter time-series in that tensor are filled with NaNs. Then, in my collate_fn, I take the longest time-series in each batch, cut all the other simulations in that batch to that length, and fill in any NaNs in the shorter time series with appropriate values. So at the end of the collate_fn, my x has a size (batch_size X length of longest series X 5) - the most important thing to note is that length of longest series naturally varies from batch to batch. My embedding net (a 1D CNN with Global max pooling) is designed to eventually produce a tensor of fixed size (batch_size X 32) from any such batch, irrespective of the length of the longest series in the batch.

Now, using this setup with the collate_fn, here are 2 persisting issues I have - both of them happen while constructing snpe_base._neural_net:

  1. The neural net construction (here) takes the full tensor of simulations (in my case, the big consolidated initial tensor containing NaNs) to deduce the dimensionality of the embedding_net instead of just using a batch - thus, it ignores the collate_fn completely, which is a problem. It is an additional problem for me as my collate_fn actually permutes a couple of the axes of the simulations of the batch (to bring it to the shape needed for the 1D convolutions), and not taking that into account breaks the neural net creation process.

  2. I am able to hack and fix 1 by putting my consolidated, NaN-containing initial tensor of simulations with the same order of axes that the collate_fn would have produced, but that in turn leads to problems while doing z_score_y here - the z-scoring, just like point 1 above, also works on the consolidated initial tensor (and not the simulations from a batch that have passed through the collate_fn) and is unable to deal as a result with the varying lengths of my time-series.

In the end, I was able to run the training by turning off z_score_y, but the training seems to fail without it (I get a best validation performance of -inf). The z scoring seems is probably important in my case as the values of the 5-D observations of my time-series can vary across a couple orders of magnitude, and the neural net is probably unable to deal with that kind of variability :|

Let me know how you think I should go about dealing with both these issues (and also, let me know if you have any other thoughts on why I might have gotten -inf log prob when I turned z_score_y off (of course, I couldn't train with z_score_y on in the first place)). Thanks again for all your help in advance!

@narendramukherjee
Copy link
Contributor Author

Actually, I dug a bit more into the -inf issue, and it seems like it happens because snpe_base._loss (which in turn calls snpe_base._neural_net._log_prob) returns NaNs in my case :( I went and looked at the _log_prob in nflows, and the NaNs first come up here when calling _transform: https://github.com/bayesiains/nflows/blob/master/nflows/flows/base.py#L39

I checked that the output of the _embedding_net seems fine and I am not passing any NaNs by mistake in the batches, so I'm a little out of depth on what might be happening. Do you have any thoughts?

@janfb
Copy link
Contributor

janfb commented Feb 16, 2021

Hi @narendramukherjee and thanks for the update.

I see the problems, e.g., the build_maf function in sbi/neural_nets/flow.py will pass a single x through the embedding net to infer the final embedded shape of x without using the collate_fn:
https://github.com/mackelab/sbi/blob/0aa376e67938a1fe30a76b408c1174f5a4506024/sbi/neural_nets/flow.py#L105

At this stage, however, I think this is quite a special case which is unlikely to occur for too many users. So I suggest we find a special solution for your problem, not a general one in sbi.

I suggest to keep the changes you suggest in this PR, i.e., the possibility to pass the data loader kwargs. Thus, we will finish reviewing this PR and merge it and work on a solution for your problem separately. OK?

@janfb
Copy link
Contributor

janfb commented Feb 16, 2021

I can think of at least two options now.

  1. you can pass your own build_neural_net function to sbi. Thus, you could just take our code from neural_nets/flow.py and change it such that it takes care of your case with the varying length of the time series.
    The problem with the z-scoring and the NaNs is more difficult. You would have to turn off z-scoring because it will otherwise either ignore all data entries with NaNs or have NaNs in the transform.
    Therefore, the other option might be better:
  2. You embed the data yourself, not using the joint training in sbi. Giving that you have time series data, you could train a RNN architecture or an autoencoder unsupervised in order to learn an embedding of x. Once trained, you take, e.g., the encoding network of your autoencoder and pass all time-series through it to obtain the embedded_x. And then you use sbi without embedding net and pass as training data theta, embedded_x.

Does that make sense?

@narendramukherjee
Copy link
Contributor Author

Hi @janfb : Thank you for those comments, yes, they do make sense. For now, I am also planning to just pad the whole set of simulations to a single fixed length, and that case works without a collate_fn. I need some results for a presentation next week, so I will just go with this approach for now and come back to the padding in batches with collate_fn idea in a few days time :)

So yeah, let's evaluate this PR and merge it if it looks okay to you for now.

Just a final q: in your experience, how important do you think the z-scoring of the simulations is and what are some pros/cons of doing it vs not doing it?

@janfb
Copy link
Contributor

janfb commented Feb 16, 2021

OK, sounds good.

to finish the PR I would ask you to rebase and squash your commits maybe into 2 separate commits, one with all the changes related to the data loader kwargs, and one with all the "unrelated" formatting changes?

regarding the z-scoring: under normal circumstances z-scoring never hurts, but only helps the NN with learning. But it can hurt when the data has many strong outliers, then it squashes potentially important differences in small values, or when the data contains NaNs like in your case, then it produces NaNs.

@janfb
Copy link
Contributor

janfb commented Feb 17, 2021

I had a quick look and tried to rebase your branch on main. But this turns out to be quite complicated.

Therefore, I think it would be easier, if you make a new branch from the current state of main and then basically redo your changes to add the data-loader kwargs on top of that.
Then you can also try to separate those changes from the formatting changes, i.e., make a first commit with all the formatting, and then a commit with the data-loader kwargs.

Would that be possible for you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants