diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index 65db3682..47ec22e5 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -10,11 +10,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, TypeAlias, Optional, List +from typing import TYPE_CHECKING, Any, Optional, TypeAlias +import arviz as az import numpy as np import pandas as pd -import arviz as az from pydantic import Field, field_validator, model_validator from .base import BetterBase @@ -396,7 +396,7 @@ def from_numpyro( cls, mcmc: numpyro.infer.MCMC, name: str, - var_names: Optional[List[str]] = [], + var_names: Optional[list[str]] = [], **kwargs: Any, ) -> Chain: """Constructor from numpyro samples @@ -413,7 +413,9 @@ def from_numpyro( """ var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys())) - df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}) + df = pd.DataFrame.from_dict( + {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names} + ) return cls(samples=df, name=name, **kwargs) @classmethod @@ -421,7 +423,7 @@ def from_arviz( cls, arviz_id: arviz.InferenceData, name: str, - var_names: Optional[List[str]] = [], + var_names: Optional[list[str]] = [], **kwargs: Any, ) -> Chain: """Constructor from an arviz InferenceData object @@ -457,7 +459,7 @@ def vec_coordinate(self) -> np.ndarray: return np.array(list(self.coordinate.values())) -def _filter_var_names(var_names: List[str], all_vars: List[str]): +def _filter_var_names(var_names: list[str], all_vars: list[str]): """ Helper function to return the var_names to allows filtering parameters names. """ @@ -466,10 +468,12 @@ def _filter_var_names(var_names: List[str], all_vars: List[str]): return all_vars elif var_names: - if not (all([var.startswith('~') for var in var_names]) or all([not var.startswith('~') for var in var_names])): - raise ValueError("all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset") + if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])): + raise ValueError( + "all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset" + ) - if all([var.startswith('~') for var in var_names]): + if all([var.startswith("~") for var in var_names]): # remove the ~ from the var names var_names = [var[1:] for var in var_names] var_names = [var for var in all_vars if var not in var_names] diff --git a/tests/test_translators.py b/tests/test_translators.py index 199d52ea..318812a4 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -87,7 +87,7 @@ def test_drop_on_arviz_translator(self): numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) arviz_id = az.from_numpyro(numpyro_mcmc) chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"]) - assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"]) assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) @@ -101,7 +101,7 @@ def test_numpyro_translator(self): def test_drop_on_numpyro_translator(self): numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"]) - assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"]) assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)