Skip to content

Commit

Permalink
convienent methods for getting inference names and kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Apr 13, 2024
1 parent 4a01c7c commit e3a4393
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 46 deletions.
88 changes: 72 additions & 16 deletions bambi/backend/inference_methods.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
import importlib
import inspect
import operator

import pymc as pm


class InferenceMethods:
"""Obtain a dictionary of available inference methods for Bambi
models, and or the kwargs that each inference method accepts.
models and or the default kwargs of each inference method.
"""

def __init__(self):
# In order to access inference methods, a bayeux model must be initialized
self.bayeux_model = bayeux_model()

self.bayeux_methods = self._get_bayeux_methods(bayeux_model())
self.pymc_methods = self._pymc_methods()

def _get_bayeux_methods(self, model):
# Bambi only supports bayeux MCMC methods
mcmc_methods = model.methods.get("mcmc")
return {"mcmc": mcmc_methods}

def _pymc_methods(self):
return {"mcmc": ["mcmc"], "vi": ["vi"]}

def _remove_parameters(self, fn_signature_dict):
# Remove 'pm.sample' parameters that are irrelevant for Bambi users
params_to_remove = [
"progressbar",
"progressbar_theme",
"var_names",
"nuts_sampler",
"return_inferencedata",
"idata_kwargs",
"callback",
"mp_ctx",
"model",
]
return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove}

def get_kwargs(self, method):
"""Get the default kwargs for a given inference method.
Expand All @@ -23,29 +52,31 @@ def get_kwargs(self, method):
dict
The default kwargs for the inference method.
"""
# TODO: Somehow add the ability to retrieve PyMC kwargs of
# TODO: `pymc.sampling.mcmc.sample`
# Bambi only supports bayeux MCMC methods
if method not in self.bayeux_model.methods["mcmc"]:
if method in self.bayeux_methods.get("mcmc"):
bx_method = operator.attrgetter(method)(
self.bayeux_model.mcmc # pylint: disable=no-member
)
return bx_method.get_kwargs()
elif method in self.pymc_methods.get("mcmc"):
return self._remove_parameters(get_default_signature(pm.sample))
elif method in self.pymc_methods.get("vi"):
return get_default_signature(pm.ADVI.fit)
else:
raise ValueError(
f"Inference method '{method}' not found in the list of available"
" methods"
)
" methods. Use `bmb.inference_methods.names` to list the available methods."
)

bx_method = operator.attrgetter(method)(self.bayeux_model.mcmc)
return bx_method.get_kwargs()

@property
def names(self):
# TODO: Add PyMC MCMC methods
return self.bayeux_model.methods.get("mcmc")
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def bayeux_model():
"""Dummy bayeux model for obtaining inference methods.
A dummy model is needed because algorithms are dynamically determined at
runtime, based on the libraries that are installed. A model can give
A dummy model is needed because algorithms are dynamically determined at
runtime, based on the libraries that are installed. A model can give
programmatic access to the available algorithms via the `methods` attribute.
Returns
Expand All @@ -57,7 +88,32 @@ def bayeux_model():
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

return bx.Model(lambda x: -(x**2), 0.0)


inference_methods = InferenceMethods()
def get_default_signature(fn):
"""Get the default parameter values of a function.
This function inspects the signature of the provided function and returns
a dictionary containing the default values of its parameters.
Parameters
----------
fn : callable
The function for which default argument values are to be retrieved.
Returns
-------
dict
A dictionary mapping argument names to their default values.
"""
defaults = {}
for key, val in inspect.signature(fn).parameters.items():
if val.default is not inspect.Signature.empty:
defaults[key] = val.default
return defaults


inference_methods = InferenceMethods()
32 changes: 4 additions & 28 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import importlib
import logging
import operator
import traceback
Expand All @@ -14,6 +13,7 @@
import pytensor.tensor as pt
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
from bambi.utils import get_aliased_name
Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(self):
self.model = None
self.spec = None
self.components = {}
self.bayeux_methods = _get_bayeux_methods()
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}
self.bayeux_methods = inference_methods.names["bayeux"]
self.pymc_methods = inference_methods.names["pymc"]

def build(self, spec):
"""Compile the PyMC model from an abstract model specification.
Expand Down Expand Up @@ -338,8 +338,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
Mainly for pedagogical use, provides reasonable results for approximately
Gaussian posteriors. The approximation can be very poor for some models
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
for better approximations.
like hierarchical ones. Use MCMC or VI methods for better approximations.
Parameters
----------
Expand Down Expand Up @@ -388,10 +387,6 @@ def constant_components(self):
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}

@property
def inference_methods(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def _posterior_samples_to_idata(samples, model):
"""Create InferenceData from samples.
Expand Down Expand Up @@ -431,22 +426,3 @@ def _posterior_samples_to_idata(samples, model):

idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
return idata


def _get_bayeux_methods():
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
within the user's environment.
Returns
-------
dict
A dict where the keys are the module names and the values are the methods
available in that module.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

# Dummy log density to get access to all methods
return bx.Model(lambda x: -(x**2), 0.0).methods
4 changes: 2 additions & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def fit(
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To get a list of JAX based inference methods, call
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
init : str
Initialization method. Defaults to ``"auto"``. The available methods are:
Expand Down Expand Up @@ -307,7 +307,7 @@ def fit(
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"laplace", or one of the MCMC methods in
``model.backend.inference_methods['bayeux']['mcmc]``.
``bmb.inference_methods.names['bayeux']['mcmc]``.
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down

0 comments on commit e3a4393

Please sign in to comment.