Skip to content

Commit

Permalink
Add support for numpyro and blackjax PyMC samplers (#526)
Browse files Browse the repository at this point in the history
* Add support for numpyro and blackjax PyMC samplers

* Update bambi/models.py

Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>

* Lazily import jax sampling. Refactor sampler_backend to instead be part of the method argument

* Fix for chains bug in numpyro and blackjax backends

* Minor error message fix

* Rename mcmc-numpyro/blackjax to nuts_numpyro/blackjax

* Remove incorrect statement about chain_method as the default is now "parallel"

* Extend tests to also cover numpyro/blackjax samplers

* Run black and pylint

* Re-run black with latest version

* Add pylint error ignores for lazy imports

* Re-add math import on init. Lazy importing looks to have fixed circular imports and this is needed for tests to work

* Add optional dependencies for Jax samplers & modify test setup accordingly

* Update pymc.py

* Run black

* Add new numpyro/blackjax only tests and revert old tests

Co-authored-by: mark <mark@longshotsystems.co.uk>
Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
  • Loading branch information
3 people authored Jun 10, 2022
1 parent 4b89c4f commit 91903c8
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
conda install pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install -r requirements-optional.txt
pip install .
python --version
Expand Down
89 changes: 66 additions & 23 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run(
):
"""Run PyMC sampler."""
# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if method.lower() == "mcmc":
if method.lower() in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
result = self._run_mcmc(
draws,
tune,
Expand All @@ -99,6 +99,7 @@ def run(
chains,
cores,
random_seed,
method.lower(),
**kwargs,
)
elif method.lower() == "vi":
Expand Down Expand Up @@ -209,40 +210,80 @@ def _run_mcmc(
chains=None,
cores=None,
random_seed=None,
sampler_backend="mcmc",
**kwargs,
):
with self.model:
try:
idata = pm.sample(
draws=draws,
tune=tune,
discard_tuned_samples=discard_tuned_samples,
init=init,
n_init=n_init,
chains=chains,
cores=cores,
random_seed=random_seed,
**kwargs,
)
except (RuntimeError, ValueError):
if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto":
_log.info(
"\nThe default initialization using init='auto' has failed, trying to "
"recover by switching to init='adapt_diag'",
)
if sampler_backend == "mcmc":
try:
idata = pm.sample(
draws=draws,
tune=tune,
discard_tuned_samples=discard_tuned_samples,
init="adapt_diag",
init=init,
n_init=n_init,
chains=chains,
cores=cores,
random_seed=random_seed,
**kwargs,
)
else:
raise
except (RuntimeError, ValueError):
if (
"ValueError: Mass matrix contains" in traceback.format_exc()
and init == "auto"
):
_log.info(
"\nThe default initialization using init='auto' has failed, trying to "
"recover by switching to init='adapt_diag'",
)
idata = pm.sample(
draws=draws,
tune=tune,
discard_tuned_samples=discard_tuned_samples,
init="adapt_diag",
n_init=n_init,
chains=chains,
cores=cores,
random_seed=random_seed,
**kwargs,
)
else:
raise
elif sampler_backend == "nuts_numpyro":
# Lazy import to not force users to install Jax
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

if not chains:
chains = (
4 # sample_numpyro_nuts does not handle chains = None like pm.sample does
)
idata = pymc.sampling_jax.sample_numpyro_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
elif sampler_backend == "nuts_blackjax":
# Lazy import to not force users to install Jax
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

if not chains:
chains = (
4 # sample_blackjax_nuts does not handle chains = None like pm.sample does
)
idata = pymc.sampling_jax.sample_blackjax_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f"``mcmc``, ``nuts_numpyro`` or ``nuts_blackjax``"
)

idata = self._clean_mcmc_results(idata, omit_offsets, include_mean)
return idata
Expand Down Expand Up @@ -319,7 +360,9 @@ def _clean_mcmc_results(self, idata, omit_offsets, include_mean):
else:
intercept_name = self.spec.intercept_term.name

idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape)
idata.posterior[intercept_name] = idata.posterior[intercept_name] - np.dot(
X.mean(0), coefs
).reshape(shape)

if include_mean:
self.spec.predict(idata)
Expand Down
3 changes: 3 additions & 0 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def fit(
using the ``fit`` function.
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax``
respectively. Both methods will only work if you can use NUTS sampling, so your model
must be differentiable.
init: str
Initialization method. Defaults to ``"auto"``. The available methods are:
* auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``.
Expand Down
34 changes: 34 additions & 0 deletions bambi/tests/test_built_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,40 @@ def test_laplace_regression():
bmb_model.fit()


def test_logistic_regression_numpyro():
y = pd.Series(np.random.choice(["a", "b"], 50), dtype="category")
data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)})
model = Model("y ~ x", data, family="bernoulli")
model.fit(method="nuts_numpyro", chain_method="vectorized")


def test_logistic_regression_blackjax():
y = pd.Series(np.random.choice(["a", "b"], 50), dtype="category")
data = pd.DataFrame({"y": y, "x": np.random.normal(size=50)})
model = Model("y ~ x", data, family="bernoulli")
model.fit(method="nuts_blackjax", chain_method="vectorized")


def test_regression_blackjax():
size = 1_000
rng = np.random.default_rng(0)
x = rng.normal(size=size)
data = pd.DataFrame({"x": x, "y": rng.normal(loc=x, size=size)})

bmb_model = Model("y ~ x", data)
bmb_model.fit(method="nuts_blackjax", chain_method="vectorized")


def test_regression_nunpyro():
size = 1_000
rng = np.random.default_rng(0)
x = rng.normal(size=size)
data = pd.DataFrame({"x": x, "y": rng.normal(loc=x, size=size)})

bmb_model = Model("y ~ x", data)
bmb_model.fit(method="nuts_numpyro", chain_method="vectorized")


def test_poisson_regression(crossed_data):
# build model using fit and pymc
crossed_data["count"] = (crossed_data["Y"] - crossed_data["Y"].min()).round()
Expand Down
4 changes: 4 additions & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
jax>=0.3.1
jaxlib>=0.3.1
numpyro>=0.9.0
blackjax>=0.7.0
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
README_FILE = os.path.join(PROJECT_ROOT, "README.md")
VERSION_FILE = os.path.join(PROJECT_ROOT, "bambi", "version.py")
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements.txt")
OPTIONAL_REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements-optional.txt")
MINIMUM_PYTHON_VERSION = (3, 7, 2)


Expand All @@ -29,6 +30,11 @@ def get_requirements():
return buff.read().splitlines()


def get_optional_requirements():
with open(OPTIONAL_REQUIREMENTS_FILE, encoding="utf-8") as buff:
return buff.read().splitlines()


def get_version():
with open(VERSION_FILE, encoding="utf-8") as buff:
exec(buff.read()) # pylint: disable=exec-used
Expand All @@ -49,6 +55,9 @@ def get_version():
url="http://github.com/bambinos/bambi",
download_url="https://github.com/bambinos/bambi/archive/%s.tar.gz" % __version__,
install_requires=get_requirements(),
extras_require={
"jax": [get_optional_requirements()],
},
maintainer="Tomas Capretto",
maintainer_email="tomicapretto@gmail.com",
packages=find_packages(exclude=["tests", "test_*"]),
Expand Down

0 comments on commit 91903c8

Please sign in to comment.