Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convenient function to access inference methods and kwargs #795

Merged
merged 9 commits into from
Apr 15, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
### Maintenance and fixes

* Fix bug in predictions with models using HSGP (#780)
* Fix `get_model_covariates()` utility function (#801)
* Upgrade PyMC dependency to >= 5.13 (#803)
* Use `pm.compute_deterministics()` to compute deterministics when bayeux based samplers are used (#803)

### Documentation

Expand Down
3 changes: 2 additions & 1 deletion bambi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymc import math

from .backend import PyMCModel
from .backend import inference_methods, PyMCModel
from .config import config
from .data import clear_data_home, load_data
from .families import Family, Likelihood, Link
Expand All @@ -25,6 +25,7 @@
"Formula",
"clear_data_home",
"config",
"inference_methods",
"load_data",
"math",
]
Expand Down
3 changes: 2 additions & 1 deletion bambi/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pymc import PyMCModel
from .inference_methods import inference_methods

__all__ = ["PyMCModel"]
__all__ = ["inference_methods", "PyMCModel"]
119 changes: 119 additions & 0 deletions bambi/backend/inference_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import importlib
import inspect
import operator

import pymc as pm


class InferenceMethods:
"""Obtain a dictionary of available inference methods for Bambi
models and or the default kwargs of each inference method.
"""

def __init__(self):
# In order to access inference methods, a bayeux model must be initialized
self.bayeux_model = bayeux_model()
self.bayeux_methods = self._get_bayeux_methods(bayeux_model())
self.pymc_methods = self._pymc_methods()

def _get_bayeux_methods(self, model):
# Bambi only supports bayeux MCMC methods
mcmc_methods = model.methods.get("mcmc")
return {"mcmc": mcmc_methods}

def _pymc_methods(self):
return {"mcmc": ["mcmc"], "vi": ["vi"]}

def _remove_parameters(self, fn_signature_dict):
# Remove 'pm.sample' parameters that are irrelevant for Bambi users
params_to_remove = [
"progressbar",
"progressbar_theme",
"var_names",
"nuts_sampler",
"return_inferencedata",
"idata_kwargs",
"callback",
"mp_ctx",
"model",
]
return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove}

def get_kwargs(self, method):
"""Get the default kwargs for a given inference method.

Parameters
----------
method : str
The name of the inference method.

Returns
-------
dict
The default kwargs for the inference method.
"""
if method in self.bayeux_methods.get("mcmc"):
bx_method = operator.attrgetter(method)(
self.bayeux_model.mcmc # pylint: disable=no-member
)
return bx_method.get_kwargs()
elif method in self.pymc_methods.get("mcmc"):
return self._remove_parameters(get_default_signature(pm.sample))
elif method in self.pymc_methods.get("vi"):
return get_default_signature(pm.ADVI.fit)
else:
raise ValueError(
f"Inference method '{method}' not found in the list of available"
" methods. Use `bmb.inference_methods.names` to list the available methods."
)

@property
def names(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def bayeux_model():
"""Dummy bayeux model for obtaining inference methods.

A dummy model is needed because algorithms are dynamically determined at
runtime, based on the libraries that are installed. A model can give
programmatic access to the available algorithms via the `methods` attribute.

Returns
-------
bayeux.Model
A dummy model with a simple quadratic likelihood function.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

return bx.Model(lambda x: -(x**2), 0.0)


def get_default_signature(fn):
"""Get the default parameter values of a function.

This function inspects the signature of the provided function and returns
a dictionary containing the default values of its parameters.

Parameters
----------
fn : callable
The function for which default argument values are to be retrieved.

Returns
-------
dict
A dictionary mapping argument names to their default values.

"""
defaults = {}
for key, val in inspect.signature(fn).parameters.items():
if val.default is not inspect.Signature.empty:
defaults[key] = val.default
return defaults


inference_methods = InferenceMethods()
72 changes: 29 additions & 43 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import importlib
import logging
import operator
import traceback
Expand All @@ -14,6 +13,7 @@
import pytensor.tensor as pt
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
from bambi.utils import get_aliased_name
Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(self):
self.model = None
self.spec = None
self.components = {}
self.bayeux_methods = _get_bayeux_methods()
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}
self.bayeux_methods = inference_methods.names["bayeux"]
self.pymc_methods = inference_methods.names["pymc"]

def build(self, spec):
"""Compile the PyMC model from an abstract model specification.
Expand Down Expand Up @@ -253,45 +253,55 @@ def _run_mcmc(
return idata

def _clean_results(self, idata, omit_offsets, include_mean, idata_from):
for group in idata.groups():
# Before doing anything, make sure we compute deterministics.
if idata_from == "bayeux":
idata.posterior = pm.compute_deterministics(
idata.posterior, model=self.model, merge_dataset=True, progressbar=False
)

for group in idata.groups():
getattr(idata, group).attrs["modeling_interface"] = "bambi"
getattr(idata, group).attrs["modeling_interface_version"] = __version__

if omit_offsets:
offset_vars = [var for var in idata.posterior.data_vars if var.endswith("_offset")]
idata.posterior = idata.posterior.drop_vars(offset_vars)

# Drop variables and dimensions associated with LKJ prior
vars_to_drop = [var for var in idata.posterior.data_vars if var.startswith("_LKJ")]
dims_to_drop = [dim for dim in idata.posterior.dims if dim.startswith("_LKJ")]
# NOTE:
# This has not had an effect for a while since we haven't been supporting LKJ prior lately.

idata.posterior = idata.posterior.drop_vars(vars_to_drop)
idata.posterior = idata.posterior.drop_dims(dims_to_drop)
# Drop variables and dimensions associated with LKJ prior
# vars_to_drop = [var for var in idata.posterior.data_vars if var.startswith("_LKJ")]
# dims_to_drop = [dim for dim in idata.posterior.dims if dim.startswith("_LKJ")]
# idata.posterior = idata.posterior.drop_vars(vars_to_drop)
# idata.posterior = idata.posterior.drop_dims(dims_to_drop)

dims_original = list(self.model.coords)

# Identify bayeux idata and rename dims and coordinates to match PyMC model
if idata_from == "bayeux":
pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim]
bayeux_dims = [
dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))
]
cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims))
cleaned_dims = {
f"{dim}_0": dim
for dim in dims_original
if not dim.endswith("_obs") and f"{dim}_0" in idata.posterior.dims
}
idata = idata.rename(cleaned_dims)

# Discard dims that are in the model but unused in the posterior
# Don't select dims that are in the model but unused in the posterior
dims_original = [dim for dim in dims_original if dim in idata.posterior.dims]

# This does not add any new coordinate, it just changes the order so the ones
# ending in "__factor_dim" are placed after the others.
dims_group = [c for c in dims_original if c.endswith("__factor_dim")]
dims_group = [dim for dim in dims_original if dim.endswith("__factor_dim")]

# Keep the original order in dims_original
dims_original_set = set(dims_original) - set(dims_group)
dims_original = [c for c in dims_original if c in dims_original_set]
dims_original = [dim for dim in dims_original if dim in dims_original_set]
dims_new = ["chain", "draw"] + dims_original + dims_group
idata.posterior = idata.posterior.transpose(*dims_new)

# Drop unused dimensions before transposing
dims_to_drop = [dim for dim in idata.posterior.dims if dim not in dims_new]
idata.posterior = idata.posterior.drop_dims(dims_to_drop).transpose(*dims_new)

# Compute the actual intercept in all distributional components that have an intercept
for pymc_component in self.distributional_components.values():
Expand Down Expand Up @@ -338,8 +348,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):

Mainly for pedagogical use, provides reasonable results for approximately
Gaussian posteriors. The approximation can be very poor for some models
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
for better approximations.
like hierarchical ones. Use MCMC or VI methods for better approximations.

Parameters
----------
Expand Down Expand Up @@ -388,10 +397,6 @@ def constant_components(self):
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}

@property
def inference_methods(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def _posterior_samples_to_idata(samples, model):
"""Create InferenceData from samples.
Expand Down Expand Up @@ -431,22 +436,3 @@ def _posterior_samples_to_idata(samples, model):

idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
return idata


def _get_bayeux_methods():
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
within the user's environment.

Returns
-------
dict
A dict where the keys are the module names and the values are the methods
available in that module.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

# Dummy log density to get access to all methods
return bx.Model(lambda x: -(x**2), 0.0).methods
6 changes: 5 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ def build(self, spec):
response_dims = list(spec.response_component.response_term.coords)

dims = list(self.coords) + response_dims
coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs)

# Squeeze ensures we don't have a shape of (n, 1) when we mean (n, )
# This happens with categorical predictors with two levels and intercept.
coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs).squeeze()
# See https://github.com/pymc-devs/pymc/issues/7246
if len(coef.shape.eval()) == 2 and coef.shape.eval()[-1] == 1:
coef = pt.specify_broadcastable(coef, 1).squeeze()
coef = coef[self.term.group_index]

return coef, predictor
Expand Down
2 changes: 1 addition & 1 deletion bambi/families/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def transform_backend_eta(eta, kwargs):
def transform_backend_kwargs(kwargs):
# P(Y = k) = F(threshold_k - eta) * \prod_{j=1}^{k-1}{1 - F(threshold_j - eta)}
p = kwargs.pop("p")
n_columns = p.type.shape[-1]
n_columns = p.shape.eval()[-1]
p = pt.concatenate(
[
pt.shape_padright(p[..., 0]),
Expand Down
3 changes: 3 additions & 0 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def get_model_covariates(model: Model) -> np.ndarray:

flatten_covariates = [item for sublist in covariates for item in sublist]

# Don't include non-covariate names (#797)
flatten_covariates = [name for name in flatten_covariates if name in model.data]

return np.unique(flatten_covariates)


Expand Down
4 changes: 2 additions & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def fit(
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To get a list of JAX based inference methods, call
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
init : str
Initialization method. Defaults to ``"auto"``. The available methods are:
Expand Down Expand Up @@ -307,7 +307,7 @@ def fit(
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"laplace", or one of the MCMC methods in
``model.backend.inference_methods['bayeux']['mcmc]``.
``bmb.inference_methods.names['bayeux']['mcmc]``.
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down
Loading
Loading