Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions for filtering out values being outside the prior in flow.py #206

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

asasli
Copy link

@asasli asasli commented Feb 7, 2025

  1. _get_log_prior_dict : Constructs a dictionary for the priors (also adds boundaries for the missing ones).
  2. filter_descaled_parameters : Filters the descaled posteriors based on the imposed prior and prints the number of discarded values (both total and per parameter).
  3. Updated test_step to call filter_descaled_parameters.
    The final descaled posteriors will be within the prior boundaries.

1. ```_get_log_prior_dict``` : Constructs a dictionary for the priors (also adds boundaries for the missing ones).
2. ```filter_descaled_parameters``` : Filters the descaled posteriors based on the imposed prior and prints the number of discarded values (both total and per parameter).
3. Updated ```test_step``` to call ```filter_descaled_parameters```.
# 'phi': Uniform(low: -3.140000104904175, high: 3.140000104904175)
log_prior_dict[param] = getattr(self.trainer.datamodule.waveform_sampler, param)
# add low and high boundaries (for some parameters these values are missing)
log_prior_dict['distance'].low = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asasli Im wondering if, instead of using the prior ranges, we can simply call prior.prob(sample), and discard the elements that have 0 probability. That way, this functionality is robust to changes in the prior ranges, and we don't have to add these prior specific low and high attributes. Do you think that would work?

Alternatively, we could add low and high attributes to each of the prior classes.

valid_idxs &= (descaled[:, idx] >= low) & (descaled[:, idx] <= high)
discarded_count = (~valid_idxs).sum().item()
num_discarded += discarded_count
print(f"Discarded samples[{param}]: {discarded_count}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of print, you should be able to use self._logger.info to use the model logger object

discarded_count = (~valid_idxs).sum().item()
num_discarded += discarded_count
print(f"Discarded samples[{param}]: {discarded_count}")
print(f"Total discarded samples: {num_discarded}/{descaled.shape[0]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -133,6 +195,7 @@ def test_step(self, batch, _):
)
descaled = self.trainer.datamodule.scale(samples, reverse=True)
parameters = self.trainer.datamodule.scale(parameters, reverse=True)
descaled = self.filter_descaled_parameters(descaled) # filter out samples outside of prior boundaries
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we typically add comments above the relevant code line in this repo. Of course, nothing wrong with adding them inline, but good to keep style consistent

not attached `trainer` corrected
replace `print` with `self._logger.info`
Fixing the printed numbers!
Copy link
Collaborator

@deepchatterjeeligo deepchatterjeeligo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be interested to see whether setting validate_args=False in the priors resolves most of the issue you are seeing.

Comment on lines +131 to +132
if trainer is not None:
self.trainer = trainer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step should not be necessary. (I know we talked about this; mentioning here for my own sake).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants