diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index dee1ac28..ad24d721 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -8,10 +8,9 @@ There are also a few helper functions and objects in here, like the `MaxPosterior` class which provides the log posterior and the coordinate at which it can be found for the chain.""" from __future__ import annotations -from typing import TYPE_CHECKING import logging -from typing import Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias import numpy as np import pandas as pd @@ -365,12 +364,11 @@ def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix: @classmethod def from_emcee( - cls, - sampler: "emcee.EnsembleSampler", - columns: list[str], - name: str = 'Chain', - **kwargs: Any, - + cls, + sampler: emcee.EnsembleSampler, + columns: list[str], + name: str = "Chain", + **kwargs: Any, ) -> Chain: """ Constructor from an emcee sampler @@ -382,17 +380,16 @@ def from_emcee( kwargs: Any other arguments to pass to the Chain constructor. """ - df = pd.DataFrame.from_dict({col:val for col, val in zip(columns, sampler.get_chain(flat=True).T)}) + df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, sampler.get_chain(flat=True).T)}) return cls(samples=df, name=name, **kwargs) @classmethod def from_numpyro( - cls, - mcmc: "numpyro.infer.MCMC", - name: str = 'Chain', - **kwargs: Any, - + cls, + mcmc: numpyro.infer.MCMC, + name: str = "Chain", + **kwargs: Any, ) -> Chain: """ Constructor from an emcee sampler @@ -409,11 +406,10 @@ def from_numpyro( @classmethod def from_arviz( - cls, - arviz_id: "arviz.InferenceData", - name: str = 'Chain', - **kwargs: Any, - + cls, + arviz_id: arviz.InferenceData, + name: str = "Chain", + **kwargs: Any, ) -> Chain: """ Constructor from an arviz InferenceData object diff --git a/tests/test_translators.py b/tests/test_translators.py index ad107b0e..a8e15e02 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -1,11 +1,11 @@ +import arviz as az +import emcee +import numpy as np import numpyro import numpyro.distributions as dist +from jax import random from numpyro.infer import MCMC, NUTS -import jax.random as random -import emcee -import numpy as np -import scipy.stats as stats -import arviz as az +from scipy import stats from chainconsumer import Chain @@ -20,12 +20,12 @@ def run_numpyro_mcmc(n_steps, n_chains): def model(data=None): # Prior - mu = numpyro.sample('mu', dist.Normal(0, 10)) - sigma = numpyro.sample('sigma', dist.HalfNormal(10)) + mu = numpyro.sample("mu", dist.Normal(0, 10)) + sigma = numpyro.sample("sigma", dist.HalfNormal(10)) # Likelihood - with numpyro.plate('data', size=len(data)): - numpyro.sample('obs', dist.Normal(mu, sigma), obs=data) + with numpyro.plate("data", size=len(data)): + numpyro.sample("obs", dist.Normal(mu, sigma), obs=data) # Running MCMC kernel = NUTS(model) @@ -72,29 +72,25 @@ def log_probability(theta, data): class TestTranslators: - n_steps: int = 2000 n_chains: int = 4 n_params: int = 2 def test_arviz_translator(self): - numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) arviz_id = az.from_numpyro(numpyro_mcmc) chain = Chain.from_arviz(arviz_id) - assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) #+1 for weight column + assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column def test_numpyro_translator(self): - numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) chain = Chain.from_numpyro(numpyro_mcmc) assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) def test_emcee_translator(self): - emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains) - chain = Chain.from_emcee(emcee_sampler, ['mu', 'sigma']) + chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"]) assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1)