-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the ability to pass dataloader related kwargs in SNPE, SNRE and SNLE #435
Conversation
Many thanks for the PR, I will have a detailed look soon. A first thing, I noticed a lot of formatting changes, which is great. Would it be possible to move all the formatting into a separate commit? Then we can see more clearly what was done when, where and why. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Regarding your question of how to pass the additional kwarg, e.g., from snpe_c
to snpe_base
I added two comments at the specific locations.
sbi/inference/snle/snle_base.py
Outdated
@@ -137,6 +138,8 @@ def train( | |||
estimator for the posterior from scratch each round. | |||
show_train_summary: Whether to print the number of epochs and validation | |||
loss after the training. | |||
dataloader_kwargs: Any additional kwargs to be passed to the training and | |||
validation dataloaders (like a collate_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(like, e.g., a collate_fn
)
sbi/inference/snpe/snpe_base.py
Outdated
@@ -192,6 +190,8 @@ def train( | |||
estimator for the posterior from scratch each round. | |||
show_train_summary: Whether to print the number of epochs and validation | |||
loss after the training. | |||
dataloader_kwargs: Any additional kwargs to be passed to the training and | |||
validation dataloaders (like a collate_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment on formulation above.
sbi/inference/snpe/snpe_base.py
Outdated
@@ -203,9 +203,6 @@ def train( | |||
|
|||
max_num_epochs = 2 ** 31 - 1 if max_num_epochs is None else max_num_epochs | |||
|
|||
# Load data from most recent round. | |||
theta, x, _ = self.get_simulations(self._round, exclude_invalid_x, False) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we fixed this already. Which lets it look like you are based on an older main
branch. Have you tried rebasing your local branch on the most recent version of our main
?
sbi/inference/snre/snre_base.py
Outdated
@@ -139,6 +140,8 @@ def train( | |||
samples. | |||
retrain_from_scratch_each_round: Whether to retrain the conditional density | |||
estimator for the posterior from scratch each round. | |||
dataloader_kwargs: Any additional kwargs to be passed to the training and | |||
validation dataloaders (like a collate_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see comment above.
sbi/inference/snpe/snpe_base.py
Outdated
@@ -166,6 +163,7 @@ def train( | |||
discard_prior_samples: bool = False, | |||
retrain_from_scratch_each_round: bool = False, | |||
show_train_summary: bool = False, | |||
**dataloader_kwargs: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't you make this an explicit kwarg of train()
? E.g., dataloader_kwargs: Optional[dict] = None,
and then the same in the snpe_c
train
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in snpe_c
line 108 add the same additional kwarg: dataloader_kwargs: Optional[dict] = None
. I haven't tested it, just an idea.
Hi there! Thanks for creating the PR! Just a quick heads-up: we just merged #431 which will generate merge conflicts. Please let us know if you need assistance in rebasing your branch :) Michael |
Thanks for the heads-up @michaeldeistler and for the comments @janfb - I will merge the current main branch from the sbi repo and update my changes in the next couple of days! |
@janfb @michaeldeistler However, there are 2 additional issues that I now run into, despite being able to equip the dataloaders with a Now, using this setup with the
In the end, I was able to run the training by turning off Let me know how you think I should go about dealing with both these issues (and also, let me know if you have any other thoughts on why I might have gotten -inf log prob when I turned |
Actually, I dug a bit more into the I checked that the output of the |
Hi @narendramukherjee and thanks for the update. I see the problems, e.g., the At this stage, however, I think this is quite a special case which is unlikely to occur for too many users. So I suggest we find a special solution for your problem, not a general one in I suggest to keep the changes you suggest in this PR, i.e., the possibility to pass the data loader kwargs. Thus, we will finish reviewing this PR and merge it and work on a solution for your problem separately. OK? |
I can think of at least two options now.
Does that make sense? |
Hi @janfb : Thank you for those comments, yes, they do make sense. For now, I am also planning to just pad the whole set of simulations to a single fixed length, and that case works without a collate_fn. I need some results for a presentation next week, so I will just go with this approach for now and come back to the padding in batches with collate_fn idea in a few days time :) So yeah, let's evaluate this PR and merge it if it looks okay to you for now. Just a final q: in your experience, how important do you think the z-scoring of the simulations is and what are some pros/cons of doing it vs not doing it? |
OK, sounds good. to finish the PR I would ask you to rebase and squash your commits maybe into 2 separate commits, one with all the changes related to the data loader kwargs, and one with all the "unrelated" formatting changes? regarding the z-scoring: under normal circumstances z-scoring never hurts, but only helps the NN with learning. But it can hurt when the data has many strong outliers, then it squashes potentially important differences in small values, or when the data contains NaNs like in your case, then it produces NaNs. |
I had a quick look and tried to rebase your branch on Therefore, I think it would be easier, if you make a new branch from the current state of Would that be possible for you? |
NOTE: The majority of the code changes are formatting related changes, which came up when I ran black and isort on the repo (as instructed in the contribution guidelines)
The real code changes are in the following files:
inference/snpe/snpe_base.py
inference/snre/snre_base.py
inference/snle/snle_base.py
In each of these files, I have added a
**dataloader_kwargs
parameter to be passed to thetrain
method. I am imagining this to be "additional" args to be passed to the dataloader, over and above the ones that are already being used to build the train and validation dataloaders. In my specific case, this would be acollate_fn
. I thought that it would be a waste to make the user specify all the args to dataloader here, as most of the default args being used right now are great - any additional args can be passed using**dataloader_kwargs
.I tried testing this code change with SNPE - however,
snpe_c.py
has a relatively complicated way of calling thetrain
method ofsnpe_base
using kwargs built here: https://github.com/mackelab/sbi/blob/main/sbi/inference/snpe/snpe_c.py#L150So when I pass any dataloader related kwargs to
snpe_c.train
, those aren't recognized in the train method's signature. Since there's a complicated way in which thesnpe_base.train
is being called from here, I didn't want to mess with it before getting your opinion first.Let me know how I should proceed, and if you have any other suggestions/questions about my approach!
closing #435