-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functions for filtering out values being outside the prior in flow.py #206
base: main
Are you sure you want to change the base?
Conversation
1. ```_get_log_prior_dict``` : Constructs a dictionary for the priors (also adds boundaries for the missing ones). 2. ```filter_descaled_parameters``` : Filters the descaled posteriors based on the imposed prior and prints the number of discarded values (both total and per parameter). 3. Updated ```test_step``` to call ```filter_descaled_parameters```.
# 'phi': Uniform(low: -3.140000104904175, high: 3.140000104904175) | ||
log_prior_dict[param] = getattr(self.trainer.datamodule.waveform_sampler, param) | ||
# add low and high boundaries (for some parameters these values are missing) | ||
log_prior_dict['distance'].low = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@asasli Im wondering if, instead of using the prior ranges, we can simply call prior.prob(sample)
, and discard the elements that have 0 probability. That way, this functionality is robust to changes in the prior ranges, and we don't have to add these prior specific low
and high
attributes. Do you think that would work?
Alternatively, we could add low
and high
attributes to each of the prior classes.
amplfi/train/models/flow.py
Outdated
valid_idxs &= (descaled[:, idx] >= low) & (descaled[:, idx] <= high) | ||
discarded_count = (~valid_idxs).sum().item() | ||
num_discarded += discarded_count | ||
print(f"Discarded samples[{param}]: {discarded_count}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of print, you should be able to use self._logger.info
to use the model logger object
amplfi/train/models/flow.py
Outdated
discarded_count = (~valid_idxs).sum().item() | ||
num_discarded += discarded_count | ||
print(f"Discarded samples[{param}]: {discarded_count}") | ||
print(f"Total discarded samples: {num_discarded}/{descaled.shape[0]}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
amplfi/train/models/flow.py
Outdated
@@ -133,6 +195,7 @@ def test_step(self, batch, _): | |||
) | |||
descaled = self.trainer.datamodule.scale(samples, reverse=True) | |||
parameters = self.trainer.datamodule.scale(parameters, reverse=True) | |||
descaled = self.filter_descaled_parameters(descaled) # filter out samples outside of prior boundaries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, we typically add comments above the relevant code line in this repo. Of course, nothing wrong with adding them inline, but good to keep style consistent
not attached `trainer` corrected
replace `print` with `self._logger.info`
Fixing the printed numbers!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be interested to see whether setting validate_args=False
in the priors resolves most of the issue you are seeing.
if trainer is not None: | ||
self.trainer = trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This step should not be necessary. (I know we talked about this; mentioning here for my own sake).
_get_log_prior_dict
: Constructs a dictionary for the priors (also adds boundaries for the missing ones).filter_descaled_parameters
: Filters the descaled posteriors based on the imposed prior and prints the number of discarded values (both total and per parameter).test_step
to callfilter_descaled_parameters
.The final descaled posteriors will be within the prior boundaries.