-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointing #273
Checkpointing #273
Conversation
Hi @jan-matthis @janfb @Meteore for my STG project, I need the mechanism described above, so I would like to implement it in sbi soon. Please let me know if you have a preference for suggestion 1 or suggestion 2. To seed this discussion: |
I would like us to list somewhere the state that we are carrying around and cannot be managed as return values that get passed again into one round of inference. I mentioned something similar on the PR about external data and I think it is a discussion worth to have, possibly in a dedicated meeting. PS. not sure I understand what |
Here's a list of the state:
|
In the example above,
|
@jan-matthis @janfb @Meteore After quite some discussions yesterday, we thought that it might be a good idea to change the API of multi-round. In the description of this PR, I outline one way to do it. Have a look and let me know what you think. The code for this is also "ready" (despite only for snpe and not documented yet), so, if you want, also have a look at that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a first look at the PR, looking good! I guess SNLE, SNRE, tests and changes to infer()
are still forthcoming
sbi/inference/posterior.py
Outdated
@@ -93,6 +94,16 @@ def __init__( | |||
# Correction factor for leakage, only applicable to SNPE-family methods. | |||
self._leakage_density_correction_factor = None | |||
|
|||
def focus_training_on(self, x) -> "NeuralPosterior": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing no invocation of this method in the code, shouldn't it be called at some point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user has to call it after inference. See the API example in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I missed the edited version. How about sticking to method names that unspecific to training/inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example:
# Single round inference with prior as proposal
posterior = infer(num_simulations=200, proposal=None) # proposal=None is also default.
posteriors = [posterior]
# Run additional rounds with the last posterior as new proposal.
for round_ in range(1, num_rounds):
posteriors.append(infer(num_simulations=200, proposal=posteriors[round_-1].set_default_x(x_o)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or shorter:
posteriors = []
proposal = None
for _ in range(num_rounds):
posterior = infer(num_simulations=200, proposal=proposal)
proposal = posterior.set_default_x(x_o)
posteriors.append(posterior)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More generally, I wonder if infer()
should keep a num_rounds
keyword and build such a loop internally
7e97cc0
to
2ec0929
Compare
c55e075
to
7a5a625
Compare
1762c55
to
042f584
Compare
Checkpointing and new API for multi-round
API
Main changes
x_o
argument_x_o_trained_on
attribute