diff --git a/CHANGELOG.md b/CHANGELOG.md index ceb9b8e00..8eb58aef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ ### Maintenance and fixes * Fix bug in predictions with models using HSGP (#780) +* Fix `get_model_covariates()` utility function (#801) +* Upgrade PyMC dependency to >= 5.13 (#803) +* Use `pm.compute_deterministics()` to compute deterministics when bayeux based samplers are used (#803) ### Documentation diff --git a/bambi/__init__.py b/bambi/__init__.py index 660d6724c..fdec8ec16 100644 --- a/bambi/__init__.py +++ b/bambi/__init__.py @@ -4,7 +4,7 @@ from pymc import math -from .backend import PyMCModel +from .backend import inference_methods, PyMCModel from .config import config from .data import clear_data_home, load_data from .families import Family, Likelihood, Link @@ -25,6 +25,7 @@ "Formula", "clear_data_home", "config", + "inference_methods", "load_data", "math", ] diff --git a/bambi/backend/__init__.py b/bambi/backend/__init__.py index 6ee2a4aa3..daef1924c 100644 --- a/bambi/backend/__init__.py +++ b/bambi/backend/__init__.py @@ -1,3 +1,4 @@ from .pymc import PyMCModel +from .inference_methods import inference_methods -__all__ = ["PyMCModel"] +__all__ = ["inference_methods", "PyMCModel"] diff --git a/bambi/backend/inference_methods.py b/bambi/backend/inference_methods.py new file mode 100644 index 000000000..900d9c262 --- /dev/null +++ b/bambi/backend/inference_methods.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import operator + +import pymc as pm + + +class InferenceMethods: + """Obtain a dictionary of available inference methods for Bambi + models and or the default kwargs of each inference method. + """ + + def __init__(self): + # In order to access inference methods, a bayeux model must be initialized + self.bayeux_model = bayeux_model() + self.bayeux_methods = self._get_bayeux_methods(bayeux_model()) + self.pymc_methods = self._pymc_methods() + + def _get_bayeux_methods(self, model): + # Bambi only supports bayeux MCMC methods + mcmc_methods = model.methods.get("mcmc") + return {"mcmc": mcmc_methods} + + def _pymc_methods(self): + return {"mcmc": ["mcmc"], "vi": ["vi"]} + + def _remove_parameters(self, fn_signature_dict): + # Remove 'pm.sample' parameters that are irrelevant for Bambi users + params_to_remove = [ + "progressbar", + "progressbar_theme", + "var_names", + "nuts_sampler", + "return_inferencedata", + "idata_kwargs", + "callback", + "mp_ctx", + "model", + ] + return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove} + + def get_kwargs(self, method): + """Get the default kwargs for a given inference method. + + Parameters + ---------- + method : str + The name of the inference method. + + Returns + ------- + dict + The default kwargs for the inference method. + """ + if method in self.bayeux_methods.get("mcmc"): + bx_method = operator.attrgetter(method)( + self.bayeux_model.mcmc # pylint: disable=no-member + ) + return bx_method.get_kwargs() + elif method in self.pymc_methods.get("mcmc"): + return self._remove_parameters(get_default_signature(pm.sample)) + elif method in self.pymc_methods.get("vi"): + return get_default_signature(pm.ADVI.fit) + else: + raise ValueError( + f"Inference method '{method}' not found in the list of available" + " methods. Use `bmb.inference_methods.names` to list the available methods." + ) + + @property + def names(self): + return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} + + +def bayeux_model(): + """Dummy bayeux model for obtaining inference methods. + + A dummy model is needed because algorithms are dynamically determined at + runtime, based on the libraries that are installed. A model can give + programmatic access to the available algorithms via the `methods` attribute. + + Returns + ------- + bayeux.Model + A dummy model with a simple quadratic likelihood function. + """ + if importlib.util.find_spec("bayeux") is None: + return {"mcmc": []} + + import bayeux as bx # pylint: disable=import-outside-toplevel + + return bx.Model(lambda x: -(x**2), 0.0) + + +def get_default_signature(fn): + """Get the default parameter values of a function. + + This function inspects the signature of the provided function and returns + a dictionary containing the default values of its parameters. + + Parameters + ---------- + fn : callable + The function for which default argument values are to be retrieved. + + Returns + ------- + dict + A dictionary mapping argument names to their default values. + + """ + defaults = {} + for key, val in inspect.signature(fn).parameters.items(): + if val.default is not inspect.Signature.empty: + defaults[key] = val.default + return defaults + + +inference_methods = InferenceMethods() diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 82b646ebe..75a1fe318 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,5 +1,4 @@ import functools -import importlib import logging import operator import traceback @@ -14,6 +13,7 @@ import pytensor.tensor as pt from pytensor.tensor.special import softmax +from bambi.backend.inference_methods import inference_methods from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 from bambi.backend.model_components import ConstantComponent, DistributionalComponent from bambi.utils import get_aliased_name @@ -47,8 +47,8 @@ def __init__(self): self.model = None self.spec = None self.components = {} - self.bayeux_methods = _get_bayeux_methods() - self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} + self.bayeux_methods = inference_methods.names["bayeux"] + self.pymc_methods = inference_methods.names["pymc"] def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -253,8 +253,13 @@ def _run_mcmc( return idata def _clean_results(self, idata, omit_offsets, include_mean, idata_from): - for group in idata.groups(): + # Before doing anything, make sure we compute deterministics. + if idata_from == "bayeux": + idata.posterior = pm.compute_deterministics( + idata.posterior, model=self.model, merge_dataset=True, progressbar=False + ) + for group in idata.groups(): getattr(idata, group).attrs["modeling_interface"] = "bambi" getattr(idata, group).attrs["modeling_interface_version"] = __version__ @@ -262,36 +267,41 @@ def _clean_results(self, idata, omit_offsets, include_mean, idata_from): offset_vars = [var for var in idata.posterior.data_vars if var.endswith("_offset")] idata.posterior = idata.posterior.drop_vars(offset_vars) - # Drop variables and dimensions associated with LKJ prior - vars_to_drop = [var for var in idata.posterior.data_vars if var.startswith("_LKJ")] - dims_to_drop = [dim for dim in idata.posterior.dims if dim.startswith("_LKJ")] + # NOTE: + # This has not had an effect for a while since we haven't been supporting LKJ prior lately. - idata.posterior = idata.posterior.drop_vars(vars_to_drop) - idata.posterior = idata.posterior.drop_dims(dims_to_drop) + # Drop variables and dimensions associated with LKJ prior + # vars_to_drop = [var for var in idata.posterior.data_vars if var.startswith("_LKJ")] + # dims_to_drop = [dim for dim in idata.posterior.dims if dim.startswith("_LKJ")] + # idata.posterior = idata.posterior.drop_vars(vars_to_drop) + # idata.posterior = idata.posterior.drop_dims(dims_to_drop) dims_original = list(self.model.coords) # Identify bayeux idata and rename dims and coordinates to match PyMC model if idata_from == "bayeux": - pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim] - bayeux_dims = [ - dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw")) - ] - cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims)) + cleaned_dims = { + f"{dim}_0": dim + for dim in dims_original + if not dim.endswith("_obs") and f"{dim}_0" in idata.posterior.dims + } idata = idata.rename(cleaned_dims) - # Discard dims that are in the model but unused in the posterior + # Don't select dims that are in the model but unused in the posterior dims_original = [dim for dim in dims_original if dim in idata.posterior.dims] # This does not add any new coordinate, it just changes the order so the ones # ending in "__factor_dim" are placed after the others. - dims_group = [c for c in dims_original if c.endswith("__factor_dim")] + dims_group = [dim for dim in dims_original if dim.endswith("__factor_dim")] # Keep the original order in dims_original dims_original_set = set(dims_original) - set(dims_group) - dims_original = [c for c in dims_original if c in dims_original_set] + dims_original = [dim for dim in dims_original if dim in dims_original_set] dims_new = ["chain", "draw"] + dims_original + dims_group - idata.posterior = idata.posterior.transpose(*dims_new) + + # Drop unused dimensions before transposing + dims_to_drop = [dim for dim in idata.posterior.dims if dim not in dims_new] + idata.posterior = idata.posterior.drop_dims(dims_to_drop).transpose(*dims_new) # Compute the actual intercept in all distributional components that have an intercept for pymc_component in self.distributional_components.values(): @@ -338,8 +348,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods - for better approximations. + like hierarchical ones. Use MCMC or VI methods for better approximations. Parameters ---------- @@ -388,10 +397,6 @@ def constant_components(self): def distributional_components(self): return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} - @property - def inference_methods(self): - return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} - def _posterior_samples_to_idata(samples, model): """Create InferenceData from samples. @@ -431,22 +436,3 @@ def _posterior_samples_to_idata(samples, model): idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) return idata - - -def _get_bayeux_methods(): - """Gets a dictionary of usable bayeux methods if the bayeux package is installed - within the user's environment. - - Returns - ------- - dict - A dict where the keys are the module names and the values are the methods - available in that module. - """ - if importlib.util.find_spec("bayeux") is None: - return {"mcmc": []} - - import bayeux as bx # pylint: disable=import-outside-toplevel - - # Dummy log density to get access to all methods - return bx.Model(lambda x: -(x**2), 0.0).methods diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index a33b40d64..ed2a2e02c 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -110,9 +110,13 @@ def build(self, spec): response_dims = list(spec.response_component.response_term.coords) dims = list(self.coords) + response_dims + coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs) + # Squeeze ensures we don't have a shape of (n, 1) when we mean (n, ) # This happens with categorical predictors with two levels and intercept. - coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs).squeeze() + # See https://github.com/pymc-devs/pymc/issues/7246 + if len(coef.shape.eval()) == 2 and coef.shape.eval()[-1] == 1: + coef = pt.specify_broadcastable(coef, 1).squeeze() coef = coef[self.term.group_index] return coef, predictor diff --git a/bambi/families/univariate.py b/bambi/families/univariate.py index fa45c43a2..93b91785e 100644 --- a/bambi/families/univariate.py +++ b/bambi/families/univariate.py @@ -403,7 +403,7 @@ def transform_backend_eta(eta, kwargs): def transform_backend_kwargs(kwargs): # P(Y = k) = F(threshold_k - eta) * \prod_{j=1}^{k-1}{1 - F(threshold_j - eta)} p = kwargs.pop("p") - n_columns = p.type.shape[-1] + n_columns = p.shape.eval()[-1] p = pt.concatenate( [ pt.shape_padright(p[..., 0]), diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index cbb7bde19..b06f97c11 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -258,6 +258,9 @@ def get_model_covariates(model: Model) -> np.ndarray: flatten_covariates = [item for sublist in covariates for item in sublist] + # Don't include non-covariate names (#797) + flatten_covariates = [name for name in flatten_covariates if name in model.data] + return np.unique(flatten_covariates) diff --git a/bambi/models.py b/bambi/models.py index ecb57700f..74286dbed 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -267,7 +267,7 @@ def fit( Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. To get a list of JAX based inference methods, call - ``model.backend.inference_methods['bayeux']``. This will return a dictionary of the + ``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others. init : str Initialization method. Defaults to ``"auto"``. The available methods are: @@ -307,7 +307,7 @@ def fit( ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), "laplace", or one of the MCMC methods in - ``model.backend.inference_methods['bayeux']['mcmc]``. + ``bmb.inference_methods.names['bayeux']['mcmc]``. An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None) diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb index 24d610d96..7610df6d6 100644 --- a/docs/notebooks/alternative_samplers.ipynb +++ b/docs/notebooks/alternative_samplers.ipynb @@ -15,9 +15,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], "source": [ "import arviz as az\n", "import bambi as bmb\n", @@ -62,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -74,12 +82,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can call `model.backend.inference_methods` that returns a nested dictionary of the backends and list of inference methods." + "We can call `bmb.inference_methods.names` that returns a nested dictionary of the backends and list of inference methods." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -100,47 +108,16 @@ " 'flowmc_realnvp_hmc',\n", " 'flowmc_realnvp_mala',\n", " 'numpyro_hmc',\n", - " 'numpyro_nuts'],\n", - " 'optimize': ['jaxopt_bfgs',\n", - " 'jaxopt_gradient_descent',\n", - " 'jaxopt_lbfgs',\n", - " 'jaxopt_nonlinear_cg',\n", - " 'optimistix_bfgs',\n", - " 'optimistix_chord',\n", - " 'optimistix_dogleg',\n", - " 'optimistix_gauss_newton',\n", - " 'optimistix_indirect_levenberg_marquardt',\n", - " 'optimistix_levenberg_marquardt',\n", - " 'optimistix_nelder_mead',\n", - " 'optimistix_newton',\n", - " 'optimistix_nonlinear_cg',\n", - " 'optax_adabelief',\n", - " 'optax_adafactor',\n", - " 'optax_adagrad',\n", - " 'optax_adam',\n", - " 'optax_adamw',\n", - " 'optax_adamax',\n", - " 'optax_amsgrad',\n", - " 'optax_fromage',\n", - " 'optax_lamb',\n", - " 'optax_lion',\n", - " 'optax_noisy_sgd',\n", - " 'optax_novograd',\n", - " 'optax_radam',\n", - " 'optax_rmsprop',\n", - " 'optax_sgd',\n", - " 'optax_sm3',\n", - " 'optax_yogi'],\n", - " 'vi': ['tfp_factored_surrogate_posterior']}}" + " 'numpyro_nuts']}}" ] }, - "execution_count": 4, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "methods = model.backend.inference_methods\n", + "methods = bmb.inference_methods.names\n", "methods" ] }, @@ -153,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -162,7 +139,7 @@ "{'mcmc': ['mcmc'], 'vi': ['vi']}" ] }, - "execution_count": 5, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -180,36 +157,36 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['tfp_hmc',\n", - " 'tfp_nuts',\n", - " 'tfp_snaper_hmc',\n", - " 'blackjax_hmc',\n", - " 'blackjax_chees_hmc',\n", - " 'blackjax_meads_hmc',\n", - " 'blackjax_nuts',\n", - " 'blackjax_hmc_pathfinder',\n", - " 'blackjax_nuts_pathfinder',\n", - " 'flowmc_rqspline_hmc',\n", - " 'flowmc_rqspline_mala',\n", - " 'flowmc_realnvp_hmc',\n", - " 'flowmc_realnvp_mala',\n", - " 'numpyro_hmc',\n", - " 'numpyro_nuts']" + "{'mcmc': ['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts']}" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "methods[\"bayeux\"][\"mcmc\"]" + "methods[\"bayeux\"]" ] }, { @@ -242,9 +219,46 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "909c6a6f539145ab8348ebdeb1d42a3b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -256,8 +270,8 @@ "
  • created_at :
    2024-04-13T05:34:49.761913+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:34:49.763427+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -1447,7 +1461,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1490,7 +1505,7 @@ "\t> sample_stats" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -1506,7 +1521,7 @@ "source": [ "Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own `kwargs` to the `fit` method.\n", "\n", - "Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called." + "The following can be performend to identify the kwargs specific to each method." ] }, { @@ -1542,7 +1557,7 @@ } ], "source": [ - "bx.Model.from_pymc(model.backend.model).mcmc.blackjax_nuts.get_kwargs()" + "bmb.inference_methods.get_kwargs(\"blackjax_nuts\")" ] }, { @@ -1554,9 +1569,46 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f12c14ad9394476085d96b2ebbaa837d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -1568,8 +1620,8 @@ "
  • created_at :
    2024-04-13T05:36:20.439151+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:20.441267+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -3057,7 +3109,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -3100,7 +3153,7 @@ "\t> sample_stats" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -3126,9 +3179,46 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a25653608a3f4f26a20237fb94775629", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -3140,8 +3230,8 @@ "
  • created_at :
    2024-04-13T05:36:30.303342+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:30.304788+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -4325,7 +4415,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -4368,7 +4459,7 @@ "\t> sample_stats" ] }, - "execution_count": 5, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -4394,9 +4485,46 @@ "name": "stderr", "output_type": "stream", "text": [ - "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" + "sample: 100%|██████████| 1500/1500 [00:02<00:00, 599.25it/s]\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "851ea515d7c54968926f9eb0dc8b30c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -4408,8 +4536,8 @@ "
  • created_at :
    2024-04-13T05:36:33.599519+00:00
    arviz_version :
    0.18.0
    inference_library :
    numpyro
    inference_library_version :
    0.14.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:33.623197+00:00
    arviz_version :
    0.18.0
    inference_library :
    numpyro
    inference_library_version :
    0.14.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -5603,7 +5731,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -5680,7 +5809,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" + "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.37s/it]\n" ] }, { @@ -5694,9 +5823,53 @@ "name": "stderr", "output_type": "stream", "text": [ - "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + "Production run: 100%|██████████| 5/5 [00:00<00:00, 14.38it/s]" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1865c421a05b46109fcf06c8b7da2cf4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -5708,8 +5881,8 @@ "
  • created_at :
    2024-04-13T05:37:29.798250+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -6440,7 +6613,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -6503,7 +6677,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -6540,40 +6714,40 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.945\n", - " 0.070\n", - " 0.819\n", - " 1.080\n", - " 0.002\n", - " 0.002\n", - " 1044.0\n", - " 667.0\n", - " 1.0\n", - " \n", - " \n", " Intercept\n", - " 0.018\n", - " 0.089\n", - " -0.156\n", - " 0.185\n", + " 0.023\n", + " 0.097\n", + " -0.141\n", + " 0.209\n", + " 0.004\n", " 0.003\n", - " 0.002\n", - " 844.0\n", - " 733.0\n", - " 1.0\n", + " 694.0\n", + " 508.0\n", + " 1.00\n", " \n", " \n", " x\n", - " 0.358\n", - " 0.105\n", - " 0.163\n", - " 0.554\n", + " 0.356\n", + " 0.111\n", + " 0.162\n", + " 0.571\n", " 0.004\n", " 0.003\n", - " 829.0\n", - " 767.0\n", - " 1.0\n", + " 970.0\n", + " 675.0\n", + " 1.00\n", + " \n", + " \n", + " y_sigma\n", + " 0.950\n", + " 0.069\n", + " 0.827\n", + " 1.072\n", + " 0.002\n", + " 0.001\n", + " 1418.0\n", + " 842.0\n", + " 1.01\n", " \n", " \n", "\n", @@ -6581,17 +6755,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.945 0.070 0.819 1.080 0.002 0.002 1044.0 \n", - "Intercept 0.018 0.089 -0.156 0.185 0.003 0.002 844.0 \n", - "x 0.358 0.105 0.163 0.554 0.004 0.003 829.0 \n", + "Intercept 0.023 0.097 -0.141 0.209 0.004 0.003 694.0 \n", + "x 0.356 0.111 0.162 0.571 0.004 0.003 970.0 \n", + "y_sigma 0.950 0.069 0.827 1.072 0.002 0.001 1418.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 667.0 1.0 \n", - "Intercept 733.0 1.0 \n", - "x 767.0 1.0 " + "Intercept 508.0 1.00 \n", + "x 675.0 1.00 \n", + "y_sigma 842.0 1.01 " ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -6602,7 +6776,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -6639,39 +6813,39 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.948\n", - " 0.067\n", - " 0.824\n", - " 1.073\n", + " Intercept\n", + " 0.023\n", + " 0.097\n", + " -0.157\n", + " 0.205\n", " 0.001\n", " 0.001\n", - " 8107.0\n", - " 5585.0\n", + " 6785.0\n", + " 5740.0\n", " 1.0\n", " \n", " \n", - " Intercept\n", - " 0.025\n", - " 0.095\n", - " -0.152\n", - " 0.200\n", + " x\n", + " 0.360\n", + " 0.105\n", + " 0.169\n", + " 0.563\n", " 0.001\n", " 0.001\n", - " 6772.0\n", - " 5624.0\n", + " 6988.0\n", + " 5116.0\n", " 1.0\n", " \n", " \n", - " x\n", - " 0.361\n", - " 0.104\n", - " 0.157\n", - " 0.551\n", + " y_sigma\n", + " 0.946\n", + " 0.067\n", + " 0.831\n", + " 1.081\n", " 0.001\n", " 0.001\n", - " 6682.0\n", - " 5414.0\n", + " 7476.0\n", + " 5971.0\n", " 1.0\n", " \n", " \n", @@ -6680,17 +6854,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.948 0.067 0.824 1.073 0.001 0.001 8107.0 \n", - "Intercept 0.025 0.095 -0.152 0.200 0.001 0.001 6772.0 \n", - "x 0.361 0.104 0.157 0.551 0.001 0.001 6682.0 \n", + "Intercept 0.023 0.097 -0.157 0.205 0.001 0.001 6785.0 \n", + "x 0.360 0.105 0.169 0.563 0.001 0.001 6988.0 \n", + "y_sigma 0.946 0.067 0.831 1.081 0.001 0.001 7476.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 5585.0 1.0 \n", - "Intercept 5624.0 1.0 \n", - "x 5414.0 1.0 " + "Intercept 5740.0 1.0 \n", + "x 5116.0 1.0 \n", + "y_sigma 5971.0 1.0 " ] }, - "execution_count": 6, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -6701,7 +6875,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -6739,38 +6913,38 @@ " \n", " \n", " Intercept\n", - " 0.022\n", - " 0.097\n", - " -0.149\n", - " 0.217\n", + " 0.024\n", + " 0.095\n", + " -0.162\n", + " 0.195\n", " 0.001\n", " 0.001\n", - " 7412.0\n", - " 5758.0\n", + " 6851.0\n", + " 5614.0\n", " 1.0\n", " \n", " \n", " x\n", - " 0.359\n", - " 0.105\n", - " 0.159\n", - " 0.555\n", + " 0.362\n", + " 0.104\n", + " 0.176\n", + " 0.557\n", " 0.001\n", " 0.001\n", - " 7406.0\n", - " 5967.0\n", + " 9241.0\n", + " 6340.0\n", " 1.0\n", " \n", " \n", " y_sigma\n", - " 0.947\n", - " 0.069\n", - " 0.822\n", + " 0.946\n", + " 0.068\n", + " 0.826\n", " 1.079\n", " 0.001\n", " 0.001\n", - " 7371.0\n", - " 5405.0\n", + " 7247.0\n", + " 5711.0\n", " 1.0\n", " \n", " \n", @@ -6779,17 +6953,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.022 0.097 -0.149 0.217 0.001 0.001 7412.0 \n", - "x 0.359 0.105 0.159 0.555 0.001 0.001 7406.0 \n", - "y_sigma 0.947 0.069 0.822 1.079 0.001 0.001 7371.0 \n", + "Intercept 0.024 0.095 -0.162 0.195 0.001 0.001 6851.0 \n", + "x 0.362 0.104 0.176 0.557 0.001 0.001 9241.0 \n", + "y_sigma 0.946 0.068 0.826 1.079 0.001 0.001 7247.0 \n", "\n", " ess_tail r_hat \n", - "Intercept 5758.0 1.0 \n", - "x 5967.0 1.0 \n", - "y_sigma 5405.0 1.0 " + "Intercept 5614.0 1.0 \n", + "x 6340.0 1.0 \n", + "y_sigma 5711.0 1.0 " ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -6800,7 +6974,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -6837,39 +7011,39 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.946\n", - " 0.067\n", - " 0.825\n", - " 1.076\n", - " 0.001\n", - " 0.001\n", - " 6260.0\n", - " 5213.0\n", - " 1.00\n", - " \n", - " \n", " Intercept\n", - " 0.013\n", - " 0.093\n", - " -0.165\n", + " 0.015\n", + " 0.100\n", + " -0.186\n", " 0.190\n", + " 0.004\n", " 0.003\n", - " 0.002\n", - " 924.0\n", - " 1302.0\n", + " 758.0\n", + " 1233.0\n", " 1.02\n", " \n", " \n", " x\n", - " 0.359\n", - " 0.103\n", - " 0.166\n", - " 0.556\n", + " 0.361\n", + " 0.105\n", + " 0.174\n", + " 0.565\n", " 0.001\n", " 0.001\n", - " 5132.0\n", - " 5790.0\n", + " 5084.0\n", + " 4525.0\n", + " 1.00\n", + " \n", + " \n", + " y_sigma\n", + " 0.951\n", + " 0.070\n", + " 0.823\n", + " 1.079\n", + " 0.001\n", + " 0.001\n", + " 5536.0\n", + " 5080.0\n", " 1.00\n", " \n", " \n", @@ -6878,17 +7052,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.946 0.067 0.825 1.076 0.001 0.001 6260.0 \n", - "Intercept 0.013 0.093 -0.165 0.190 0.003 0.002 924.0 \n", - "x 0.359 0.103 0.166 0.556 0.001 0.001 5132.0 \n", + "Intercept 0.015 0.100 -0.186 0.190 0.004 0.003 758.0 \n", + "x 0.361 0.105 0.174 0.565 0.001 0.001 5084.0 \n", + "y_sigma 0.951 0.070 0.823 1.079 0.001 0.001 5536.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 5213.0 1.00 \n", - "Intercept 1302.0 1.02 \n", - "x 5790.0 1.00 " + "Intercept 1233.0 1.02 \n", + "x 4525.0 1.00 \n", + "y_sigma 5080.0 1.00 " ] }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -6908,25 +7082,24 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Fri Mar 01 2024\n", + "Last updated: Sat Apr 13 2024\n", "\n", "Python implementation: CPython\n", - "Python version : 3.11.7\n", - "IPython version : 8.21.0\n", + "Python version : 3.12.2\n", + "IPython version : 8.20.0\n", "\n", - "arviz : 0.17.0\n", - "bambi : 0.13.1.dev16+g9a1387a7.d20240204\n", - "numpy : 1.26.3\n", - "pandas : 2.2.0\n", - "bayeux : 0.1.9\n", - "matplotlib: 3.8.2\n", + "bambi : 0.13.1.dev25+g1e7f677e.d20240413\n", + "pandas: 2.2.1\n", + "numpy : 1.26.4\n", + "bayeux: 0.1.10\n", + "arviz : 0.18.0\n", "\n", "Watermark: 2.4.3\n", "\n" @@ -6955,7 +7128,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 262482de1..f8b5ac674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", - "pymc>=5.12.0", + "pymc>=5.13.0", ] [project.optional-dependencies] diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index 6222f3df3..be260b6bd 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -7,7 +7,9 @@ MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__] -MCMC_METHODS_FILTERED = [i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads"))] +MCMC_METHODS_FILTERED = [ + i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads")) +] @pytest.fixture(scope="module") @@ -30,6 +32,24 @@ def data_n100(): return data +def test_inference_method_names_and_kwargs(): + names = bmb.inference_methods.names + + # Check PyMC inference method family + assert "mcmc" in names["pymc"].keys() + assert "vi" in names["pymc"].keys() + + # Check bayeu inference method family. Currently, only MCMC methods are supported + assert "mcmc" in names["bayeux"].keys() + + # Ensure get_kwargs method raises an error if a non-supported method name is passed + with pytest.raises( + ValueError, + match="Inference method 'not_a_method' not found in the list of available methods. Use `bmb.inference_methods.names` to list the available methods.", + ): + bmb.inference_methods.get_kwargs("not_a_method") + + def test_laplace(): data = pd.DataFrame(np.repeat((0, 1), (30, 60)), columns=["w"]) priors = {"Intercept": bmb.Prior("Uniform", lower=0, upper=1)} @@ -56,7 +76,7 @@ def test_vi(): (mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2 ) -# + @pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED) def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli") diff --git a/tests/test_interpret.py b/tests/test_interpret.py index f9be28957..083e001a0 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -2,12 +2,14 @@ This module contains tests for the helper functions of the 'interpret' sub-package. Tests here do not test any of the plotting functionality. """ + import numpy as np import pandas as pd import pytest import bambi as bmb from bambi.interpret.helpers import data_grid, select_draws +from bambi.interpret.utils import get_model_covariates CHAINS = 4 @@ -190,3 +192,18 @@ def test_select_draws_no_effect(request, mtcars, condition): assert draws.shape == (CHAINS, DRAWS, 14) elif id == "3": assert draws.shape == (CHAINS, DRAWS, 2) + + +# ------------------------------------------------------------------------------------------------ # +# Tests for utils # +# ------------------------------------------------------------------------------------------------ # + + +def test_get_model_covariates(): + """Tests `get_model_covariates()` does not include non-covariate names""" + # See issue 797 + df = pd.DataFrame({"y": np.arange(10), "x": np.random.normal(size=10)}) + knots = np.linspace(np.min(df["x"]), np.max(df["x"]), 4 + 2)[1:-1] + formula = "y ~ 1 + bs(x, degree=3, knots=knots)" + model = bmb.Model(formula, df) + assert set(get_model_covariates(model)) == {"x"} diff --git a/tests/test_models.py b/tests/test_models.py index 717abe7e3..3bdcf4a39 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -192,13 +192,13 @@ def test_cell_means_parameterization(self, crossed_data): def test_2_factors_saturated(self, crossed_data): model = bmb.Model("Y ~ threecats*fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "Intercept", "threecats", "fourcats", "threecats:fourcats", "Y_sigma", - ] + } assert list(idata.posterior["threecats_dim"].values) == ["b", "c"] assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"] assert list(idata.posterior["threecats:fourcats_dim"].values) == [ @@ -214,12 +214,12 @@ def test_2_factors_saturated(self, crossed_data): def test_2_factors_no_intercept(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats*fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "threecats", "fourcats", "threecats:fourcats", "Y_sigma", - ] + } assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"] assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"] assert list(idata.posterior["threecats:fourcats_dim"].values) == [ @@ -235,7 +235,7 @@ def test_2_factors_no_intercept(self, crossed_data): def test_2_factors_cell_means(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats:fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == ["threecats:fourcats", "Y_sigma"] + assert set(idata.posterior.data_vars) == {"threecats:fourcats", "Y_sigma"} assert list(idata.posterior["threecats:fourcats_dim"].values) == [ "a, a", "a, b", @@ -255,7 +255,7 @@ def test_2_factors_cell_means(self, crossed_data): def test_cell_means_with_covariate(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats + continuous", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == ["threecats", "continuous", "Y_sigma"] + assert set(idata.posterior.data_vars) == {"threecats", "continuous", "Y_sigma"} assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"] self.predict_oos(model, idata) @@ -477,7 +477,7 @@ def test_group_specific_categorical_interaction(self, crossed_data): idata = self.fit(model) self.predict_oos(model, idata) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "Intercept", "continuous", "Y_sigma", @@ -485,7 +485,7 @@ def test_group_specific_categorical_interaction(self, crossed_data): "threecats:fourcats|site_sigma", "1|site", "threecats:fourcats|site", - ] + } assert list(idata.posterior["threecats:fourcats|site"].coords) == [ "chain", "draw",