Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding "var_names" parameter to filter the variable to keep when building chain from numpyro or arviz #122

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-added-large-files
args: ["--maxkb=5000"]
Expand All @@ -14,7 +14,7 @@ repos:
- --unsafe
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.3.7
hooks:
- id: ruff
args: ["--fix", "--no-unsafe-fixes"]
Expand Down
1 change: 1 addition & 0 deletions docs/examples/advanced_examples/plot_0_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
the order of the parameters is not preserved in the dictionary.

"""

import numpy as np
import pandas as pd

Expand Down
1 change: 1 addition & 0 deletions docs/examples/advanced_examples/plot_1_blinding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
or give it a string (or list of strings) detailing the specific parameters you want blinded!

"""

from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample

df = make_sample(num_dimensions=4, seed=1)
Expand Down
1 change: 1 addition & 0 deletions docs/examples/advanced_examples/plot_2_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
increses the width of the marginal distributions.

"""

from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample

df = make_sample(num_dimensions=2, seed=3, num_points=1000)
Expand Down
1 change: 1 addition & 0 deletions docs/examples/advanced_examples/plot_3_divide_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
In this toy example, all the chains are from the same random generator,
so they're on top of each other. Except MCMC chains to not be as perfect.
"""

from chainconsumer import Chain, ChainConsumer, PlotConfig, make_sample

df = make_sample(num_dimensions=2, seed=3, num_points=40000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Rather than having one example for each option, let's condense things.
"""

# %%
# Shade Gradient
# --------------
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_0_contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
handle the defaults and display.

"""

from chainconsumer import Chain, ChainConfig, ChainConsumer, PlotConfig, Truth, make_sample

# Here's what you might start with
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_1_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
To show you how they work, let's make some sample data that all
has the same average.
"""

from chainconsumer import Chain, ChainConfig, ChainConsumer, PlotConfig, Truth, make_sample

# Here's what you might start with
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_2_textual_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Because typing those things out is a **massive pain in the ass.**

"""

from chainconsumer import Chain, ChainConsumer, Truth, make_sample

# Here's a sample dataset
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_3_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
To show you how they work, let's make some sample data that all
has the same average.
"""

from chainconsumer import Chain, ChainConsumer, Truth, make_sample

# Here's what you might start with
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_4_walks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Want to see if your chain is behaving nicely? Use a walk!
"""

from chainconsumer import Chain, ChainConsumer, Truth, make_sample

# Here's a sample dataset
Expand Down
1 change: 1 addition & 0 deletions docs/examples/plot_6_custom_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Here's an example, noting that there are also `plot_point`, `plot_surface` available
that I haven't explicitly shown.
"""

import matplotlib.pyplot as plt

from chainconsumer import Chain, Truth, make_sample
Expand Down
6 changes: 3 additions & 3 deletions src/chainconsumer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class Bound(BetterBase):
def array(self) -> np.ndarray:
return np.array(
[
self.lower if self.lower is not None else np.NaN,
self.center if self.center is not None else np.NaN,
self.upper if self.upper is not None else np.NaN,
self.lower if self.lower is not None else np.nan,
self.center if self.center is not None else np.nan,
self.upper if self.upper is not None else np.nan,
]
)

Expand Down
45 changes: 43 additions & 2 deletions src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

There are also a few helper functions and objects in here, like the `MaxPosterior` class which
provides the log posterior and the coordinate at which it can be found for the chain."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, TypeAlias

import arviz as az
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm reading through this now I realise that that arviz is only a dependency in the test group, so I'd be tempted to move this import down to a conditional one where we wrap it in a try - catch.

The type hint should stay as arviz.blah because that import is only eval'd if TYPE_CHECKING (so shouldnt be a problem if its not there, only the type hinter will be sad)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I move the import to the from_arviz constructor only, and didn't add the try-catch as it is very unlikely that someone using it do not have arviz

import numpy as np
import pandas as pd
from pydantic import Field, field_validator, model_validator
Expand Down Expand Up @@ -395,40 +397,52 @@ def from_numpyro(
cls,
mcmc: numpyro.infer.MCMC,
name: str,
var_names: list[str] = [],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make this list[str] | None = None or similar to ensure we dont have mutable kwarg defaults

**kwargs: Any,
) -> Chain:
"""Constructor from numpyro samples

Args:
mcmc: The numpyro sampler
name: The name of the chain
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
they are excluded from the variables. If empty, all parameters are included.
kwargs: Any other arguments to pass to the Chain constructor.

Returns:
A ChainConsumer Chain made from numpyro samples
"""
df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()})

var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys()))
df = pd.DataFrame.from_dict(
{key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names}
)
return cls(samples=df, name=name, **kwargs)

@classmethod
def from_arviz(
cls,
arviz_id: arviz.InferenceData,
name: str,
var_names: list[str] = [],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with above :)

**kwargs: Any,
) -> Chain:
"""Constructor from an arviz InferenceData object

Args:
arviz_id: The arviz inference data
name: The name of the chain
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~,
they are excluded from the variables. If empty, all parameters are included.
kwargs: Any other arguments to pass to the Chain constructor.

Returns:
A ChainConsumer Chain made from the arviz chain
"""

df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"])
var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))
reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
df = reduced_id.to_dataframe().drop(columns=["chain", "draw"])

return cls(samples=df, name=name, **kwargs)

Expand All @@ -444,3 +458,30 @@ class MaxPosterior(BetterBase):
def vec_coordinate(self) -> np.ndarray:
"""The coordinate as a numpy array, in the order the columns were given."""
return np.array(list(self.coordinate.values()))


def _filter_var_names(var_names: list[str], all_vars: list[str]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint

"""
Helper function to return the var_names to allows filtering parameters names.
"""

if not var_names:
return all_vars

elif var_names:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the primary if has a return, better to simply unindent and remove the elif - will keep the code simpler and easier to read

if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with this, but other way we could do it with less duplication is:

negations =set([var.startswith("~") for var in var_names])
if len(negations) != 1:
    raise...

if True in negations:
    ...
    return blah
...
return blah

raise ValueError(
"all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset"
)

if all([var.startswith("~") for var in var_names]):
# remove the ~ from the var names
var_names = [var[1:] for var in var_names]
var_names = [var for var in all_vars if var not in var_names]

return var_names

else:
# keep var_names as is but check if var is in all_vars
var_names = [var for var in all_vars if var in var_names]
return var_names
17 changes: 17 additions & 0 deletions tests/test_translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,29 @@ def test_arviz_translator(self):

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1) # +1 for weight column

def test_drop_on_arviz_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
arviz_id = az.from_numpyro(numpyro_mcmc)
chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["mu"])
assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns

chain = Chain.from_arviz(arviz_id, "Arviz", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)

def test_numpyro_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro")

assert chain.samples.shape == (self.n_steps * self.n_chains, self.n_params + 1)

def test_drop_on_numpyro_translator(self):
numpyro_mcmc = run_numpyro_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["mu"])
assert ("mu" in chain.samples.columns) and "sigma" not in chain.samples.columns

chain = Chain.from_numpyro(numpyro_mcmc, "numpyro", var_names=["~mu"])
assert ("mu" not in chain.samples.columns) and ("sigma" in chain.samples.columns)

def test_emcee_translator(self):
emcee_sampler = run_emcee_mcmc(self.n_steps, self.n_chains)
chain = Chain.from_emcee(emcee_sampler, ["mu", "sigma"], "emcee")
Expand Down
Loading