Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing #273

Merged
merged 14 commits into from
Jul 28, 2020
Merged

Checkpointing #273

merged 14 commits into from
Jul 28, 2020

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Jul 23, 2020

Checkpointing and new API for multi-round

API

# Run additional rounds with the last posterior as new proposal.
proposal = None
posteriors = []
for round_ in range(num_rounds):
    posterior = infer(num_simulations=200, proposal=proposal)
    proposal = posterior.set_default_x(x_o)
    posteriors.append(posterior)

Main changes

  • no more x_o argument
  • no more _x_o_trained_on attribute
  • proposal needs to be passed as a function only in theta

@michaeldeistler michaeldeistler added enhancement New feature or request API changes This impacts the public API of the project (e.g. inference class). labels Jul 23, 2020
@michaeldeistler michaeldeistler self-assigned this Jul 23, 2020
@michaeldeistler
Copy link
Contributor Author

Hi @jan-matthis @janfb @Meteore

for my STG project, I need the mechanism described above, so I would like to implement it in sbi soon. Please let me know if you have a preference for suggestion 1 or suggestion 2.

To seed this discussion:
I prefer suggestion 1 because:
a) it does not require an additional argument to __call__()
b) having a .continue() function makes it very explicit to the user what is happing.

@alvorithm
Copy link
Contributor

Hi @jan-matthis @janfb @Meteore

for my STG project, I need the mechanism described above, so I would like to implement it in sbi soon. Please let me know if you have a preference for suggestion 1 or suggestion 2.

To seed this discussion:
I prefer suggestion 1 because:
a) it does not require an additional argument to __call__()
b) having a .continue() function makes it very explicit to the user what is happing.

I would like us to list somewhere the state that we are carrying around and cannot be managed as return values that get passed again into one round of inference. I mentioned something similar on the PR about external data and I think it is a discussion worth to have, possibly in a dedicated meeting.

PS. not sure I understand what start_new_round is doing and how it interacts with num_rounds.

@michaeldeistler
Copy link
Contributor Author

Here's a list of the state:

  • _theta_roundwise
  • _x_roundwise
  • _prior_masks
  • _data_round_index
  • _posterior
  • _model_bank

@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented Jul 23, 2020

In the example above, start_new_round=False and hence the next simulations will still come from the same distribution as in the round before, in this case from the prior. If it were True, we would start a new round, i.e. a second round and hence simulate from the posterior.

num_rounds simply indicates how many rounds are being run in the current call to __call__() or .continue(), respectively.

@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented Jul 24, 2020

@jan-matthis @janfb @Meteore After quite some discussions yesterday, we thought that it might be a good idea to change the API of multi-round. In the description of this PR, I outline one way to do it. Have a look and let me know what you think.

The code for this is also "ready" (despite only for snpe and not documented yet), so, if you want, also have a look at that.

Copy link
Contributor

@jan-matthis jan-matthis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a first look at the PR, looking good! I guess SNLE, SNRE, tests and changes to infer() are still forthcoming

sbi/inference/base.py Outdated Show resolved Hide resolved
sbi/inference/base.py Outdated Show resolved Hide resolved
sbi/inference/base.py Show resolved Hide resolved
sbi/inference/snpe/snpe_base.py Outdated Show resolved Hide resolved
sbi/inference/snpe/snpe_base.py Show resolved Hide resolved
sbi/inference/snpe/snpe_c.py Outdated Show resolved Hide resolved
sbi/inference/snpe/snpe_base.py Outdated Show resolved Hide resolved
@@ -93,6 +94,16 @@ def __init__(
# Correction factor for leakage, only applicable to SNPE-family methods.
self._leakage_density_correction_factor = None

def focus_training_on(self, x) -> "NeuralPosterior":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing no invocation of this method in the code, shouldn't it be called at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user has to call it after inference. See the API example in the PR description.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I missed the edited version. How about sticking to method names that unspecific to training/inference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

# Single round inference with prior as proposal
posterior = infer(num_simulations=200, proposal=None)   # proposal=None is also default.
posteriors = [posterior]

# Run additional rounds with the last posterior as new proposal.
for round_ in range(1, num_rounds):
    posteriors.append(infer(num_simulations=200, proposal=posteriors[round_-1].set_default_x(x_o)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or shorter:

posteriors = []
proposal = None
for _ in range(num_rounds):
    posterior = infer(num_simulations=200, proposal=proposal)
    proposal = posterior.set_default_x(x_o)
    posteriors.append(posterior)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More generally, I wonder if infer() should keep a num_rounds keyword and build such a loop internally

@michaeldeistler michaeldeistler linked an issue Jul 24, 2020 that may be closed by this pull request
@michaeldeistler michaeldeistler force-pushed the checkpointing branch 3 times, most recently from c55e075 to 7a5a625 Compare July 24, 2020 15:27
@michaeldeistler michaeldeistler force-pushed the checkpointing branch 2 times, most recently from 1762c55 to 042f584 Compare July 28, 2020 20:43
@michaeldeistler michaeldeistler merged commit 4fbc2ba into main Jul 28, 2020
@michaeldeistler michaeldeistler deleted the checkpointing branch July 28, 2020 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API changes This impacts the public API of the project (e.g. inference class). enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Warn if invalid simulations + multiround SNPE-C
3 participants