-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding "var_names" parameter to filter the variable to keep when building chain from numpyro or arviz #122
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,13 @@ | |
|
||
There are also a few helper functions and objects in here, like the `MaxPosterior` class which | ||
provides the log posterior and the coordinate at which it can be found for the chain.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Any, TypeAlias | ||
|
||
import arviz as az | ||
import numpy as np | ||
import pandas as pd | ||
from pydantic import Field, field_validator, model_validator | ||
|
@@ -395,40 +397,52 @@ def from_numpyro( | |
cls, | ||
mcmc: numpyro.infer.MCMC, | ||
name: str, | ||
var_names: list[str] = [], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably make this |
||
**kwargs: Any, | ||
) -> Chain: | ||
"""Constructor from numpyro samples | ||
|
||
Args: | ||
mcmc: The numpyro sampler | ||
name: The name of the chain | ||
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, | ||
they are excluded from the variables. If empty, all parameters are included. | ||
kwargs: Any other arguments to pass to the Chain constructor. | ||
|
||
Returns: | ||
A ChainConsumer Chain made from numpyro samples | ||
""" | ||
df = pd.DataFrame.from_dict({key: np.ravel(value) for key, value in mcmc.get_samples().items()}) | ||
|
||
var_names = _filter_var_names(var_names, list(mcmc.get_samples().keys())) | ||
df = pd.DataFrame.from_dict( | ||
{key: np.ravel(value) for key, value in mcmc.get_samples().items() if key in var_names} | ||
) | ||
return cls(samples=df, name=name, **kwargs) | ||
|
||
@classmethod | ||
def from_arviz( | ||
cls, | ||
arviz_id: arviz.InferenceData, | ||
name: str, | ||
var_names: list[str] = [], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As with above :) |
||
**kwargs: Any, | ||
) -> Chain: | ||
"""Constructor from an arviz InferenceData object | ||
|
||
Args: | ||
arviz_id: The arviz inference data | ||
name: The name of the chain | ||
var_names: The names of the parameters to include in the chain. If the entries of var_names start with ~, | ||
they are excluded from the variables. If empty, all parameters are included. | ||
kwargs: Any other arguments to pass to the Chain constructor. | ||
|
||
Returns: | ||
A ChainConsumer Chain made from the arviz chain | ||
""" | ||
|
||
df = arviz_id.to_dataframe(groups="posterior").drop(columns=["chain", "draw"]) | ||
var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys())) | ||
reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior") | ||
df = reduced_id.to_dataframe().drop(columns=["chain", "draw"]) | ||
|
||
return cls(samples=df, name=name, **kwargs) | ||
|
||
|
@@ -444,3 +458,30 @@ class MaxPosterior(BetterBase): | |
def vec_coordinate(self) -> np.ndarray: | ||
"""The coordinate as a numpy array, in the order the columns were given.""" | ||
return np.array(list(self.coordinate.values())) | ||
|
||
|
||
def _filter_var_names(var_names: list[str], all_vars: list[str]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing return type hint |
||
""" | ||
Helper function to return the var_names to allows filtering parameters names. | ||
""" | ||
|
||
if not var_names: | ||
return all_vars | ||
|
||
elif var_names: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the primary |
||
if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy with this, but other way we could do it with less duplication is: negations =set([var.startswith("~") for var in var_names])
if len(negations) != 1:
raise...
if True in negations:
...
return blah
...
return blah |
||
raise ValueError( | ||
"all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset" | ||
) | ||
|
||
if all([var.startswith("~") for var in var_names]): | ||
# remove the ~ from the var names | ||
var_names = [var[1:] for var in var_names] | ||
var_names = [var for var in all_vars if var not in var_names] | ||
|
||
return var_names | ||
|
||
else: | ||
# keep var_names as is but check if var is in all_vars | ||
var_names = [var for var in all_vars if var in var_names] | ||
return var_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm reading through this now I realise that that arviz is only a dependency in the
test
group, so I'd be tempted to move this import down to a conditional one where we wrap it in atry - catch
.The type hint should stay as
arviz.blah
because that import is only eval'dif TYPE_CHECKING
(so shouldnt be a problem if its not there, only the type hinter will be sad)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I move the import to the from_arviz constructor only, and didn't add the try-catch as it is very unlikely that someone using it do not have arviz