-
-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convenient function to access inference methods and kwargs (#795)
* add inference_methods class to obtain names of methods and kwargs * re-run notebook * update notebook to include new methods * convienent methods for getting inference names and kwargs * Fix `get_model_covariates()` utility function (#801) * Support PyMC 5.13 and fix bayeux related issues (#803) * run black to fix formatting * add test to check for inference method names * test get_kwargs method of InferenceMethods class --------- Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
- Loading branch information
1 parent
793be6a
commit b5aefcf
Showing
7 changed files
with
1,286 additions
and
996 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .pymc import PyMCModel | ||
from .inference_methods import inference_methods | ||
|
||
__all__ = ["PyMCModel"] | ||
__all__ = ["inference_methods", "PyMCModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import importlib | ||
import inspect | ||
import operator | ||
|
||
import pymc as pm | ||
|
||
|
||
class InferenceMethods: | ||
"""Obtain a dictionary of available inference methods for Bambi | ||
models and or the default kwargs of each inference method. | ||
""" | ||
|
||
def __init__(self): | ||
# In order to access inference methods, a bayeux model must be initialized | ||
self.bayeux_model = bayeux_model() | ||
self.bayeux_methods = self._get_bayeux_methods(bayeux_model()) | ||
self.pymc_methods = self._pymc_methods() | ||
|
||
def _get_bayeux_methods(self, model): | ||
# Bambi only supports bayeux MCMC methods | ||
mcmc_methods = model.methods.get("mcmc") | ||
return {"mcmc": mcmc_methods} | ||
|
||
def _pymc_methods(self): | ||
return {"mcmc": ["mcmc"], "vi": ["vi"]} | ||
|
||
def _remove_parameters(self, fn_signature_dict): | ||
# Remove 'pm.sample' parameters that are irrelevant for Bambi users | ||
params_to_remove = [ | ||
"progressbar", | ||
"progressbar_theme", | ||
"var_names", | ||
"nuts_sampler", | ||
"return_inferencedata", | ||
"idata_kwargs", | ||
"callback", | ||
"mp_ctx", | ||
"model", | ||
] | ||
return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove} | ||
|
||
def get_kwargs(self, method): | ||
"""Get the default kwargs for a given inference method. | ||
Parameters | ||
---------- | ||
method : str | ||
The name of the inference method. | ||
Returns | ||
------- | ||
dict | ||
The default kwargs for the inference method. | ||
""" | ||
if method in self.bayeux_methods.get("mcmc"): | ||
bx_method = operator.attrgetter(method)( | ||
self.bayeux_model.mcmc # pylint: disable=no-member | ||
) | ||
return bx_method.get_kwargs() | ||
elif method in self.pymc_methods.get("mcmc"): | ||
return self._remove_parameters(get_default_signature(pm.sample)) | ||
elif method in self.pymc_methods.get("vi"): | ||
return get_default_signature(pm.ADVI.fit) | ||
else: | ||
raise ValueError( | ||
f"Inference method '{method}' not found in the list of available" | ||
" methods. Use `bmb.inference_methods.names` to list the available methods." | ||
) | ||
|
||
@property | ||
def names(self): | ||
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} | ||
|
||
|
||
def bayeux_model(): | ||
"""Dummy bayeux model for obtaining inference methods. | ||
A dummy model is needed because algorithms are dynamically determined at | ||
runtime, based on the libraries that are installed. A model can give | ||
programmatic access to the available algorithms via the `methods` attribute. | ||
Returns | ||
------- | ||
bayeux.Model | ||
A dummy model with a simple quadratic likelihood function. | ||
""" | ||
if importlib.util.find_spec("bayeux") is None: | ||
return {"mcmc": []} | ||
|
||
import bayeux as bx # pylint: disable=import-outside-toplevel | ||
|
||
return bx.Model(lambda x: -(x**2), 0.0) | ||
|
||
|
||
def get_default_signature(fn): | ||
"""Get the default parameter values of a function. | ||
This function inspects the signature of the provided function and returns | ||
a dictionary containing the default values of its parameters. | ||
Parameters | ||
---------- | ||
fn : callable | ||
The function for which default argument values are to be retrieved. | ||
Returns | ||
------- | ||
dict | ||
A dictionary mapping argument names to their default values. | ||
""" | ||
defaults = {} | ||
for key, val in inspect.signature(fn).parameters.items(): | ||
if val.default is not inspect.Signature.empty: | ||
defaults[key] = val.default | ||
return defaults | ||
|
||
|
||
inference_methods = InferenceMethods() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.