diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e82621c5..0ab70bbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files args: ["--maxkb=5000"] @@ -14,7 +14,7 @@ repos: - --unsafe - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.14 + rev: v0.3.7 hooks: - id: ruff args: ["--fix", "--no-unsafe-fixes"] diff --git a/docs/examples/advanced_examples/plot_0_grid.py b/docs/examples/advanced_examples/plot_0_grid.py index a58697b8..1b45d674 100644 --- a/docs/examples/advanced_examples/plot_0_grid.py +++ b/docs/examples/advanced_examples/plot_0_grid.py @@ -15,6 +15,7 @@ the order of the parameters is not preserved in the dictionary. """ + import numpy as np import pandas as pd diff --git a/docs/examples/advanced_examples/plot_1_blinding.py b/docs/examples/advanced_examples/plot_1_blinding.py index b03cbdef..555f6371 100644 --- a/docs/examples/advanced_examples/plot_1_blinding.py +++ b/docs/examples/advanced_examples/plot_1_blinding.py @@ -8,6 +8,7 @@ or give it a string (or list of strings) detailing the specific parameters you want blinded! """ + from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample df = make_sample(num_dimensions=4, seed=1) diff --git a/docs/examples/advanced_examples/plot_2_kde.py b/docs/examples/advanced_examples/plot_2_kde.py index 6e427122..a0a91e58 100644 --- a/docs/examples/advanced_examples/plot_2_kde.py +++ b/docs/examples/advanced_examples/plot_2_kde.py @@ -15,6 +15,7 @@ increses the width of the marginal distributions. """ + from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample df = make_sample(num_dimensions=2, seed=3, num_points=1000) diff --git a/docs/examples/advanced_examples/plot_3_divide_chains.py b/docs/examples/advanced_examples/plot_3_divide_chains.py index c4d9b4d8..4d0dc7bb 100644 --- a/docs/examples/advanced_examples/plot_3_divide_chains.py +++ b/docs/examples/advanced_examples/plot_3_divide_chains.py @@ -12,6 +12,7 @@ In this toy example, all the chains are from the same random generator, so they're on top of each other. Except MCMC chains to not be as perfect. """ + from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample df = make_sample(num_dimensions=2, seed=3, num_points=40000) diff --git a/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py b/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py index 7f55f4de..dff18149 100644 --- a/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py +++ b/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py @@ -3,6 +3,7 @@ Rather than having one example for each option, let's condense things. """ + # %% # Shade Gradient # -------------- diff --git a/docs/examples/plot_0_contours.py b/docs/examples/plot_0_contours.py index 85995424..a44a5d35 100644 --- a/docs/examples/plot_0_contours.py +++ b/docs/examples/plot_0_contours.py @@ -5,6 +5,7 @@ handle the defaults and display. """ + from chainconsumer import Chain, ChainConfig, ChainConsumer, PlotConfig, Truth, make_sample # Here's what you might start with diff --git a/docs/examples/plot_1_summary.py b/docs/examples/plot_1_summary.py index 35e84b87..150e2aef 100644 --- a/docs/examples/plot_1_summary.py +++ b/docs/examples/plot_1_summary.py @@ -7,6 +7,7 @@ To show you how they work, let's make some sample data that all has the same average. """ + from chainconsumer import Chain, ChainConfig, ChainConsumer, PlotConfig, Truth, make_sample # Here's what you might start with diff --git a/docs/examples/plot_2_textual_output.py b/docs/examples/plot_2_textual_output.py index 8606c70e..b2cfc81d 100644 --- a/docs/examples/plot_2_textual_output.py +++ b/docs/examples/plot_2_textual_output.py @@ -4,6 +4,7 @@ Because typing those things out is a **massive pain in the ass.** """ + from chainconsumer import Chain, ChainConsumer, Truth, make_sample # Here's a sample dataset diff --git a/docs/examples/plot_3_distributions.py b/docs/examples/plot_3_distributions.py index 36ac4875..38ee4512 100644 --- a/docs/examples/plot_3_distributions.py +++ b/docs/examples/plot_3_distributions.py @@ -7,6 +7,7 @@ To show you how they work, let's make some sample data that all has the same average. """ + from chainconsumer import Chain, ChainConsumer, Truth, make_sample # Here's what you might start with diff --git a/docs/examples/plot_4_walks.py b/docs/examples/plot_4_walks.py index d4d4bf75..40352a3f 100644 --- a/docs/examples/plot_4_walks.py +++ b/docs/examples/plot_4_walks.py @@ -3,6 +3,7 @@ Want to see if your chain is behaving nicely? Use a walk! """ + from chainconsumer import Chain, ChainConsumer, Truth, make_sample # Here's a sample dataset diff --git a/docs/examples/plot_6_custom_axes.py b/docs/examples/plot_6_custom_axes.py index 6ff1f8d9..026816e0 100644 --- a/docs/examples/plot_6_custom_axes.py +++ b/docs/examples/plot_6_custom_axes.py @@ -8,6 +8,7 @@ Here's an example, noting that there are also `plot_point`, `plot_surface` available that I haven't explicitly shown. """ + import matplotlib.pyplot as plt from chainconsumer import Chain, Truth, make_sample diff --git a/src/chainconsumer/analysis.py b/src/chainconsumer/analysis.py index da67b11e..0e81688c 100644 --- a/src/chainconsumer/analysis.py +++ b/src/chainconsumer/analysis.py @@ -26,9 +26,9 @@ class Bound(BetterBase): def array(self) -> np.ndarray: return np.array( [ - self.lower if self.lower is not None else np.NaN, - self.center if self.center is not None else np.NaN, - self.upper if self.upper is not None else np.NaN, + self.lower if self.lower is not None else np.nan, + self.center if self.center is not None else np.nan, + self.upper if self.upper is not None else np.nan, ] ) diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index c4f7e64b..1321ae64 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -7,6 +7,7 @@ There are also a few helper functions and objects in here, like the `MaxPosterior` class which provides the log posterior and the coordinate at which it can be found for the chain.""" + from __future__ import annotations import logging @@ -395,6 +396,7 @@ def from_numpyro( cls, mcmc: numpyro.infer.MCMC, name: str, + var_names: list[str] | None = None, **kwargs: Any, ) -> Chain: """Constructor from numpyro samples @@ -402,12 +404,18 @@ def from_numpyro( Args: mcmc: The numpyro sampler name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from numpyro samples """ - df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()}) + + var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys())) + df = pd.DataFrame.from_dict( + {key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names} + ) return cls(samples=df, name=name, **kwargs) @classmethod @@ -415,6 +423,7 @@ def from_arviz( cls, arviz_id: arviz.InferenceData, name: str, + var_names: list[str] | None = None, **kwargs: Any, ) -> Chain: """Constructor from an arviz InferenceData object @@ -422,13 +431,19 @@ def from_arviz( Args: arviz_id: The arviz inference data name: The name of the chain + var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, + they are excluded from the variables. If empty, all parameters are included. kwargs: Any other arguments to pass to the Chain constructor. Returns: A ChainConsumer Chain made from the arviz chain """ - df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"]) + import arviz as az + + var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys())) + reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior") + df = reduced_id.to_dataframe().drop(columns=["chain", "draw"]) return cls(samples=df, name=name, **kwargs) @@ -444,3 +459,31 @@ class MaxPosterior(BetterBase): def vec_coordinate(self) -> np.ndarray: """The coordinate as a numpy array, in the order the columns were given.""" return np.array(list(self.coordinate.values())) + + +def _filter_var_names(var_names: list[str] | None, all_vars: list[str]) -> list[str]: + """ + Helper function to return the var_names to allows filtering parameters names. + """ + + if var_names is None: + return all_vars + + negations = set([var.startswith("~") for var in var_names]) + + if len(negations) != 1: + raise ValueError( + "all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset" + ) + + if True in negations: + # remove the ~ from the var names + var_names = [var[1:] for var in var_names] + var_names = [var for var in all_vars if var not in var_names] + + return var_names + + else: + # keep var_names as is but check if var is in all_vars + var_names = [var for var in all_vars if var in var_names] + return var_names diff --git a/tests/test_translators.py b/tests/test_translators.py index 7439ada8..318812a4 100644 --- a/tests/test_translators.py +++ b/tests/test_translators.py @@ -83,12 +83,29 @@ def test_arviz_translator(self): assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column + def test_drop_on_arviz_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + arviz_id = az.from_numpyro(numpyro_mcmc) + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns + + chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_numpyro_translator(self): numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) chain = Chain.from_numpyro(numpyro_mcmc, "numpyro") assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) + def test_drop_on_numpyro_translator(self): + numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains) + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"]) + assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns + + chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"]) + assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns) + def test_emcee_translator(self): emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains) chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"], "emcee")