Skip to content

Commit

Permalink
Allow continuation of MCMC chains when using latest_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-matthis committed Sep 27, 2020
1 parent cba822c commit 3250b06
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 3 deletions.
12 changes: 10 additions & 2 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from torch import nn

from sbi import utils as utils
from sbi.mcmc import Slice, SliceSampler, SliceSamplerVectorized, prior_init, sir
from sbi.mcmc import (
Slice,
SliceSampler,
SliceSamplerVectorized,
IterateParameters,
prior_init,
sir,
)
from sbi.types import Array, Shape
from sbi.user_input.user_input_checks import process_x
from sbi.utils.torchutils import (
Expand Down Expand Up @@ -594,7 +601,8 @@ def _build_mcmc_init_fn(
elif init_strategy == "sir":
return lambda: sir(prior, potential_fn, **kwargs)
elif init_strategy == "latest_sample":
return lambda: self._mcmc_init_params
latest_sample = IterateParameters(self._mcmc_init_params, **kwargs)
return latest_sample
else:
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def __call__(
mcmc_parameters=self._mcmc_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit neural likelihood to newly aggregated dataset.
self._train(
training_batch_size=training_batch_size,
Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ def __call__(
rejection_sampling_parameters=self._rejection_sampling_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit posterior using newly aggregated data set.
self._train(
proposal=proposal,
Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def __call__(
mcmc_parameters=self._mcmc_parameters,
)

# Copy MCMC init parameters for latest sample init
if hasattr(proposal, "_mcmc_init_params"):
self._posterior._mcmc_init_params = proposal._mcmc_init_params

# Fit posterior using newly aggregated data set.
self._train(
num_atoms=num_atoms,
Expand Down
2 changes: 1 addition & 1 deletion sbi/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sbi.mcmc.slice_numpy import SliceSampler
from sbi.mcmc.slice_numpy_vectorized import SliceSamplerVectorized
from sbi.mcmc.slice import Slice
from sbi.mcmc.init_strategy import prior_init, sir
from sbi.mcmc.init_strategy import IterateParameters, prior_init, sir
16 changes: 16 additions & 0 deletions sbi/mcmc/init_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from torch import Tensor


class IterateParameters:
"""Iterates through parameters by rows
"""

def __init__(self, parameters: torch.Tensor, **kwargs):
self.iter = self._make_iterator(parameters)

@staticmethod
def _make_iterator(t):
for i in range(t.shape[0]):
yield t[i, :].reshape(1, -1)

def __call__(self):
return next(self.iter)


def prior_init(prior: Any, **kwargs: Any) -> Tensor:
"""Return a sample from the prior."""
return prior.sample((1,)).detach()
Expand Down

0 comments on commit 3250b06

Please sign in to comment.