Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DensityEstimator.loss does not take sample_dim #1149

Merged
merged 1 commit into from
Apr 25, 2024
Merged

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Apr 23, 2024

In #1066, we had defined that log_prob and loss have the same input and output shapes:

density_estimator.log_prob(input, condition)
input: (sample_input, batch_input, *event_shape_input)
condition: (batch_condition, *event_shape_condition)
returns: (sample_input, batch_input)
raises: batch_input != batch_condition

However, for .loss, we are now removing the sample_dim. Therefore, the .loss function now has the following signature:

input: (batch_input, *event_shape_input)
condition: (batch_condition, *event_shape_condition)
returns: (batch_input)
raises: batch_input != batch_condition

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)
  • For reviewer: The continuous deployment (CD) workflow are passing.

@michaeldeistler michaeldeistler force-pushed the lossdensityestimator branch 3 times, most recently from 7283987 to e2a8e75 Compare April 24, 2024 07:47
Copy link

codecov bot commented Apr 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 77.01%. Comparing base (005aeac) to head (6ae9ede).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1149      +/-   ##
==========================================
- Coverage   85.09%   77.01%   -8.09%     
==========================================
  Files          90       90              
  Lines        6649     6643       -6     
==========================================
- Hits         5658     5116     -542     
- Misses        991     1527     +536     
Flag Coverage Δ
unittests 77.01% <100.00%> (-8.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/snle/mnle.py 85.00% <ø> (-8.48%) ⬇️
sbi/inference/snle/snle_base.py 93.61% <100.00%> (ø)
sbi/inference/snpe/snpe_base.py 89.02% <100.00%> (ø)
sbi/neural_nets/density_estimators/base.py 57.14% <ø> (ø)
.../neural_nets/density_estimators/categorical_net.py 98.03% <100.00%> (ø)
...nets/density_estimators/mixed_density_estimator.py 69.11% <100.00%> (ø)
sbi/neural_nets/density_estimators/nflows_flow.py 62.74% <100.00%> (ø)
sbi/neural_nets/density_estimators/zuko_flow.py 64.44% <100.00%> (ø)

... and 22 files with indirect coverage changes

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! thanks for fixing this!
looks good, I added a couple of questions to clarify the changes.

You can run ruff check --fix sbi to fix the ruff errors.

@michaeldeistler michaeldeistler requested a review from janfb April 25, 2024 18:30
@janfb janfb merged commit afbd5e7 into main Apr 25, 2024
5 checks passed
@michaeldeistler michaeldeistler deleted the lossdensityestimator branch April 26, 2024 06:11
@bkmi bkmi mentioned this pull request Apr 26, 2024
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants