From 3f90553d42a3569b9d1ec8392bb0263f5937daba Mon Sep 17 00:00:00 2001 From: SamuelSonoiki Date: Wed, 4 Dec 2024 23:20:51 -0500 Subject: [PATCH] Add NestedToMCMCAdapter to enable compatibility with ArviZ and MCMC workflows (arviz-devs#2391) --- arviz/data/io_numpyro.py | 116 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index f8fc098a68..23860a745b 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -13,6 +13,101 @@ _log = logging.getLogger(__name__) +class NestedToMCMCAdapter: + """ + Adapter to convert a NestedSampler object into an MCMC-compatible interface. + + This class reshapes posterior samples from a NestedSampler into a chain-and-draw + structure expected by MCMC workflows, providing compatibility with downstream + tools like ArviZ for posterior analysis. + + Parameters + ---------- + nested_sampler : numpyro.contrib.nested_sampling.NestedSampler + The NestedSampler object containing posterior samples. + rng_key : jax.random.PRNGKey + The random key used for sampling. + num_samples : int + The total number of posterior samples to draw. + num_chains : int, optional + The number of artificial chains to create for MCMC compatibility (default is 1). + *args : tuple + Additional positional arguments required by the model (e.g., data, labels). + **kwargs : dict + Additional keyword arguments required by the model. + + Attributes + ---------- + samples : dict + Reshaped posterior samples organized by variable name. + thinning : int + Dummy thinning attribute for compatibility with MCMC. + sampler : NestedToMCMCAdapter + Mimics the sampler attribute of an MCMC object. + model : callable + The probabilistic model used in the NestedSampler. + _args : tuple + Positional arguments passed to the model. + _kwargs : dict + Keyword arguments passed to the model. + + Methods + ------- + get_samples(group_by_chain=True) + Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False. + get_extra_fields(group_by_chain=True) + Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps. + """ + + def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs): + self.nested_sampler = nested_sampler + self.rng_key = rng_key + self.num_samples = num_samples + self.num_chains = num_chains + self.samples = self._reshape_samples() + self.thinning = 1 + self.sampler = self + self.model = nested_sampler.model + self._args = args + self._kwargs = kwargs + + def _reshape_samples(self): + raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples) + samples_per_chain = self.num_samples // self.num_chains + return { + k: np.reshape( + v[: samples_per_chain * self.num_chains], + (self.num_chains, samples_per_chain, *v.shape[1:]), + ) + for k, v in raw_samples.items() + } + + def get_samples(self, group_by_chain=True): + if group_by_chain: + return self.samples + else: + # Flatten chains into a single dimension + return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()} + + def get_extra_fields(self, group_by_chain=True): + # Generate dummy fields since NestedSampler does not produce these + n_chains = self.num_chains + n_samples = self.num_samples // self.num_chains + + # Create dummy values for extra fields + extra_fields = { + "accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted + "step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size + "num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps + } + + if not group_by_chain: + # Flatten the chains into a single dimension + extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()} + + return extra_fields + + class NumPyroConverter: """Encapsulate NumPyro specific logic.""" @@ -37,6 +132,10 @@ def __init__( dims=None, pred_dims=None, num_chains=1, + rng_key=None, + num_samples=1000, + data=None, + labels=None, ): """Convert NumPyro data into an InferenceData object. @@ -68,6 +167,15 @@ def __init__( import numpyro self.posterior = posterior + self.rng_key = rng_key + self.num_samples = num_samples + + if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler): + posterior = NestedToMCMCAdapter( + posterior, rng_key, num_samples, num_chains=num_chains, data=data, labels=labels + ) + self.posterior = posterior + self.prior = jax.device_get(prior) self.posterior_predictive = jax.device_get(posterior_predictive) self.predictions = predictions @@ -340,6 +448,10 @@ def from_numpyro( dims=None, pred_dims=None, num_chains=1, + rng_key=None, + num_samples=1000, + data=None, + labels=None, ): """Convert NumPyro data into an InferenceData object. @@ -383,4 +495,8 @@ def from_numpyro( dims=dims, pred_dims=pred_dims, num_chains=num_chains, + rng_key=rng_key, + num_samples=num_samples, + data=data, + labels=labels, ).to_inference_data()