-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add size invariant iid embedding nets, tests. #808
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A few points below. I'd like to think a bit about this, so let's maybe not merge it today yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's discuss again if we can avoid the loop for varying trial numbers.
0326c2c
to
502a70e
Compare
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more @@ Coverage Diff @@
## main #808 +/- ##
==========================================
+ Coverage 74.76% 74.81% +0.05%
==========================================
Files 80 80
Lines 6190 6191 +1
==========================================
+ Hits 4628 4632 +4
+ Misses 1562 1559 -3
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
502a70e
to
0ef585a
Compare
this is working accurately for a Gaussian iid example with up 100 trials, trained with varying number of trials. thanks to @manuelgloeckler's input, the forward pass is performed batched as well by masking the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey,
Looks good. One minor comment: I like the option to pass a custom aggregation function; currently, it would perfectly work for torch.sum . Yet for others like torch.max, torch.min or torch.median it would compute a slightly different value then expected as we substitute the invalid outputs with zero. We may should add this to the docstring, such that the user can adjust for this explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is getting really awesome, I like the way it is implemented now!
As far as I understand, this does not work if x
has a NaN
value (which is not just because of missing trials). We should spell this out clearly in the docstring of this embedding net. We could even add a check which tests whether there are x
for which only some summary stats are NaN
and raise an error in this case (I'm also happy to leave this as TODO in the code for now)
Regarding my comments: I'd be happy to have a quick call if anything is unclear
fe29549
to
2833d48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added asserts
to catch NaNs in standardizing nets. This will inform the user when they want to use NaNs to encode varying number of trials that they have to turn of z-scoring and set exclude_invalid_x=False
.
fbbaf8f
to
e5b3cb3
Compare
b2b9e55
to
5b8e232
Compare
5b8e232
to
5805e96
Compare
goal: have an embedding net for
NPE
that can handle varying number of iid trials, i.e., learn how the posterior changes as we change the number of trials.problem: our training procedure assume that
x
is a tensor with fixed dimensions, e.g., the trial dimension must be the same for all training data points.solution: given a training data set with a varying number of trials where the maximum number of trials is
max_num_trials
, pad allx
s with smaller number of trials withNaN
s such thatx.shape = (num_thetas, max_num_trials, dim_x)
. Adapt thePermutationInvariantEmbeddingNet
such that it detects theNaNs
and applies theTrialEmbedding
only to the valid entries (using a loop over the batch).functional tests: seems to work fine for a Gaussian example
Questions:
x
to be a list? treatx
atorch.DataSet
from the very start inappend_simulations
?