Skip to content

Commit

Permalink
mandatory linting
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Apr 15, 2024
1 parent 051fca9 commit a88d0e8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
22 changes: 13 additions & 9 deletions src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, TypeAlias, Optional, List
from typing import TYPE_CHECKING, Any, Optional, TypeAlias

import arviz as az
import numpy as np
import pandas as pd
import arviz as az
from pydantic import Field, field_validator, model_validator

from .base import BetterBase
Expand Down Expand Up @@ -396,7 +396,7 @@ def from_numpyro(
cls,
mcmc: numpyro.infer.MCMC,
name: str,
var_names: Optional[List[str]] = [],
var_names: Optional[list[str]] = [],
**kwargs: Any,
) -> Chain:
"""Constructor from numpyro samples
Expand All @@ -413,15 +413,17 @@ def from_numpyro(
"""

var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names})
df = pd.DataFrame.from_dict(
{key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
)
return cls(samples=df, name=name, **kwargs)

@classmethod
def from_arviz(
cls,
arviz_id: arviz.InferenceData,
name: str,
var_names: Optional[List[str]] = [],
var_names: Optional[list[str]] = [],
**kwargs: Any,
) -> Chain:
"""Constructor from an arviz InferenceData object
Expand Down Expand Up @@ -457,7 +459,7 @@ def vec_coordinate(self) -> np.ndarray:
return np.array(list(self.coordinate.values()))


def _filter_var_names(var_names: List[str], all_vars: List[str]):
def _filter_var_names(var_names: list[str], all_vars: list[str]):
"""
Helper function to return the var_names to allows filtering parameters names.
"""
Expand All @@ -466,10 +468,12 @@ def _filter_var_names(var_names: List[str], all_vars: List[str]):
return all_vars

elif var_names:
if not (all([var.startswith('~') for var in var_names]) or all([not var.startswith('~') for var in var_names])):
raise ValueError("all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset")
if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])):
raise ValueError(
"all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset"
)

if all([var.startswith('~') for var in var_names]):
if all([var.startswith("~") for var in var_names]):
# remove the ~ from the var names
var_names = [var[1:] for var in var_names]
var_names = [var for var in all_vars if var not in var_names]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_drop_on_arviz_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
arviz_id = az.from_numpyro(numpyro_mcmc)
chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"])
assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns)
assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns

chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)
Expand All @@ -101,7 +101,7 @@ def test_numpyro_translator(self):
def test_drop_on_numpyro_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"])
assert ("mu" in chain.samples.columns) and not ("sigma" in chain.samples.columns)
assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns

chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)
Expand Down

0 comments on commit a88d0e8

Please sign in to comment.