-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding "var_names" parameter to filter the variable to keep when building chain from numpyro or arviz #122
Conversation
a88d0e8
to
6f5d333
Compare
It seems that the pre-commit hooks are not applying the same rules in the action and locally, I have no warnings about ruff formatter when running it locally. I will check this tomorrow |
This seems to be a weird issue related to an update of git itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this code together, mostly minor thoughts and nits :)
src/chainconsumer/chain.py
Outdated
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Any, TypeAlias | ||
|
||
import arviz as az |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm reading through this now I realise that that arviz is only a dependency in the test
group, so I'd be tempted to move this import down to a conditional one where we wrap it in a try - catch
.
The type hint should stay as arviz.blah
because that import is only eval'd if TYPE_CHECKING
(so shouldnt be a problem if its not there, only the type hinter will be sad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I move the import to the from_arviz constructor only, and didn't add the try-catch as it is very unlikely that someone using it do not have arviz
src/chainconsumer/chain.py
Outdated
@@ -395,40 +397,52 @@ def from_numpyro( | |||
cls, | |||
mcmc: numpyro.infer.MCMC, | |||
name: str, | |||
var_names: list[str] = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably make this list[str] | None = None
or similar to ensure we dont have mutable kwarg defaults
src/chainconsumer/chain.py
Outdated
return cls(samples=df, name=name, **kwargs) | ||
|
||
@classmethod | ||
def from_arviz( | ||
cls, | ||
arviz_id: arviz.InferenceData, | ||
name: str, | ||
var_names: list[str] = [], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As with above :)
src/chainconsumer/chain.py
Outdated
@@ -444,3 +458,30 @@ class MaxPosterior(BetterBase): | |||
def vec_coordinate(self) -> np.ndarray: | |||
"""The coordinate as a numpy array, in the order the columns were given.""" | |||
return np.array(list(self.coordinate.values())) | |||
|
|||
|
|||
def _filter_var_names(var_names: list[str], all_vars: list[str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return type hint
src/chainconsumer/chain.py
Outdated
if not var_names: | ||
return all_vars | ||
|
||
elif var_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the primary if
has a return, better to simply unindent and remove the elif
- will keep the code simpler and easier to read
src/chainconsumer/chain.py
Outdated
return all_vars | ||
|
||
elif var_names: | ||
if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with this, but other way we could do it with less duplication is:
negations =set([var.startswith("~") for var in var_names])
if len(negations) != 1:
raise...
if True in negations:
...
return blah
...
return blah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good :)
I've added a convenience parameter in the Chain factories from_numpyro and from_arviz in the same fashion as in the arviz package. The user can either choose to keep parameters by providing a list or discard some by providing a list of parameters starting with ~.
WDYT ?