Skip to content

Commit

Permalink
adding "var_names" parameter to filter the variable to keep when buil…
Browse files Browse the repository at this point in the history
…ding chain from numpyro or arviz
  • Loading branch information
renecotyfanboy committed Apr 15, 2024
1 parent f461e02 commit 051fca9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
42 changes: 39 additions & 3 deletions src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias, Optional, List

import numpy as np
import pandas as pd
import arviz as az
from pydantic import Field, field_validator, model_validator

from .base import BetterBase
Expand Down Expand Up @@ -395,40 +396,50 @@ def from_numpyro(
cls,
mcmc: numpyro.infer.MCMC,
name: str,
var_names: Optional[List[str]] = [],
**kwargs: Any,
) -> Chain:
"""Constructor from numpyro samples
Args:
mcmc: The numpyro sampler
name: The name of the chain
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
they are excluded from the variables. If empty, all parameters are included.
kwargs: Any other arguments to pass to the Chain constructor.
Returns:
A ChainConsumer Chain made from numpyro samples
"""
df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()})

var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names})
return cls(samples=df, name=name, **kwargs)

@classmethod
def from_arviz(
cls,
arviz_id: arviz.InferenceData,
name: str,
var_names: Optional[List[str]] = [],
**kwargs: Any,
) -> Chain:
"""Constructor from an arviz InferenceData object
Args:
arviz_id: The arviz inference data
name: The name of the chain
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
they are excluded from the variables. If empty, all parameters are included.
kwargs: Any other arguments to pass to the Chain constructor.
Returns:
A ChainConsumer Chain made from the arviz chain
"""

df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"])
var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))
reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
df = reduced_id.to_dataframe().drop(columns=["chain", "draw"])

return cls(samples=df, name=name, **kwargs)

Expand All @@ -444,3 +455,28 @@ class MaxPosterior(BetterBase):
def vec_coordinate(self) -> np.ndarray:
"""The coordinate as a numpy array, in the order the columns were given."""
return np.array(list(self.coordinate.values()))


def _filter_var_names(var_names: List[str], all_vars: List[str]):
"""
Helper function to return the var_names to allows filtering parameters names.
"""

if not var_names:
return all_vars

elif var_names:
if not (all([var.startswith('~') for var in var_names]) or all([not var.startswith('~') for var in var_names])):
raise ValueError("all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset")

if all([var.startswith('~') for var in var_names]):
# remove the ~ from the var names
var_names = [var[1:] for var in var_names]
var_names = [var for var in all_vars if var not in var_names]

return var_names

else:
# keep var_names as is but check if var is in all_vars
var_names = [var for var in all_vars if var in var_names]
return var_names
17 changes: 17 additions & 0 deletions tests/test_translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,29 @@ def test_arviz_translator(self):

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column

def test_drop_on_arviz_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
arviz_id = az.from_numpyro(numpyro_mcmc)
chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"])
assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns)

chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)

def test_numpyro_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro")

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1)

def test_drop_on_numpyro_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"])
assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns)

chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)

def test_emcee_translator(self):
emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"], "emcee")
Expand Down

0 comments on commit 051fca9

Please sign in to comment.