Skip to content

Commit

Permalink
feedback: get arviz object with separate class method.
Browse files Browse the repository at this point in the history
- fix reset of posterior sampler for new x.
- fix docstrings.
  • Loading branch information
janfb committed Aug 10, 2022
1 parent 4d8d4bc commit 094625c
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 1,111 deletions.
3 changes: 2 additions & 1 deletion sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
)
Expand All @@ -141,7 +143,6 @@ def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
"Context `x` needed when a default has not been set."
"If you'd like to have a default, use the `.set_default_x()` method."
)
self._posterior_sampler = None
else:
return self.default_x

Expand Down
90 changes: 62 additions & 28 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

import arviz as az
import torch
import torch.distributions.transforms as torch_tf
from arviz.data import InferenceData
Expand All @@ -28,12 +29,7 @@
)
from sbi.simulators.simutils import tqdm_joblib
from sbi.types import Shape, TorchTransform
from sbi.utils import (
get_arviz_diagnostics,
pyro_potential_wrapper,
tensor2numpy,
transformed_potential,
)
from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -56,7 +52,6 @@ def __init__(
init_strategy_parameters: Dict[str, Any] = {},
init_strategy_num_candidates: Optional[int] = None,
num_workers: int = 1,
param_name: str = "theta",
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
):
Expand Down Expand Up @@ -91,9 +86,6 @@ def __init__(
locations in `init_strategy=sir` (deprecated, use
init_strategy_parameters instead).
num_workers: number of cpu cores used to parallelize mcmc
param_name: Name of the sampled parameters used internally. When sampling
with `mcmc_method` of `slice`, `hmc`, or `nuts`, this name is used in
the sampler returned by `self.posterior_sampler` after sampling.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
Expand All @@ -115,8 +107,9 @@ def __init__(
self.init_strategy = init_strategy
self.init_strategy_parameters = init_strategy_parameters
self.num_workers = num_workers
self.param_name = param_name
self._posterior_sampler = None
# Hardcode parameter name to reduce clutter kwargs.
self.param_name = "theta"

if init_strategy_num_candidates is not None:
warn(
Expand Down Expand Up @@ -204,7 +197,6 @@ def sample(
mcmc_method: Optional[str] = None,
sample_with: Optional[str] = None,
num_workers: Optional[int] = None,
return_arviz: bool = False,
show_progress_bars: bool = True,
) -> Union[Tensor, Tuple[Tensor, InferenceData]]:
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.
Expand Down Expand Up @@ -318,18 +310,7 @@ def sample(

samples = self.theta_transform.inv(transformed_samples)

# Maybe return Arviz inference data object.
if return_arviz:
return (
samples.reshape((*sample_shape, -1)), # type: ignore
get_arviz_diagnostics(
self._posterior_sampler,
param_name=self.param_name,
theta_transform=self.theta_transform, # type: ignore
),
)
else:
return samples.reshape((*sample_shape, -1)) # type: ignore
return samples.reshape((*sample_shape, -1)) # type: ignore

def _build_mcmc_init_fn(
self,
Expand Down Expand Up @@ -463,10 +444,9 @@ def _slice_np_mcmc(
initial_params: Initial parameters for MCMC chain.
thin: Thinning (subsampling) factor.
warmup_steps: Initial number of samples to discard.
vectorized: Whether to use a vectorized implementation of
the Slice sampler (still experimental).
num_workers: number of CPU cores to use
seed: seed that will be used to generate sub-seeds for each worker
vectorized: Whether to use a vectorized implementation of the Slice sampler.
num_workers: Number of CPU cores to use.
init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
can only be turned off for vectorized sampler.
Expand Down Expand Up @@ -667,6 +647,60 @@ def map(
force_update=force_update,
)

def get_arviz_inference_data(self) -> InferenceData:
"""Returns arviz InferenceData object constructed most recent samples.
Note: the InferenceData is constructed using the posterior samples generated in
most recent call to `.sample(...)`.
For Pyro HMC and NUTS kernels InferenceData will contain diagnostics, for Pyro
Slice or sbi slice sampling samples, only the samples are added.
Returns:
inference_data: Arviz InferenceData object.
"""
assert (
self._posterior_sampler is not None
), """No samples have been generated, call .sample() first."""

sampler: Union[
MCMC, SliceSamplerSerial, SliceSamplerVectorized
] = self._posterior_sampler

# If Pyro sampler and samples not transformed, use arviz' from_pyro.
# Exclude 'slice' kernel as it lacks the 'divergence' diagnostics key.
if isinstance(self._posterior_sampler, (HMC, NUTS)) and isinstance(
self.theta_transform, torch_tf.IndependentTransform
):
inference_data = az.from_pyro(sampler)

# otherwise get samples from sampler and transform to original space.
else:
transformed_samples = sampler.get_samples(group_by_chain=True)
# Pyro samplers returns dicts, get values.
if isinstance(transformed_samples, Dict):
# popitem gets last items, [1] get the values as tensor.
transformed_samples = transformed_samples.popitem()[1]
# Our slice samplers return numpy arrays.
elif isinstance(transformed_samples, ndarray):
transformed_samples = torch.from_numpy(transformed_samples).type(
torch.float32
)
# For MultipleIndependent priors transforms first dim must be batch dim.
# thus, reshape back and forth to have batch dim in front.
num_chains, samples_per_chain, dim_params = transformed_samples.shape
samples = self.theta_transform.inv( # type: ignore
transformed_samples.reshape(-1, dim_params)
).reshape( # type: ignore
num_chains, samples_per_chain, dim_params
)

inference_data = az.convert_to_inference_data(
{f"{self.param_name}": samples}
)

return inference_data


def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
"""Returns `default` if `key` is not in the dict and otherwise the dict entry.
Expand Down
1 change: 0 additions & 1 deletion sbi/samplers/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
SliceSampler,
SliceSamplerSerial,
SliceSamplerVectorized,
run_slice_np_vectorized_parallelized,
)
100 changes: 8 additions & 92 deletions sbi/samplers/mcmc/slice_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@ def run_fun(self, num_samples, inits, seed) -> np.ndarray:

def get_samples(
self, num_samples: Optional[int] = None, group_by_chain: bool = True
) -> Union[None, np.ndarray]:
) -> np.ndarray:
"""Returns samples from last call to self.run.
Returns None if no samples have been generated yet.
Raises ValueError if no samples have been generated yet.
Args:
num_samples: Number of samples to return (for each chain if grouped by
Expand All @@ -328,10 +328,10 @@ def get_samples(
x dim_params) or flattened (all_samples, dim_params).
Returns:
samples (or None if no samples have been generated yet)
samples
"""
if self._samples is None:
return None
raise ValueError("No samples found from MCMC run.")
# if not grouped by chain, flatten samples into (all_samples, dim_params)
if not group_by_chain:
samples = self._samples.reshape(-1, self._samples.shape[2])
Expand Down Expand Up @@ -592,10 +592,10 @@ def run(self, num_samples: int) -> np.ndarray:

def get_samples(
self, num_samples: Optional[int] = None, group_by_chain: bool = True
) -> Union[None, np.ndarray]:
) -> np.ndarray:
"""Returns samples from last call to self.run.
Returns None if no samples have been generated yet.
Raises ValueError if no samples have been generated yet.
Args:
num_samples: Number of samples to return (for each chain if grouped by
Expand All @@ -604,10 +604,10 @@ def get_samples(
x dim_params) or flattened (all_samples, dim_params).
Returns:
samples (or None if no samples have been generated yet)
samples
"""
if self._samples is None:
return None
raise ValueError("No samples found from MCMC run.")
# if not grouped by chain, flatten samples into (all_samples, dim_params)
if not group_by_chain:
samples = self._samples.reshape(-1, self._samples.shape[2])
Expand All @@ -622,87 +622,3 @@ def get_samples(
return samples[:, -num_samples:, :]
else:
return samples[-num_samples:, :]


def run_slice_np_vectorized_parallelized(
potential_function: Callable,
initial_params: torch.Tensor,
num_samples: int,
thin: int,
warmup_steps: int,
vectorized: bool,
num_workers: int = 1,
show_progress_bars: bool = False,
):
"""Run slice np (vectorized) parallized over CPU cores.
In case of the vectorized version of slice np parallization happens over batches of
chains to still exploit vectorization.
MCMC progress bars are omitted if num_workers > 1 to reduce clutter. Instead the
progress over chains is shown.
Args:
potential_function: potential function
initial_params: initital parameters, one for each chain
num_samples: number of MCMC samples to produce
thin: thinning factor
warmup_steps: number of warmup / burnin steps
vectorized: whether to use the vectorized version
num_workers: number of CPU cores to use
show_progress_bars: whether to show progress bars
Returns:
Tensor: final MCMC samples of each chain (num_chains, num_samples, dim_samples)
"""
num_chains, dim_samples = initial_params.shape

# Generate seeds for workers from current random state.
seeds = torch.randint(high=2**31, size=(num_chains,))

# Define local function to run a batch of chains vectorized.
def run_slice_np_vectorized(inits, seed):
# Seed current job.
np.random.seed(seed)
posterior_sampler = SliceSamplerVectorized(
init_params=tensor2numpy(inits),
log_prob_fn=potential_function,
num_chains=inits.shape[0],
# Show pbars of workers only for single worker
verbose=show_progress_bars and num_workers == 1,
)
# TODO: move warmup and thinning into SliceSamplerVectorized?
warmup_ = warmup_steps * thin
num_samples_ = ceil((num_samples * thin) / num_chains)
samples = posterior_sampler.run(warmup_ + num_samples_)
samples = samples[:, warmup_:, :] # discard warmup steps
samples = samples[:, ::thin, :] # thin chains
samples = torch.from_numpy(samples) # chains x samples x dim
return samples

# For vectorized case a batch contains multiple chains to exploit vectorization.
batch_size = ceil(num_chains / num_workers)
run_fun = run_slice_np_vectorized

# Parallize over batch of chains.
initial_params_in_batches = torch.split(initial_params, batch_size, dim=0)
num_batches = len(initial_params_in_batches)

# Show progress bars over batches.
with tqdm_joblib(
tqdm(
range(num_batches), # type: ignore
disable=not show_progress_bars or num_workers == 1,
desc=f"""Running {num_chains} MCMC chains with {num_workers} worker{"s" if
num_workers>1 else ""} (batch_size={batch_size}).""",
total=num_chains,
)
):
all_samples = Parallel(n_jobs=num_workers)(
delayed(run_fun)(initial_params_batch, seed)
for initial_params_batch, seed in zip(initial_params_in_batches, seeds)
)
all_samples = np.stack(all_samples).astype(np.float32)
samples = torch.from_numpy(all_samples)

return samples.reshape(num_chains, -1, dim_samples) # chains x samples x dim
1 change: 0 additions & 1 deletion sbi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
clamp_and_warn,
del_entries,
expit,
get_arviz_diagnostics,
get_simulations_since_round,
gradient_ascent,
handle_invalid_x,
Expand Down
58 changes: 0 additions & 58 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,61 +855,3 @@ def gradient_ascent(
return argmax_, max_val # type: ignore

return theta_transform.inv(best_theta_overall), max_val # type: ignore


def get_arviz_diagnostics(
sampler: Any,
theta_transform: Optional[TorchTransform],
param_name: Optional[str] = "theta",
) -> InferenceData:
"""Returns arviz InferenceData object constructed from sampler or samples.
Note: if parameters were obtained in unconstrained space the corresponding
theta_transform should be passed to ensure that the arviz InferenceData lives in
constrained space.
For Pyro HMC and NUTS kernels InferenceData will contain diagnostics, for Pyro Slice
or sbi slice sampling samples, only the samples are added.
Args:
sampler: Pyro MCMC object.
theta_transform: parameter transform for parameters theta, optional but needed
if MCMC was performed in unconstrained space.
param_name: internal name of the parameters (not their dimensions).
Returns:
inference_data: Arviz InferenceData object.
"""

# construct from sampler from Pyro object, but only if parameters are not
# transformed.
if sampler is not None and theta_transform is None:
try:
inference_data = az.from_pyro(sampler)
# Will result in KeyError on the 'divergence' diagnostics for the Slice kernel.
except KeyError:
# Construct InferenceData from samples instead.
samples = sampler.get_samples(group_by_chain=True).popitem()[1]
inference_data = az.convert_to_inference_data({f"{param_name}": samples})
# get samples from sampler and transform to original space.
else:
transformed_samples = sampler.get_samples(group_by_chain=True)
# Pyro samplers returns dicts, get values.
if isinstance(transformed_samples, Dict):
transformed_samples = transformed_samples.popitem()[1]
elif isinstance(transformed_samples, ndarray):
transformed_samples = torch.from_numpy(transformed_samples).type(
torch.float32
)
# Transform with additional chain dimension fails for MultipleIndependent
# priors, reshape back and forth instead.
num_chains, samples_per_chain, dim_params = transformed_samples.shape
samples = theta_transform.inv( # type: ignore
transformed_samples.reshape(-1, dim_params)
).reshape( # type: ignore
num_chains, samples_per_chain, dim_params
)

inference_data = az.convert_to_inference_data({f"{param_name}": samples})

return inference_data
Loading

0 comments on commit 094625c

Please sign in to comment.