-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dealing with custom DataLoader in PosteriorEstimator.train
#415
Comments
Hi @narendramukherjee and thanks for creating the issue! Yes, here our API seems a bit stiff. It would be great if you make a PR for fixing it! We think that an addition argument in Regarding your other point -- you are absolutely right, this seems to be a left over from a refactoring we did recently. We will remove it, thanks! |
Hi @janfb , absolutely, your suggestion makes sense and I will work on it and file a PR. It might take me a few days to get to doing this though, and I shall tag you as soon as it is ready :D |
Great, @narendramukherjee ! No worries, take your time and let us know if you need help. |
@janfb Thanks for your previous comments - I looked into this a bit more today, and I feel like the API expects the simulation outputs (x) to be a multi-dimensional tensor, which is causing several issues (for eg., while validating theta and x). Let me explain my problem scenario a bit more - my simulations generate time-series which have variable lengths (right now, they vary from 20 to 3000). Each time-series itself is a torch Tensor (of shape So, right now, I have the simulations in a list, where each element of the list is a torch Tensor of shape I wrote a custom dataset like this:
And a
So basically, each batch would yield a Wondering what your thoughts are on this issue, and how I can approach it? Is there something different I can do (and I am happy to work on any API changes that can accommodate this)? I don't really want to pad all my sequences to the size of the biggest one across all simulations, as their sizes go across a couple orders of magnitude, and that is why I thought of doing that on a batch-by-batch level. Thanks again for all your help, and for putting sbi together in the first place :) |
I see the problem. Changing the requirement that x is a Tensor would be a major change which we would prefer to avoid. My first approach would indeed be to pad all your simulations to the same length using, e.g., NaNs, and then to use the collate_function to deal with that. I think this will not really be a memory issue unless you have millions of simulations, will it? |
@janfb That's a good idea - I didn't think of padding using NaNs and then replacing those inside the collate_fn. Then I also probably don't need a custom Dataset and can use TensorDataset as is the case in |
@janfb I have made an initial PR about this issue #435 . I have run into some problems as (I think) we need to add |
Hi,
First of all, thanks a ton to the Macke lab team for building this toolbox. It has made it really easy to get started with using SBI in so many application domains! :)
I wanted to check in as right now the NN data loading process is fixed inside
PosteriorEstimator.train
. I have a situation where I am using a 1D ConvNet as an embedding network for time-series data that are unequal in length. So I need to add in acollate_fn
to a custom DataLoader to pad the unequal length sequences in each batch before feeding them into the embedding network. The output from the embedding net into the density estimator will be fixed length, but the input into the embedding net can be variable length. The only way I can think of doing this right now is to write my own custom implementation ofPosteriorEstimator.train
(where only the data loading part changes), so I wanted to ask if:a) There's any other way of doing this?
b) If no, would there be interest if I filed a PR where one could pass a custom Dataset/DataLoader to the
train
method?Finally, I had a very basic question about the
train
method - I am unable to understand the utility of this line when thetheta
andx
get re-written right afterwards here. I am probably missing something very basic, and would love to have clarification from you :)The text was updated successfully, but these errors were encountered: