Skip to content

Commit

Permalink
oups I didn't see there was a pre-commit setup
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Oct 15, 2023
1 parent b52a985 commit 851f436
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 34 deletions.
34 changes: 15 additions & 19 deletions src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
There are also a few helper functions and objects in here, like the `MaxPosterior` class which
provides the log posterior and the coordinate at which it can be found for the chain."""
from __future__ import annotations
from typing import TYPE_CHECKING

import logging
from typing import Any, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -365,12 +364,11 @@ def get_correlation(self, columns: list[str] | None = None) -> Named2DMatrix:

@classmethod
def from_emcee(
cls,
sampler: "emcee.EnsembleSampler",
columns: list[str],
name: str = 'Chain',
**kwargs: Any,

cls,
sampler: emcee.EnsembleSampler,
columns: list[str],
name: str = "Chain",
**kwargs: Any,
) -> Chain:
"""
Constructor from an emcee sampler
Expand All @@ -382,17 +380,16 @@ def from_emcee(
kwargs: Any other arguments to pass to the Chain constructor.
"""

df = pd.DataFrame.from_dict({col:val for col, val in zip(columns, sampler.get_chain(flat=True).T)})
df = pd.DataFrame.from_dict({col: val for col, val in zip(columns, sampler.get_chain(flat=True).T)})

return cls(samples=df, name=name, **kwargs)

@classmethod
def from_numpyro(
cls,
mcmc: "numpyro.infer.MCMC",
name: str = 'Chain',
**kwargs: Any,

cls,
mcmc: numpyro.infer.MCMC,
name: str = "Chain",
**kwargs: Any,
) -> Chain:
"""
Constructor from an emcee sampler
Expand All @@ -409,11 +406,10 @@ def from_numpyro(

@classmethod
def from_arviz(
cls,
arviz_id: "arviz.InferenceData",
name: str = 'Chain',
**kwargs: Any,

cls,
arviz_id: arviz.InferenceData,
name: str = "Chain",
**kwargs: Any,
) -> Chain:
"""
Constructor from an arviz InferenceData object
Expand Down
26 changes: 11 additions & 15 deletions tests/test_translators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import arviz as az
import emcee
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
import jax.random as random
import emcee
import numpy as np
import scipy.stats as stats
import arviz as az
from scipy import stats

from chainconsumer import Chain

Expand All @@ -20,12 +20,12 @@ def run_numpyro_mcmc(n_steps, n_chains):

def model(data=None):
# Prior
mu = numpyro.sample('mu', dist.Normal(0, 10))
sigma = numpyro.sample('sigma', dist.HalfNormal(10))
mu = numpyro.sample("mu", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.HalfNormal(10))

# Likelihood
with numpyro.plate('data', size=len(data)):
numpyro.sample('obs', dist.Normal(mu, sigma), obs=data)
with numpyro.plate("data", size=len(data)):
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

# Running MCMC
kernel = NUTS(model)
Expand Down Expand Up @@ -72,29 +72,25 @@ def log_probability(theta, data):


class TestTranslators:

n_steps: int = 2000
n_chains: int = 4
n_params: int = 2

def test_arviz_translator(self):

numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
arviz_id = az.from_numpyro(numpyro_mcmc)
chain = Chain.from_arviz(arviz_id)

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) #+1 for weight column
assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column

def test_numpyro_translator(self):

numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc)

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1)

def test_emcee_translator(self):

emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_emcee(emcee_sampler, ['mu', 'sigma'])
chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"])

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1)

0 comments on commit 851f436

Please sign in to comment.