diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index c4f7e64b..65db3682 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -10,10 +10,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias, Optional, List import numpy as np import pandas as pd +import arviz as az from pydantic import Field, field_validator, model_validator from .base import BetterBase @@ -395,6 +396,7 @@ def from_numpyro( cls, mcmc: numpyro.infer.MCMC, name: str, + var_names: Optional[List[str]] = [], **kwargs: Any, ) -> Chain: """Constructor from numpyro samples @@ -402,12 +404,16 @@ def from_numpyro( Args: mcmc: The numpyro sampler name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from numpyro samples """ - df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()}) + + var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys())) + df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}) return cls(samples=df, name=name, **kwargs) @classmethod @@ -415,6 +421,7 @@ def from_arviz( cls, arviz_id: arviz.InferenceData, name: str, + var_names: Optional[List[str]] = [], **kwargs: Any, ) -> Chain: """Constructor from an arviz InferenceData object @@ -422,13 +429,17 @@ def from_arviz( Args: arviz_id: The arviz inference data name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from the arviz chain """ - df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"]) + var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys())) + reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior") + df = reduced_id.to_dataframe().drop(columns=["chain", "draw"]) return cls(samples=df, name=name, **kwargs) @@ -444,3 +455,28 @@ class MaxPosterior(BetterBase): def vec_coordinate(self) -> np.ndarray: """The coordinate as a numpy array, in the order the columns were given.""" return np.array(list(self.coordinate.values())) + + +def _filter_var_names(var_names: List[str], all_vars: List[str]): + """ + Helper function to return the var_names to allows filtering parameters names. + """ + + if not var_names: + return all_vars + + elif var_names: + if not (all([var.startswith('~') for var in var_names]) or all([not var.startswith('~') for var in var_names])): + raise ValueError("all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset") + + if all([var.startswith('~') for var in var_names]): + # remove the ~ from the var names + var_names = [var[1:] for var in var_names] + var_names = [var for var in all_vars if var not in var_names] + + return var_names + + else: + # keep var_names as is but check if var is in all_vars + var_names = [var for var in all_vars if var in var_names] + return var_names diff --git a/tests/test_translators.py b/tests/test_translators.py index 7439ada8..199d52ea 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -83,12 +83,29 @@ def test_arviz_translator(self): assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column + def test_drop_on_arviz_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + arviz_id = az.from_numpyro(numpyro_mcmc) + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns) + + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_numpyro_translator(self): numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) chain = Chain.from_numpyro(numpyro_mcmc, "numpyro") assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) + def test_drop_on_numpyro_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns) + + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_emcee_translator(self): emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains) chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"], "emcee")