Skip to content

Commit

Permalink
Add experimental JAX samplers (#4247)
Browse files Browse the repository at this point in the history
* Add JAX NUTS samplers from TFP and numpyro. With @junpenglao.

* Add missing import.

* Remove JAX as default linker.

* Add experimental warning and clean up imports.

* Add JAX nb.

* Add NB to toc.

* Black and isort.

* Change title.

* Remove comma

* Typo.

* nbqa NB

* Run pre-commit.

* Disable pylint.

* Add to release-notes.
  • Loading branch information
twiecki authored Nov 27, 2020
1 parent 22c079c commit 5ff9bbc
Show file tree
Hide file tree
Showing 4 changed files with 570 additions and 1 deletion.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip install theano-pymc
This new version of `Theano-PyMC` comes with an experimental JAX backend which, when combined with the new and experimental JAX samplers in PyMC3, can greatly speed up sampling in your model. As this is still very new, please do not use it in production yet but do test it out and let us know if anything breaks and what results you are seeing, especially speed-wise.

### New features
- New experimental JAX samplers in `pymc3.sample_jax` (see [notebook](https://docs.pymc.io/notebooks/GLM-hierarchical-jax.html) and [#4247](https://github.com/pymc-devs/pymc3/pull/4247)). Requires JAX and either TFP or numpyro.
- Add MLDA, a new stepper for multilevel sampling. MLDA can be used when a hierarchy of approximate posteriors of varying accuracy is available, offering improved sampling efficiency especially in high-dimensional problems and/or where gradients are not available (see [#3926](https://github.com/pymc-devs/pymc3/pull/3926))
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
- Added `pymc3.gp.cov.Circular` kernel for Gaussian Processes on circular domains, e.g. the unit circle (see [#4082](https://github.com/pymc-devs/pymc3/pull/4082)).
Expand Down
384 changes: 384 additions & 0 deletions docs/source/notebooks/GLM-hierarchical-jax.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docs/source/notebooks/table_of_contents_examples.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ Gallery.contents = {
"MLDA_introduction": "MCMC",
"MLDA_simple_linear_regression": "MCMC",
"MLDA_gravity_surveying": "MCMC",
"MLDA_variance_reduction_linear_regression": "MCMC"
"MLDA_variance_reduction_linear_regression": "MCMC",
"GLM-hierarchical-jax": "MCMC"
}
183 changes: 183 additions & 0 deletions pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# pylint: skip-file
import os
import re
import warnings

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])

import arviz as az
import jax
import numpy as np
import pandas as pd
import theano
import theano.sandbox.jax_linker
import theano.sandbox.jaxify

import pymc3 as pm

from pymc3 import modelcontext

warnings.warn("This module is experimental.")

# Disable C compilation by default
# theano.config.cxx = ""
# This will make the JAX Linker the default
# theano.config.mode = "JAX"


def sample_tfp_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
num_tuning_epoch=2,
num_compute_step_size=500,
):
from tensorflow_probability.substrates import jax as tfp

model = modelcontext(model)

seed = jax.random.PRNGKey(random_seed)

fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

@jax.pmap
def _sample(init_state, seed):
def gen_kernel(step_size):
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
return tfp.mcmc.DualAveragingStepSizeAdaptation(
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
)

def trace_fn(_, pkr):
return pkr.new_step_size

def get_tuned_stepsize(samples, step_size):
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])

step_size = jax.tree_map(jax.numpy.ones_like, init_state)
for i in range(num_tuning_epoch - 1):
tuning_hmc = gen_kernel(step_size)
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
num_results=tune // num_tuning_epoch,
current_state=init_state,
kernel=tuning_hmc,
trace_fn=trace_fn,
return_final_kernel_results=True,
seed=seed,
)

step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
init_state = [x[-1] for x in init_samples]

# Run inference
sample_kernel = gen_kernel(step_size)
mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
num_results=draws,
num_burnin_steps=tune // num_tuning_epoch,
current_state=init_state,
kernel=sample_kernel,
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
seed=seed,
)

return mcmc_samples, leapfrog_num

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)

# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

az_trace = az.from_dict(posterior=posterior)
return az_trace # , leapfrog_num, tic3 - tic2

import jax


def sample_numpyro_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
progress_bar=True,
):
from numpyro.infer import MCMC, NUTS

from pymc3 import modelcontext

model = modelcontext(model)

seed = jax.random.PRNGKey(random_seed)

fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

@jax.jit
def _sample(current_state, seed):
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
nuts_kernel = NUTS(
potential_fn=lambda x: -logp_fn_jax(*x),
# model=model,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
)

pmap_numpyro = MCMC(
nuts_kernel,
num_warmup=tune,
num_samples=draws,
num_chains=chains,
postprocess_fn=None,
chain_method="parallel",
progress_bar=progress_bar,
)

pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
samples = pmap_numpyro.get_samples(group_by_chain=True)
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
return samples, leapfrogs_taken

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)

# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

az_trace = az.from_dict(posterior=posterior)
return az_trace # , leapfrogs_taken, tic3 - tic2

0 comments on commit 5ff9bbc

Please sign in to comment.