Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding "var_names" parameter to filter the variable to keep when building chain from numpyro or arviz #122

Merged
merged 3 commits into from
Apr 19, 2024

Conversation

renecotyfanboy
Copy link
Collaborator

I've added a convenience parameter in the Chain factories from_numpyro and from_arviz in the same fashion as in the arviz package. The user can either choose to keep parameters by providing a list or discard some by providing a list of parameters starting with ~.

WDYT ?

@renecotyfanboy
Copy link
Collaborator Author

It seems that the pre-commit hooks are not applying the same rules in the action and locally, I have no warnings about ruff formatter when running it locally. I will check this tomorrow

@renecotyfanboy
Copy link
Collaborator Author

This seems to be a weird issue related to an update of git itself

Copy link
Owner

@Samreay Samreay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this code together, mostly minor thoughts and nits :)

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, TypeAlias

import arviz as az
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm reading through this now I realise that that arviz is only a dependency in the test group, so I'd be tempted to move this import down to a conditional one where we wrap it in a try - catch.

The type hint should stay as arviz.blah because that import is only eval'd if TYPE_CHECKING (so shouldnt be a problem if its not there, only the type hinter will be sad)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I move the import to the from_arviz constructor only, and didn't add the try-catch as it is very unlikely that someone using it do not have arviz

@@ -395,40 +397,52 @@ def from_numpyro(
cls,
mcmc: numpyro.infer.MCMC,
name: str,
var_names: list[str] = [],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make this list[str] | None = None or similar to ensure we dont have mutable kwarg defaults

return cls(samples=df, name=name, **kwargs)

@classmethod
def from_arviz(
cls,
arviz_id: arviz.InferenceData,
name: str,
var_names: list[str] = [],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with above :)

@@ -444,3 +458,30 @@ class MaxPosterior(BetterBase):
def vec_coordinate(self) -> np.ndarray:
"""The coordinate as a numpy array, in the order the columns were given."""
return np.array(list(self.coordinate.values()))


def _filter_var_names(var_names: list[str], all_vars: list[str]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint

if not var_names:
return all_vars

elif var_names:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the primary if has a return, better to simply unindent and remove the elif - will keep the code simpler and easier to read

return all_vars

elif var_names:
if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with this, but other way we could do it with less duplication is:

negations =set([var.startswith("~") for var in var_names])
if len(negations) != 1:
    raise...

if True in negations:
    ...
    return blah
...
return blah

Copy link
Owner

@Samreay Samreay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good :)

@Samreay Samreay merged commit 97daf4e into master Apr 19, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Chain straight from dict / chain from numpyro with only certain fields
2 participants