From 6f5d333fc9afd3acfb16557f569a88d9753fbbfb Mon Sep 17 00:00:00 2001 From: sdupourque Date: Mon, 15 Apr 2024 17:33:15 +0200 Subject: [PATCH] adding "var_names" parameter to filter the variables --- src/chainconsumer/chain.py | 44 ++++++++++++++++++++++++++++++++++++-- tests/test_translators.py | 17 +++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index c4f7e64b..a0335ae3 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -12,6 +12,7 @@ import logging from typing import TYPE_CHECKING, Any, TypeAlias +import arviz as az import numpy as np import pandas as pd from pydantic import Field, field_validator, model_validator @@ -395,6 +396,7 @@ def from_numpyro( cls, mcmc: numpyro.infer.MCMC, name: str, + var_names: list[str] = [], **kwargs: Any, ) -> Chain: """Constructor from numpyro samples @@ -402,12 +404,18 @@ def from_numpyro( Args: mcmc: The numpyro sampler name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from numpyro samples """ - df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()}) + + var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys())) + df = pd.DataFrame.from_dict( + {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names} + ) return cls(samples=df, name=name, **kwargs) @classmethod @@ -415,6 +423,7 @@ def from_arviz( cls, arviz_id: arviz.InferenceData, name: str, + var_names: list[str] = [], **kwargs: Any, ) -> Chain: """Constructor from an arviz InferenceData object @@ -422,13 +431,17 @@ def from_arviz( Args: arviz_id: The arviz inference data name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from the arviz chain """ - df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"]) + var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys())) + reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior") + df = reduced_id.to_dataframe().drop(columns=["chain", "draw"]) return cls(samples=df, name=name, **kwargs) @@ -444,3 +457,30 @@ class MaxPosterior(BetterBase): def vec_coordinate(self) -> np.ndarray: """The coordinate as a numpy array, in the order the columns were given.""" return np.array(list(self.coordinate.values())) + + +def _filter_var_names(var_names: list[str], all_vars: list[str]): + """ + Helper function to return the var_names to allows filtering parameters names. + """ + + if not var_names: + return all_vars + + elif var_names: + if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])): + raise ValueError( + "all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset" + ) + + if all([var.startswith("~") for var in var_names]): + # remove the ~ from the var names + var_names = [var[1:] for var in var_names] + var_names = [var for var in all_vars if var not in var_names] + + return var_names + + else: + # keep var_names as is but check if var is in all_vars + var_names = [var for var in all_vars if var in var_names] + return var_names diff --git a/tests/test_translators.py b/tests/test_translators.py index 7439ada8..318812a4 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -83,12 +83,29 @@ def test_arviz_translator(self): assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column + def test_drop_on_arviz_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + arviz_id = az.from_numpyro(numpyro_mcmc) + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns + + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_numpyro_translator(self): numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) chain = Chain.from_numpyro(numpyro_mcmc, "numpyro") assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) + def test_drop_on_numpyro_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns + + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_emcee_translator(self): emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains) chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"], "emcee")