From 05d0aecdf99de9b84a7665ff9ef7fb3b3b231d52 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 11:30:06 +0100 Subject: [PATCH 01/16] Add support for numpyro and blackjax PyMC samplers --- bambi/backend/pymc.py | 68 ++++++++++++++++++++++++++++++------------- bambi/models.py | 18 ++++++++++++ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 664bfbc39..8a1165133 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -3,6 +3,7 @@ import numpy as np import pymc as pm +import pymc.sampling_jax import aesara.tensor as at @@ -83,6 +84,7 @@ def run( chains=None, cores=None, random_seed=None, + sampler_backend="default", **kwargs, ): """Run PyMC sampler.""" @@ -99,6 +101,7 @@ def run( chains, cores, random_seed, + sampler_backend, **kwargs, ) elif method.lower() == "vi": @@ -209,40 +212,63 @@ def _run_mcmc( chains=None, cores=None, random_seed=None, + sampler_backend="default", **kwargs, ): with self.model: - try: - idata = pm.sample( - draws=draws, - tune=tune, - discard_tuned_samples=discard_tuned_samples, - init=init, - n_init=n_init, - chains=chains, - cores=cores, - random_seed=random_seed, - **kwargs, - ) - except (RuntimeError, ValueError): - if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto": - _log.info( - "\nThe default initialization using init='auto' has failed, trying to " - "recover by switching to init='adapt_diag'", - ) + if sampler_backend == "default": + try: idata = pm.sample( draws=draws, tune=tune, discard_tuned_samples=discard_tuned_samples, - init="adapt_diag", + init=init, n_init=n_init, chains=chains, cores=cores, random_seed=random_seed, **kwargs, ) - else: - raise + except (RuntimeError, ValueError): + if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto": + _log.info( + "\nThe default initialization using init='auto' has failed, trying to " + "recover by switching to init='adapt_diag'", + ) + idata = pm.sample( + draws=draws, + tune=tune, + discard_tuned_samples=discard_tuned_samples, + init="adapt_diag", + n_init=n_init, + chains=chains, + cores=cores, + random_seed=random_seed, + **kwargs, + ) + else: + raise + elif sampler_backend == "numpyro": + idata = pm.sampling_jax.sample_numpyro_nuts( + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, + ) + elif sampler_backend == "blackjax": + idata = pm.sampling_jax.sample_blackjax_nuts( + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, + ) + else: + raise ValueError( + f'sampler_backend value {sampler_backend} is not valid. Please choose one of' + f'``default``, ``numpyro`` or ``blackjax``' + ) idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) return idata diff --git a/bambi/models.py b/bambi/models.py index d3108a9ef..76b503cd9 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -180,6 +180,7 @@ def fit( chains=None, cores=None, random_seed=None, + sampler_backend="default", **kwargs, ): """Fit the model using PyMC. @@ -239,6 +240,17 @@ def fit( in the system unless there are more than 4 CPUs, in which case it is set to 4. random_seed : int or list of ints A list is accepted if cores is greater than one. + sampler_backend: str + If ``default`` uses the standard PyMC fit() method. Is only valid if method == "mcmc". + Other valid sampling methods are ``numpyro`` and ``blackjax`` which both use JAX (see + https://www.pymc.io/blog/v4_announcement.html#new-jax-backend-for-faster-sampling) + and which offers both performance improvements on the CPU due to JIT compilation + and also GPU use (which will happen automatically if a GPU is detected - see the + JAX documentation https://jax.readthedocs.io/en/latest/index.html). Also, both + non-default methods will only work if you can use NUTS sampling, so your model must + be differentiable. It is also strongly recommended, if using ``numpyro`` or + ``blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or + ``vectorized`` for optimal performance. **kwargs: For other kwargs see the documentation for ``PyMC.sample()``. @@ -248,6 +260,11 @@ def fit( An ``Approximation`` object if ``"vi"`` and a dictionary if ``"laplace"``. """ + if sampler_backend != "default" and method != "mcmc": + raise ValueError( + f"Non-default sampler_backend {sampler_backend} can only be used with method 'mcmc'" + ) + if not self.built: self.build() @@ -271,6 +288,7 @@ def fit( chains=chains, cores=cores, random_seed=random_seed, + sampler_backend=sampler_backend, **kwargs, ) From 5e6997fb5f6b9f4143122ecdf9eef5f67bd8d872 Mon Sep 17 00:00:00 2001 From: markgoodhead Date: Wed, 8 Jun 2022 13:32:25 +0100 Subject: [PATCH 02/16] Update bambi/models.py Co-authored-by: Osvaldo A Martin --- bambi/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/models.py b/bambi/models.py index 76b503cd9..89326d656 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -241,7 +241,7 @@ def fit( random_seed : int or list of ints A list is accepted if cores is greater than one. sampler_backend: str - If ``default`` uses the standard PyMC fit() method. Is only valid if method == "mcmc". + If ``default`` uses the standard PyMC sampling method. Is only valid if method == "mcmc". Other valid sampling methods are ``numpyro`` and ``blackjax`` which both use JAX (see https://www.pymc.io/blog/v4_announcement.html#new-jax-backend-for-faster-sampling) and which offers both performance improvements on the CPU due to JIT compilation From 5698b75683b2a7907a4edc1e7a93a7dcdeda8c50 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 13:47:03 +0100 Subject: [PATCH 03/16] Lazily import jax sampling. Refactor sampler_backend to instead be part of the method argument --- bambi/__init__.py | 2 -- bambi/backend/pymc.py | 16 ++++++++-------- bambi/models.py | 23 +++++------------------ 3 files changed, 13 insertions(+), 28 deletions(-) diff --git a/bambi/__init__.py b/bambi/__init__.py index 6d8dc640d..42f9843f0 100644 --- a/bambi/__init__.py +++ b/bambi/__init__.py @@ -1,7 +1,5 @@ import logging -from pymc import math - from .data import clear_data_home, load_data from .families import Family, Likelihood, Link from .models import Model diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 8a1165133..831419858 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -3,7 +3,6 @@ import numpy as np import pymc as pm -import pymc.sampling_jax import aesara.tensor as at @@ -84,12 +83,11 @@ def run( chains=None, cores=None, random_seed=None, - sampler_backend="default", **kwargs, ): """Run PyMC sampler.""" # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if method.lower() == "mcmc": + if method.lower() in ["mcmc", "mcmc-numpyro", "mcmc-blackjax"]: result = self._run_mcmc( draws, tune, @@ -101,7 +99,7 @@ def run( chains, cores, random_seed, - sampler_backend, + method.lower(), **kwargs, ) elif method.lower() == "vi": @@ -212,11 +210,11 @@ def _run_mcmc( chains=None, cores=None, random_seed=None, - sampler_backend="default", + sampler_backend="mcmc", **kwargs, ): with self.model: - if sampler_backend == "default": + if sampler_backend == "mcmc": try: idata = pm.sample( draws=draws, @@ -248,7 +246,8 @@ def _run_mcmc( ) else: raise - elif sampler_backend == "numpyro": + elif sampler_backend == "mcmc-numpyro": + import pymc.sampling_jax # Lazy import to not force users to install Jax idata = pm.sampling_jax.sample_numpyro_nuts( draws=draws, tune=tune, @@ -256,7 +255,8 @@ def _run_mcmc( random_seed=random_seed, **kwargs, ) - elif sampler_backend == "blackjax": + elif sampler_backend == "mcmc-blackjax": + import pymc.sampling_jax # Lazy import to not force users to install Jax idata = pm.sampling_jax.sample_blackjax_nuts( draws=draws, tune=tune, diff --git a/bambi/models.py b/bambi/models.py index 76b503cd9..3dc302b5c 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -180,7 +180,6 @@ def fit( chains=None, cores=None, random_seed=None, - sampler_backend="default", **kwargs, ): """Fit the model using PyMC. @@ -209,6 +208,11 @@ def fit( using the ``fit`` function. Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. + To use the PyMC numpyro and blackjax samplers, use ``mcmc-numpyro`` or ``mcmc-blackjax`` + respectively. Both methods will only work if you can use NUTS sampling, so your model must + be differentiable. It is also recommended, if using ``mcmc-numpyro`` or + ``mcmc-blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or + ``vectorized`` for optimal performance. init: str Initialization method. Defaults to ``"auto"``. The available methods are: * auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``. @@ -240,17 +244,6 @@ def fit( in the system unless there are more than 4 CPUs, in which case it is set to 4. random_seed : int or list of ints A list is accepted if cores is greater than one. - sampler_backend: str - If ``default`` uses the standard PyMC fit() method. Is only valid if method == "mcmc". - Other valid sampling methods are ``numpyro`` and ``blackjax`` which both use JAX (see - https://www.pymc.io/blog/v4_announcement.html#new-jax-backend-for-faster-sampling) - and which offers both performance improvements on the CPU due to JIT compilation - and also GPU use (which will happen automatically if a GPU is detected - see the - JAX documentation https://jax.readthedocs.io/en/latest/index.html). Also, both - non-default methods will only work if you can use NUTS sampling, so your model must - be differentiable. It is also strongly recommended, if using ``numpyro`` or - ``blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or - ``vectorized`` for optimal performance. **kwargs: For other kwargs see the documentation for ``PyMC.sample()``. @@ -260,11 +253,6 @@ def fit( An ``Approximation`` object if ``"vi"`` and a dictionary if ``"laplace"``. """ - if sampler_backend != "default" and method != "mcmc": - raise ValueError( - f"Non-default sampler_backend {sampler_backend} can only be used with method 'mcmc'" - ) - if not self.built: self.build() @@ -288,7 +276,6 @@ def fit( chains=chains, cores=cores, random_seed=random_seed, - sampler_backend=sampler_backend, **kwargs, ) From a62a11539dc3d817271de0a0f70801012965a855 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 13:52:52 +0100 Subject: [PATCH 04/16] Fix for chains bug in numpyro and blackjax backends --- bambi/backend/pymc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 831419858..4705669fa 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -248,6 +248,8 @@ def _run_mcmc( raise elif sampler_backend == "mcmc-numpyro": import pymc.sampling_jax # Lazy import to not force users to install Jax + if not chains: + chains = 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does idata = pm.sampling_jax.sample_numpyro_nuts( draws=draws, tune=tune, @@ -257,6 +259,8 @@ def _run_mcmc( ) elif sampler_backend == "mcmc-blackjax": import pymc.sampling_jax # Lazy import to not force users to install Jax + if not chains: + chains = 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does idata = pm.sampling_jax.sample_blackjax_nuts( draws=draws, tune=tune, From 21aa870d7022d3018c950a749284dc1ff967e6d9 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 14:00:10 +0100 Subject: [PATCH 05/16] Minor error message fix --- bambi/backend/pymc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 4705669fa..c1dfec2c2 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -271,7 +271,7 @@ def _run_mcmc( else: raise ValueError( f'sampler_backend value {sampler_backend} is not valid. Please choose one of' - f'``default``, ``numpyro`` or ``blackjax``' + f'``mcmc``, ``mcmc-numpyro`` or `mcmc-blackjax``' ) idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) From 99cc5f041327ce93ac3360c0efc7ca93d3ad0592 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 14:18:27 +0100 Subject: [PATCH 06/16] Rename mcmc-numpyro/blackjax to nuts_numpyro/blackjax --- bambi/backend/pymc.py | 8 ++++---- bambi/models.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index c1dfec2c2..53d1880c2 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -87,7 +87,7 @@ def run( ): """Run PyMC sampler.""" # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if method.lower() in ["mcmc", "mcmc-numpyro", "mcmc-blackjax"]: + if method.lower() in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: result = self._run_mcmc( draws, tune, @@ -246,7 +246,7 @@ def _run_mcmc( ) else: raise - elif sampler_backend == "mcmc-numpyro": + elif sampler_backend == "nuts_numpyro": import pymc.sampling_jax # Lazy import to not force users to install Jax if not chains: chains = 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does @@ -257,7 +257,7 @@ def _run_mcmc( random_seed=random_seed, **kwargs, ) - elif sampler_backend == "mcmc-blackjax": + elif sampler_backend == "nuts_blackjax": import pymc.sampling_jax # Lazy import to not force users to install Jax if not chains: chains = 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does @@ -271,7 +271,7 @@ def _run_mcmc( else: raise ValueError( f'sampler_backend value {sampler_backend} is not valid. Please choose one of' - f'``mcmc``, ``mcmc-numpyro`` or `mcmc-blackjax``' + f'``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``' ) idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) diff --git a/bambi/models.py b/bambi/models.py index 3dc302b5c..b99e02b25 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -208,10 +208,10 @@ def fit( using the ``fit`` function. Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. - To use the PyMC numpyro and blackjax samplers, use ``mcmc-numpyro`` or ``mcmc-blackjax`` + To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax`` respectively. Both methods will only work if you can use NUTS sampling, so your model must - be differentiable. It is also recommended, if using ``mcmc-numpyro`` or - ``mcmc-blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or + be differentiable. It is also recommended, if using ``nuts_numpyro`` or + ``nuts_blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or ``vectorized`` for optimal performance. init: str Initialization method. Defaults to ``"auto"``. The available methods are: From dc2609cb3e7c75b179b23a12f375467146740ff5 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 14:35:52 +0100 Subject: [PATCH 07/16] Remove incorrect statement about chain_method as the default is now "parallel" --- bambi/models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bambi/models.py b/bambi/models.py index b99e02b25..482e15858 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -210,9 +210,7 @@ def fit( recommended other than for pedagogical use. To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax`` respectively. Both methods will only work if you can use NUTS sampling, so your model must - be differentiable. It is also recommended, if using ``nuts_numpyro`` or - ``nuts_blackjax``, to set the kwarg of ``chain_method`` to either ``parallel`` or - ``vectorized`` for optimal performance. + be differentiable. init: str Initialization method. Defaults to ``"auto"``. The available methods are: * auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``. From 0dfc44c3f617f372b281ca209c7f5540035faebd Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 17:14:09 +0100 Subject: [PATCH 08/16] Extend tests to also cover numpyro/blackjax samplers --- bambi/tests/test_built_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bambi/tests/test_built_models.py b/bambi/tests/test_built_models.py index 9c4b33e89..4970b011b 100644 --- a/bambi/tests/test_built_models.py +++ b/bambi/tests/test_built_models.py @@ -301,6 +301,8 @@ def test_group_specific_categorical_interaction(crossed_data): crossed_data["fourcats"] = sum([[x] * 10 for x in ["a", "b", "c", "d"]], list()) * 3 model = Model("Y ~ continuous + (threecats:fourcats|site)", crossed_data) model.fit(tune=10, draws=10) + model.fit(tune=10, draws=10, method="nuts_numpyro") + model.fit(tune=10, draws=10, method="nuts_blackjax") def test_logistic_regression_empty_index(): @@ -327,6 +329,8 @@ def test_logistic_regression_categoric(): data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)}) model = Model("y ~ x", data, family="bernoulli") model.fit() + model.fit(tune=10, draws=10, method="nuts_numpyro") + model.fit(tune=10, draws=10, method="nuts_blackjax") def test_poisson_regression(crossed_data): @@ -449,6 +453,8 @@ def test_gamma_regression(dm): data = dm[["order", "ind_mg_dry"]] model = Model("ind_mg_dry ~ order", data, family="gamma", link="log") model.fit(draws=10, tune=10) + model.fit(tune=10, draws=10, method="nuts_numpyro") + model.fit(tune=10, draws=10, method="nuts_blackjax") def test_beta_regression(): From 9dbb1f9a3856ce33bd992fbb4d1d3d4ffe583940 Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 18:36:00 +0100 Subject: [PATCH 09/16] Run black and pylint --- bambi/backend/pymc.py | 31 +++++++++++++------------- bambi/families/link.py | 6 ++--- bambi/models.py | 4 ++-- bambi/priors/scaler.py | 2 +- bambi/tests/test_built_models.py | 21 ++++------------- bambi/tests/test_model_construction.py | 17 ++------------ bambi/tests/test_predict.py | 8 ++----- bambi/tests/test_priors.py | 7 +----- 8 files changed, 31 insertions(+), 65 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 53d1880c2..41aea8756 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -228,7 +228,10 @@ def _run_mcmc( **kwargs, ) except (RuntimeError, ValueError): - if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto": + if ( + "ValueError: Mass matrix contains" in traceback.format_exc() + and init == "auto" + ): _log.info( "\nThe default initialization using init='auto' has failed, trying to " "recover by switching to init='adapt_diag'", @@ -248,30 +251,28 @@ def _run_mcmc( raise elif sampler_backend == "nuts_numpyro": import pymc.sampling_jax # Lazy import to not force users to install Jax + if not chains: - chains = 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does + chains = ( + 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does + ) idata = pm.sampling_jax.sample_numpyro_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, + draws=draws, tune=tune, chains=chains, random_seed=random_seed, **kwargs, ) elif sampler_backend == "nuts_blackjax": import pymc.sampling_jax # Lazy import to not force users to install Jax + if not chains: - chains = 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does + chains = ( + 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does + ) idata = pm.sampling_jax.sample_blackjax_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, + draws=draws, tune=tune, chains=chains, random_seed=random_seed, **kwargs, ) else: raise ValueError( - f'sampler_backend value {sampler_backend} is not valid. Please choose one of' - f'``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``' + f"sampler_backend value {sampler_backend} is not valid. Please choose one of" + f"``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``" ) idata = self._clean_mcmc_results(idata, omit_offsets, include_mean) diff --git a/bambi/families/link.py b/bambi/families/link.py index 8ea85b28c..8d63ef24d 100644 --- a/bambi/families/link.py +++ b/bambi/families/link.py @@ -41,12 +41,12 @@ def invcloglog(eta): def probit(mu): """Probit function that ensures the input is in (0, 1)""" mu = force_within_unit_interval(mu) - return 2**0.5 * special.erfinv(2 * mu - 1) # pylint: disable=no-member + return 2 ** 0.5 * special.erfinv(2 * mu - 1) # pylint: disable=no-member def invprobit(eta): """Inverse of the probit function that ensures result is in (0, 1)""" - result = 0.5 + 0.5 * special.erf(eta / 2**0.5) # pylint: disable=no-member + result = 0.5 + 0.5 * special.erf(eta / 2 ** 0.5) # pylint: disable=no-member return force_within_unit_interval(result) @@ -70,7 +70,7 @@ def softmax(eta, axis=None): def inverse_squared(mu): - return 1 / mu**2 + return 1 / mu ** 2 def inv_inverse_squared(eta): diff --git a/bambi/models.py b/bambi/models.py index 482e15858..0f5275fa6 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -209,8 +209,8 @@ def fit( Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax`` - respectively. Both methods will only work if you can use NUTS sampling, so your model must - be differentiable. + respectively. Both methods will only work if you can use NUTS sampling, so your model + must be differentiable. init: str Initialization method. Defaults to ``"auto"``. The available methods are: * auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``. diff --git a/bambi/priors/scaler.py b/bambi/priors/scaler.py index 686faa655..bb6596324 100644 --- a/bambi/priors/scaler.py +++ b/bambi/priors/scaler.py @@ -32,7 +32,7 @@ def get_intercept_stats(self): if self.priors: sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()]) x_mean = np.hstack([self.model.terms[term].data.mean(axis=0) for term in self.priors]) - sigma = (sigma**2 + np.dot(sigmas**2, x_mean**2)) ** 0.5 + sigma = (sigma ** 2 + np.dot(sigmas ** 2, x_mean ** 2)) ** 0.5 return mu, sigma diff --git a/bambi/tests/test_built_models.py b/bambi/tests/test_built_models.py index 4970b011b..7ff5c022f 100644 --- a/bambi/tests/test_built_models.py +++ b/bambi/tests/test_built_models.py @@ -131,9 +131,7 @@ def test_many_common_many_group_specific(crossed_data): dropna=True, ) model0.fit( - tune=10, - draws=10, - chains=2, + tune=10, draws=10, chains=2, ) model1 = Model( @@ -142,9 +140,7 @@ def test_many_common_many_group_specific(crossed_data): dropna=True, ) model1.fit( - tune=10, - draws=10, - chains=2, + tune=10, draws=10, chains=2, ) # check that the group specific effects design matrices have the same shape X0 = pd.concat([pd.DataFrame(t.data) for t in model0.group_specific_terms.values()], axis=1) @@ -395,11 +391,7 @@ def test_laplace(): def test_prior_predictive(crossed_data): crossed_data["count"] = (crossed_data["Y"] - crossed_data["Y"].min()).round() # New default priors are too wide for this case... something to keep investigating - model = Model( - "count ~ threecats + continuous + dummy", - crossed_data, - family="poisson", - ) + model = Model("count ~ threecats + continuous + dummy", crossed_data, family="poisson",) model.build() print(model) pps = model.prior_predictive(draws=500) @@ -501,12 +493,7 @@ def test_potentials(): ] model = Model( - "w ~ 1", - data, - family="bernoulli", - link="identity", - priors=priors, - potentials=potentials, + "w ~ 1", data, family="bernoulli", link="identity", priors=priors, potentials=potentials, ) model.build() assert len(model.backend.model.potentials) == 2 diff --git a/bambi/tests/test_model_construction.py b/bambi/tests/test_model_construction.py index 17ef731b9..1411f3cbd 100644 --- a/bambi/tests/test_model_construction.py +++ b/bambi/tests/test_model_construction.py @@ -17,12 +17,7 @@ @pytest.fixture(scope="module") def data_numeric_xy(): - data = pd.DataFrame( - { - "y": np.random.normal(size=100), - "x": np.random.normal(size=100), - } - ) + data = pd.DataFrame({"y": np.random.normal(size=100), "x": np.random.normal(size=100),}) return data @@ -288,15 +283,7 @@ def test_hyperprior_on_common_effect(): @pytest.mark.parametrize( "family", - [ - "gaussian", - "negativebinomial", - "bernoulli", - "poisson", - "gamma", - "vonmises", - "wald", - ], + ["gaussian", "negativebinomial", "bernoulli", "poisson", "gamma", "vonmises", "wald",], ) def test_automatic_priors(family): """Test that automatic priors work correctly""" diff --git a/bambi/tests/test_predict.py b/bambi/tests/test_predict.py index 9e7619a14..344ea98ea 100644 --- a/bambi/tests/test_predict.py +++ b/bambi/tests/test_predict.py @@ -334,10 +334,7 @@ def test_predict_include_group_specific(): size = 100 data = pd.DataFrame( - { - "y": rng.choice([0, 1], size=size), - "x1": rng.choice(list("abcd"), size=size), - } + {"y": rng.choice([0, 1], size=size), "x1": rng.choice(list("abcd"), size=size),} ) model = Model("y ~ 1 + (1|x1)", data, family="bernoulli") @@ -346,8 +343,7 @@ def test_predict_include_group_specific(): idata_2 = model.predict(idata, data=data, inplace=False, include_group_specific=False) assert not np.isclose( - idata_1.posterior["y_mean"].values, - idata_2.posterior["y_mean"].values, + idata_1.posterior["y_mean"].values, idata_2.posterior["y_mean"].values, ).any() # Since it's an intercept-only model, predictions are the same for all observations if diff --git a/bambi/tests/test_priors.py b/bambi/tests/test_priors.py index 14588009d..766aefd97 100644 --- a/bambi/tests/test_priors.py +++ b/bambi/tests/test_priors.py @@ -221,12 +221,7 @@ def test_set_prior_with_tuple(): def test_set_prior_unexisting_term(): - data = pd.DataFrame( - { - "y": np.random.normal(size=100), - "x": np.random.normal(size=100), - } - ) + data = pd.DataFrame({"y": np.random.normal(size=100), "x": np.random.normal(size=100),}) prior = Prior("Uniform", lower=0, upper=50) model = Model("y ~ x", data) with pytest.raises(ValueError): From 1377b58bd9a6bb44b951944db98078ddaaac485f Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 19:11:27 +0100 Subject: [PATCH 10/16] Re-run black with latest version --- bambi/backend/pymc.py | 12 ++++++++++-- bambi/families/link.py | 6 +++--- bambi/priors/scaler.py | 2 +- bambi/tests/test_built_models.py | 21 +++++++++++++++++---- bambi/tests/test_model_construction.py | 17 +++++++++++++++-- bambi/tests/test_predict.py | 8 ++++++-- bambi/tests/test_priors.py | 7 ++++++- 7 files changed, 58 insertions(+), 15 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 41aea8756..0cc967096 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -257,7 +257,11 @@ def _run_mcmc( 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does ) idata = pm.sampling_jax.sample_numpyro_nuts( - draws=draws, tune=tune, chains=chains, random_seed=random_seed, **kwargs, + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, ) elif sampler_backend == "nuts_blackjax": import pymc.sampling_jax # Lazy import to not force users to install Jax @@ -267,7 +271,11 @@ def _run_mcmc( 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does ) idata = pm.sampling_jax.sample_blackjax_nuts( - draws=draws, tune=tune, chains=chains, random_seed=random_seed, **kwargs, + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, ) else: raise ValueError( diff --git a/bambi/families/link.py b/bambi/families/link.py index 8d63ef24d..8ea85b28c 100644 --- a/bambi/families/link.py +++ b/bambi/families/link.py @@ -41,12 +41,12 @@ def invcloglog(eta): def probit(mu): """Probit function that ensures the input is in (0, 1)""" mu = force_within_unit_interval(mu) - return 2 ** 0.5 * special.erfinv(2 * mu - 1) # pylint: disable=no-member + return 2**0.5 * special.erfinv(2 * mu - 1) # pylint: disable=no-member def invprobit(eta): """Inverse of the probit function that ensures result is in (0, 1)""" - result = 0.5 + 0.5 * special.erf(eta / 2 ** 0.5) # pylint: disable=no-member + result = 0.5 + 0.5 * special.erf(eta / 2**0.5) # pylint: disable=no-member return force_within_unit_interval(result) @@ -70,7 +70,7 @@ def softmax(eta, axis=None): def inverse_squared(mu): - return 1 / mu ** 2 + return 1 / mu**2 def inv_inverse_squared(eta): diff --git a/bambi/priors/scaler.py b/bambi/priors/scaler.py index bb6596324..686faa655 100644 --- a/bambi/priors/scaler.py +++ b/bambi/priors/scaler.py @@ -32,7 +32,7 @@ def get_intercept_stats(self): if self.priors: sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()]) x_mean = np.hstack([self.model.terms[term].data.mean(axis=0) for term in self.priors]) - sigma = (sigma ** 2 + np.dot(sigmas ** 2, x_mean ** 2)) ** 0.5 + sigma = (sigma**2 + np.dot(sigmas**2, x_mean**2)) ** 0.5 return mu, sigma diff --git a/bambi/tests/test_built_models.py b/bambi/tests/test_built_models.py index 7ff5c022f..4970b011b 100644 --- a/bambi/tests/test_built_models.py +++ b/bambi/tests/test_built_models.py @@ -131,7 +131,9 @@ def test_many_common_many_group_specific(crossed_data): dropna=True, ) model0.fit( - tune=10, draws=10, chains=2, + tune=10, + draws=10, + chains=2, ) model1 = Model( @@ -140,7 +142,9 @@ def test_many_common_many_group_specific(crossed_data): dropna=True, ) model1.fit( - tune=10, draws=10, chains=2, + tune=10, + draws=10, + chains=2, ) # check that the group specific effects design matrices have the same shape X0 = pd.concat([pd.DataFrame(t.data) for t in model0.group_specific_terms.values()], axis=1) @@ -391,7 +395,11 @@ def test_laplace(): def test_prior_predictive(crossed_data): crossed_data["count"] = (crossed_data["Y"] - crossed_data["Y"].min()).round() # New default priors are too wide for this case... something to keep investigating - model = Model("count ~ threecats + continuous + dummy", crossed_data, family="poisson",) + model = Model( + "count ~ threecats + continuous + dummy", + crossed_data, + family="poisson", + ) model.build() print(model) pps = model.prior_predictive(draws=500) @@ -493,7 +501,12 @@ def test_potentials(): ] model = Model( - "w ~ 1", data, family="bernoulli", link="identity", priors=priors, potentials=potentials, + "w ~ 1", + data, + family="bernoulli", + link="identity", + priors=priors, + potentials=potentials, ) model.build() assert len(model.backend.model.potentials) == 2 diff --git a/bambi/tests/test_model_construction.py b/bambi/tests/test_model_construction.py index 1411f3cbd..17ef731b9 100644 --- a/bambi/tests/test_model_construction.py +++ b/bambi/tests/test_model_construction.py @@ -17,7 +17,12 @@ @pytest.fixture(scope="module") def data_numeric_xy(): - data = pd.DataFrame({"y": np.random.normal(size=100), "x": np.random.normal(size=100),}) + data = pd.DataFrame( + { + "y": np.random.normal(size=100), + "x": np.random.normal(size=100), + } + ) return data @@ -283,7 +288,15 @@ def test_hyperprior_on_common_effect(): @pytest.mark.parametrize( "family", - ["gaussian", "negativebinomial", "bernoulli", "poisson", "gamma", "vonmises", "wald",], + [ + "gaussian", + "negativebinomial", + "bernoulli", + "poisson", + "gamma", + "vonmises", + "wald", + ], ) def test_automatic_priors(family): """Test that automatic priors work correctly""" diff --git a/bambi/tests/test_predict.py b/bambi/tests/test_predict.py index 344ea98ea..9e7619a14 100644 --- a/bambi/tests/test_predict.py +++ b/bambi/tests/test_predict.py @@ -334,7 +334,10 @@ def test_predict_include_group_specific(): size = 100 data = pd.DataFrame( - {"y": rng.choice([0, 1], size=size), "x1": rng.choice(list("abcd"), size=size),} + { + "y": rng.choice([0, 1], size=size), + "x1": rng.choice(list("abcd"), size=size), + } ) model = Model("y ~ 1 + (1|x1)", data, family="bernoulli") @@ -343,7 +346,8 @@ def test_predict_include_group_specific(): idata_2 = model.predict(idata, data=data, inplace=False, include_group_specific=False) assert not np.isclose( - idata_1.posterior["y_mean"].values, idata_2.posterior["y_mean"].values, + idata_1.posterior["y_mean"].values, + idata_2.posterior["y_mean"].values, ).any() # Since it's an intercept-only model, predictions are the same for all observations if diff --git a/bambi/tests/test_priors.py b/bambi/tests/test_priors.py index 766aefd97..14588009d 100644 --- a/bambi/tests/test_priors.py +++ b/bambi/tests/test_priors.py @@ -221,7 +221,12 @@ def test_set_prior_with_tuple(): def test_set_prior_unexisting_term(): - data = pd.DataFrame({"y": np.random.normal(size=100), "x": np.random.normal(size=100),}) + data = pd.DataFrame( + { + "y": np.random.normal(size=100), + "x": np.random.normal(size=100), + } + ) prior = Prior("Uniform", lower=0, upper=50) model = Model("y ~ x", data) with pytest.raises(ValueError): From 0f4bca665aa3c6c6e60f32563c20a1475df2636a Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 19:30:05 +0100 Subject: [PATCH 11/16] Add pylint error ignores for lazy imports --- bambi/backend/pymc.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 0cc967096..b957da505 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -250,13 +250,14 @@ def _run_mcmc( else: raise elif sampler_backend == "nuts_numpyro": - import pymc.sampling_jax # Lazy import to not force users to install Jax + # Lazy import to not force users to install Jax + import pymc.sampling_jax # pylint: disable=import-outside-toplevel if not chains: chains = ( 4 # sample_numpyro_nuts does not handle chains = None like pm.sample does ) - idata = pm.sampling_jax.sample_numpyro_nuts( + idata = pymc.sampling_jax.sample_numpyro_nuts( draws=draws, tune=tune, chains=chains, @@ -264,13 +265,14 @@ def _run_mcmc( **kwargs, ) elif sampler_backend == "nuts_blackjax": - import pymc.sampling_jax # Lazy import to not force users to install Jax + # Lazy import to not force users to install Jax + import pymc.sampling_jax # pylint: disable=import-outside-toplevel if not chains: chains = ( 4 # sample_blackjax_nuts does not handle chains = None like pm.sample does ) - idata = pm.sampling_jax.sample_blackjax_nuts( + idata = pymc.sampling_jax.sample_blackjax_nuts( draws=draws, tune=tune, chains=chains, From 5a7c7fab66b286e925147d8d742b0121c42aaacc Mon Sep 17 00:00:00 2001 From: mark Date: Wed, 8 Jun 2022 21:20:59 +0100 Subject: [PATCH 12/16] Re-add math import on init. Lazy importing looks to have fixed circular imports and this is needed for tests to work --- bambi/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bambi/__init__.py b/bambi/__init__.py index 42f9843f0..6d8dc640d 100644 --- a/bambi/__init__.py +++ b/bambi/__init__.py @@ -1,5 +1,7 @@ import logging +from pymc import math + from .data import clear_data_home, load_data from .families import Family, Likelihood, Link from .models import Model From 9ecdc010362cc648cf3a16f1ee5bb073c53d3a78 Mon Sep 17 00:00:00 2001 From: mark Date: Thu, 9 Jun 2022 09:44:08 +0100 Subject: [PATCH 13/16] Add optional dependencies for Jax samplers & modify test setup accordingly --- .github/workflows/test.yml | 1 + requirements-optional.txt | 4 ++++ setup.py | 9 +++++++++ 3 files changed, 14 insertions(+) create mode 100644 requirements-optional.txt diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b15c3ff26..7a6b22233 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,6 +34,7 @@ jobs: conda install pip pip install -r requirements.txt pip install -r requirements-dev.txt + pip install -r requirements-optional.txt pip install . python --version diff --git a/requirements-optional.txt b/requirements-optional.txt new file mode 100644 index 000000000..beee7714b --- /dev/null +++ b/requirements-optional.txt @@ -0,0 +1,4 @@ +jax>=0.3.1 +jaxlib>=0.3.1 +numpyro>=0.9.0 +blackjax>=0.7.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 948be9c2c..dafea4347 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ README_FILE = os.path.join(PROJECT_ROOT, "README.md") VERSION_FILE = os.path.join(PROJECT_ROOT, "bambi", "version.py") REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements.txt") +OPTIONAL_REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements-optional.txt") MINIMUM_PYTHON_VERSION = (3, 7, 2) @@ -29,6 +30,11 @@ def get_requirements(): return buff.read().splitlines() +def get_optional_requirements(): + with open(OPTIONAL_REQUIREMENTS_FILE, encoding="utf-8") as buff: + return buff.read().splitlines() + + def get_version(): with open(VERSION_FILE, encoding="utf-8") as buff: exec(buff.read()) # pylint: disable=exec-used @@ -49,6 +55,9 @@ def get_version(): url="http://github.com/bambinos/bambi", download_url="https://github.com/bambinos/bambi/archive/%s.tar.gz" % __version__, install_requires=get_requirements(), + extras_require={ + "jax": [get_optional_requirements()], + }, maintainer="Tomas Capretto", maintainer_email="tomicapretto@gmail.com", packages=find_packages(exclude=["tests", "test_*"]), From 324f8388ac065834922bb387cbd3355e1bb5b5b2 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Thu, 9 Jun 2022 09:24:02 -0300 Subject: [PATCH 14/16] Update pymc.py --- bambi/backend/pymc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index b957da505..abcb7b252 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -360,7 +360,7 @@ def _clean_mcmc_results(self, idata, omit_offsets, include_mean): else: intercept_name = self.spec.intercept_term.name - idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape) + idata.posterior[intercept_name] = idata.posterior[intercept_name] - np.dot(X.mean(0), coefs).reshape(shape) if include_mean: self.spec.predict(idata) From 0621b5bba496fdf0c1a0fba657c9a01f13758fb9 Mon Sep 17 00:00:00 2001 From: mark Date: Thu, 9 Jun 2022 13:32:08 +0100 Subject: [PATCH 15/16] Run black --- bambi/backend/pymc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index abcb7b252..f5267c51b 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -360,7 +360,9 @@ def _clean_mcmc_results(self, idata, omit_offsets, include_mean): else: intercept_name = self.spec.intercept_term.name - idata.posterior[intercept_name] = idata.posterior[intercept_name] - np.dot(X.mean(0), coefs).reshape(shape) + idata.posterior[intercept_name] = idata.posterior[intercept_name] - np.dot( + X.mean(0), coefs + ).reshape(shape) if include_mean: self.spec.predict(idata) From f5dfd5ed33a3543418dfc125b544b85a65f06389 Mon Sep 17 00:00:00 2001 From: mark Date: Thu, 9 Jun 2022 14:57:18 +0100 Subject: [PATCH 16/16] Add new numpyro/blackjax only tests and revert old tests --- bambi/tests/test_built_models.py | 40 +++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/bambi/tests/test_built_models.py b/bambi/tests/test_built_models.py index 4970b011b..bc8bb167d 100644 --- a/bambi/tests/test_built_models.py +++ b/bambi/tests/test_built_models.py @@ -301,8 +301,6 @@ def test_group_specific_categorical_interaction(crossed_data): crossed_data["fourcats"] = sum([[x] * 10 for x in ["a", "b", "c", "d"]], list()) * 3 model = Model("Y ~ continuous + (threecats:fourcats|site)", crossed_data) model.fit(tune=10, draws=10) - model.fit(tune=10, draws=10, method="nuts_numpyro") - model.fit(tune=10, draws=10, method="nuts_blackjax") def test_logistic_regression_empty_index(): @@ -329,8 +327,40 @@ def test_logistic_regression_categoric(): data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)}) model = Model("y ~ x", data, family="bernoulli") model.fit() - model.fit(tune=10, draws=10, method="nuts_numpyro") - model.fit(tune=10, draws=10, method="nuts_blackjax") + + +def test_logistic_regression_numpyro(): + y = pd.Series(np.random.choice(["a", "b"], 50), dtype="category") + data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)}) + model = Model("y ~ x", data, family="bernoulli") + model.fit(method="nuts_numpyro", chain_method="vectorized") + + +def test_logistic_regression_blackjax(): + y = pd.Series(np.random.choice(["a", "b"], 50), dtype="category") + data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)}) + model = Model("y ~ x", data, family="bernoulli") + model.fit(method="nuts_blackjax", chain_method="vectorized") + + +def test_regression_blackjax(): + size = 1_000 + rng = np.random.default_rng(0) + x = rng.normal(size=size) + data = pd.DataFrame({"x": x, "y": rng.normal(loc=x, size=size)}) + + bmb_model = Model("y ~ x", data) + bmb_model.fit(method="nuts_blackjax", chain_method="vectorized") + + +def test_regression_nunpyro(): + size = 1_000 + rng = np.random.default_rng(0) + x = rng.normal(size=size) + data = pd.DataFrame({"x": x, "y": rng.normal(loc=x, size=size)}) + + bmb_model = Model("y ~ x", data) + bmb_model.fit(method="nuts_numpyro", chain_method="vectorized") def test_poisson_regression(crossed_data): @@ -453,8 +483,6 @@ def test_gamma_regression(dm): data = dm[["order", "ind_mg_dry"]] model = Model("ind_mg_dry ~ order", data, family="gamma", link="log") model.fit(draws=10, tune=10) - model.fit(tune=10, draws=10, method="nuts_numpyro") - model.fit(tune=10, draws=10, method="nuts_blackjax") def test_beta_regression():