Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow continuation of MCMC chains when using latest_sample #348

Merged
merged 2 commits into from
Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# v0.13.1

- Make logging of vectorized numpy slice sampler slightly less verbose and address NumPy future warning (#347)
- Allow continuation of MCMC chains (#348)


# v0.13.0
Expand Down
12 changes: 10 additions & 2 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from torch import nn

from sbi import utils as utils
from sbi.mcmc import Slice, SliceSampler, SliceSamplerVectorized, prior_init, sir
from sbi.mcmc import (
Slice,
SliceSampler,
SliceSamplerVectorized,
IterateParameters,
prior_init,
sir,
)
from sbi.types import Array, Shape
from sbi.user_input.user_input_checks import process_x
from sbi.utils.torchutils import (
Expand Down Expand Up @@ -594,7 +601,8 @@ def _build_mcmc_init_fn(
elif init_strategy == "sir":
return lambda: sir(prior, potential_fn, **kwargs)
elif init_strategy == "latest_sample":
return lambda: self._mcmc_init_params
latest_sample = IterateParameters(self._mcmc_init_params, **kwargs)
return latest_sample
else:
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def __call__(
mcmc_parameters=self._mcmc_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit neural likelihood to newly aggregated dataset.
self._train(
training_batch_size=training_batch_size,
Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ def __call__(
rejection_sampling_parameters=self._rejection_sampling_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit posterior using newly aggregated data set.
self._train(
proposal=proposal,
Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def __call__(
mcmc_parameters=self._mcmc_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit posterior using newly aggregated data set.
self._train(
num_atoms=num_atoms,
Expand Down
2 changes: 1 addition & 1 deletion sbi/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sbi.mcmc.slice_numpy import SliceSampler
from sbi.mcmc.slice_numpy_vectorized import SliceSamplerVectorized
from sbi.mcmc.slice import Slice
from sbi.mcmc.init_strategy import prior_init, sir
from sbi.mcmc.init_strategy import IterateParameters, prior_init, sir
16 changes: 16 additions & 0 deletions sbi/mcmc/init_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from torch import Tensor


class IterateParameters:
"""Iterates through parameters by rows
"""

def __init__(self, parameters: torch.Tensor, **kwargs):
self.iter = self._make_iterator(parameters)

@staticmethod
def _make_iterator(t):
for i in range(t.shape[0]):
yield t[i, :].reshape(1, -1)

def __call__(self):
return next(self.iter)


def prior_init(prior: Any, **kwargs: Any) -> Tensor:
"""Return a sample from the prior."""
return prior.sample((1,)).detach()
Expand Down