-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RatioEstimator abstraction #1097
Conversation
I added a test. it's extremely simple, but it checks that the estimated ratios are the correct size . (like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will follow up after Friday. |
@tomMoral I think this is done now. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1097 +/- ##
==========================================
- Coverage 84.54% 75.61% -8.93%
==========================================
Files 95 96 +1
Lines 7576 7603 +27
==========================================
- Hits 6405 5749 -656
- Misses 1171 1854 +683
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few extra comments but overall it looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Ben,
thanks a lot! The shape-PR for DensityEstimator
s is finally merged in #1066 so we can start pushing this along. See a few questions below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing this along! 👏
I think it makes sense that @michaeldeistler reviews this.
|
goals:
|
As commented in #1103 , please make sure the all relevant (non renaming related changes) from #1103 are moved to this PR (I think they are all in here anyways). Otherwise I think that this PR is almost done, no?
Thanks a lot! 🙏 |
I will no longer remove It looks like zukoflow expects that the batch_theta and batch_x have shape (batch, *shape). I used that convention to determine x_shape and theta_shape, as well. i.e. theta_shape = batch_theta.shape[0] and x_shape = batch_x.shape[0] I support sample_dim, and batch_dim by forcing the inputs to the RatioEstimator to have the same prefix. No broadcasting is allowed. When the unnormalized_log_ratio is computed, the prefix shape gets flattened into a single effective batch, then unflattened back to the prefix shape. |
if this passes the CI, all tests that check anything related to ratios are passing on my computer as well:
|
Right now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @bkmi ! 🎉
What does this implement/fix? Explain your changes
It introduces a
RatioEstimator
abstraction. It will wrap neural networks that processx
andtheta
to estimate ratios. The base class is suggestive of how to create extensions that use alternative data types (dict
, etc.) while flexible enough to handle how embedding works and how embedded data is combined. Natural extensions include ratio estimators with specific network architectures such as transformers or convnets.This implements #992 which is the ratio specific issue for #1046. It also implements the ratio part of #957 and makes the documentation confirm to the shape goals of #1041.
Does this close any currently open issues?
Closes:
#992
Also #1036 since it eliminates a confusingly name and now-irrelevant class.
The others are more general and require more work on subjects related to ratios before closing.
Any relevant code examples, logs, error output, etc?
Nope
Any other comments?
There is a test that doesn't pass, namely #1090, but I think it is not because of my changes.
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)