From 073da0ea6bb227aa71fcf260bbfe9829725455f5 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 15 Apr 2021 15:12:12 -0500 Subject: [PATCH] Create a NumPyro sampler Op for better JAX backend integration --- pymc3/sampling_jax.py | 314 ++++++++++++++++--------------- pymc3/tests/test_sampling_jax.py | 26 ++- requirements.txt | 2 +- 3 files changed, 183 insertions(+), 159 deletions(-) diff --git a/pymc3/sampling_jax.py b/pymc3/sampling_jax.py index adbb47f9d68..6a138d3e654 100644 --- a/pymc3/sampling_jax.py +++ b/pymc3/sampling_jax.py @@ -3,150 +3,133 @@ import re import warnings -from collections import defaultdict - xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--") xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split() os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)]) import aesara.graph.fg +import aesara.tensor as at import arviz as az import jax import numpy as np import pandas as pd -from aesara.link.jax.jax_dispatch import jax_funcify - -import pymc3 as pm +from aesara.compile import SharedVariable +from aesara.graph.basic import Apply, Constant, clone, graph_inputs +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op +from aesara.graph.opt import MergeOptimizer +from aesara.link.jax.dispatch import jax_funcify +from aesara.tensor.type import TensorType from pymc3 import modelcontext warnings.warn("This module is experimental.") -# Disable C compilation by default -# aesara.config.cxx = "" -# This will make the JAX Linker the default -# aesara.config.mode = "JAX" +class NumPyroNUTS(Op): + def __init__( + self, + inputs, + outputs, + target_accept=0.9, + draws=1000, + tune=1000, + chains=4, + seed=None, + progress_bar=True, + ): + self.draws = draws + self.tune = tune + self.chains = chains + self.target_accept = target_accept + self.progress_bar = progress_bar + self.seed = seed -def sample_tfp_nuts( - draws=1000, - tune=1000, - chains=4, - target_accept=0.8, - random_seed=10, - model=None, - num_tuning_epoch=2, - num_compute_step_size=500, -): - import jax + self.inputs, self.outputs = clone(inputs, outputs, copy_inputs=False) + self.inputs_type = tuple([input.type for input in inputs]) + self.outputs_type = tuple([output.type for output in outputs]) + self.nin = len(inputs) + self.nout = len(outputs) + self.nshared = len([v for v in inputs if isinstance(v, SharedVariable)]) + self.samples_bcast = [self.chains == 1, self.draws == 1] - from tensorflow_probability.substrates import jax as tfp + self.fgraph = FunctionGraph(self.inputs, self.outputs, clone=False) + MergeOptimizer().optimize(self.fgraph) - model = modelcontext(model) + super().__init__() - seed = jax.random.PRNGKey(random_seed) + def make_node(self, *inputs): - fgraph = model.logp.f.maker.fgraph - fns = jax_funcify(fgraph) - logp_fn_jax = fns[0] + # The samples for each variable + outputs = [ + TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs + ] - rv_names = [rv.name for rv in model.free_RVs] - init_state = [model.initial_point[rv_name] for rv_name in rv_names] - init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) + # The leapfrog statistics + outputs += [TensorType("int64", self.samples_bcast)()] - @jax.pmap - def _sample(init_state, seed): - def gen_kernel(step_size): - hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size) - return tfp.mcmc.DualAveragingStepSizeAdaptation( - hmc, tune // num_tuning_epoch, target_accept_prob=target_accept - ) + all_inputs = list(inputs) + if self.nshared > 0: + all_inputs += self.inputs[-self.nshared :] - def trace_fn(_, pkr): - return pkr.new_step_size - - def get_tuned_stepsize(samples, step_size): - return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:]) - - step_size = jax.tree_map(jax.numpy.ones_like, init_state) - for i in range(num_tuning_epoch - 1): - tuning_hmc = gen_kernel(step_size) - init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain( - num_results=tune // num_tuning_epoch, - current_state=init_state, - kernel=tuning_hmc, - trace_fn=trace_fn, - return_final_kernel_results=True, - seed=seed, - ) + return Apply(self, all_inputs, outputs) - step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result) - init_state = [x[-1] for x in init_samples] - - # Run inference - sample_kernel = gen_kernel(step_size) - mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain( - num_results=draws, - num_burnin_steps=tune // num_tuning_epoch, - current_state=init_state, - kernel=sample_kernel, - trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken, - seed=seed, - ) + def do_constant_folding(self, *args): + return False - return mcmc_samples, leapfrog_num + def perform(self, node, inputs, outputs): + raise NotImplementedError() - print("Compiling...") - tic2 = pd.Timestamp.now() - map_seed = jax.random.split(seed, chains) - mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed) - - # map_seed = jax.random.split(seed, chains) - # mcmc_samples = _sample(init_state_batched, map_seed) - # tic4 = pd.Timestamp.now() - # print("Sampling time = ", tic4 - tic3) - - posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} - az_trace = az.from_dict(posterior=posterior) - tic3 = pd.Timestamp.now() - print("Compilation + sampling time = ", tic3 - tic2) - return az_trace # , leapfrog_num, tic3 - tic2 - - -def sample_numpyro_nuts( - draws=1000, - tune=1000, - chains=4, - target_accept=0.8, - random_seed=10, - model=None, - progress_bar=True, - keep_untransformed=False, -): +@jax_funcify.register(NumPyroNUTS) +def jax_funcify_NumPyroNUTS(op, node, **kwargs): from numpyro.infer import MCMC, NUTS - from pymc3 import modelcontext + draws = op.draws + tune = op.tune + chains = op.chains + target_accept = op.target_accept + progress_bar = op.progress_bar + seed = op.seed + + # Compile the "inner" log-likelihood function. This will have extra shared + # variable inputs as the last arguments + logp_fn = jax_funcify(op.fgraph, **kwargs) + + if isinstance(logp_fn, (list, tuple)): + # This handles the new JAX backend, which always returns a tuple + logp_fn = logp_fn[0] + + def _sample(*inputs): + + if op.nshared > 0: + current_state = inputs[: -op.nshared] + shared_inputs = tuple(op.fgraph.inputs[-op.nshared :]) + else: + current_state = inputs + shared_inputs = () + + def log_fn_wrap(x): + res = logp_fn( + *( + x + # We manually obtain the shared values and added them + # as arguments to our compiled "inner" function + + tuple( + v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs + ) + ) + ) - model = modelcontext(model) + if isinstance(res, (list, tuple)): + # This handles the new JAX backend, which always returns a tuple + res = res[0] - seed = jax.random.PRNGKey(random_seed) + return -res - fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt]) - fns = jax_funcify(fgraph) - logp_fn_jax = fns[0] - - rv_names = [rv.name for rv in model.free_RVs] - init_state = [model.initial_point[rv_name] for rv_name in rv_names] - init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) - - @jax.jit - def _sample(current_state, seed): - step_size = jax.tree_map(jax.numpy.ones_like, init_state) nuts_kernel = NUTS( - potential_fn=lambda x: -logp_fn_jax(*x), - # model=model, + potential_fn=log_fn_wrap, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, @@ -166,60 +149,87 @@ def _sample(current_state, seed): pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",)) samples = pmap_numpyro.get_samples(group_by_chain=True) leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"] - return samples, leapfrogs_taken - - print("Compiling...") - tic2 = pd.Timestamp.now() - map_seed = jax.random.split(seed, chains) - mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed) - # map_seed = jax.random.split(seed, chains) - # mcmc_samples = _sample(init_state_batched, map_seed) - # tic4 = pd.Timestamp.now() - # print("Sampling time = ", tic4 - tic3) + return tuple(samples) + (leapfrogs_taken,) - posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} - tic3 = pd.Timestamp.now() - posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed) - tic4 = pd.Timestamp.now() + return _sample - az_trace = az.from_dict(posterior=posterior) - print("Compilation + sampling time = ", tic3 - tic2) - print("Transformation time = ", tic4 - tic3) - return az_trace # , leapfrogs_taken, tic3 - tic2 +def sample_numpyro_nuts( + draws=1000, + tune=1000, + chains=4, + target_accept=0.8, + random_seed=10, + model=None, + progress_bar=True, + keep_untransformed=False, +): + model = modelcontext(model) + seed = jax.random.PRNGKey(random_seed) -def _transform_samples(samples, model, keep_untransformed=False): + rv_names = [rv.name for rv in model.value_vars] + init_state = [model.initial_point[rv_name] for rv_name in rv_names] + init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) + init_state_batched_at = [at.as_tensor(v) for v in init_state_batched] - # Find out which RVs we need to compute: - free_rv_names = {x.name for x in model.free_RVs} - unobserved_names = {x.name for x in model.unobserved_RVs} + nuts_inputs = sorted( + [v for v in graph_inputs([model.logpt]) if not isinstance(v, Constant)], + key=lambda x: isinstance(x, SharedVariable), + ) + map_seed = jax.random.split(seed, chains) + numpyro_samples = NumPyroNUTS( + nuts_inputs, + [model.logpt], + target_accept=target_accept, + draws=draws, + tune=tune, + chains=chains, + seed=map_seed, + progress_bar=progress_bar, + )(*init_state_batched_at) + + # Un-transform the transformed variables in JAX + sample_outputs = [] + for i, (value_var, rv_samples) in enumerate(zip(model.value_vars, numpyro_samples[:-1])): + rv = model.values_to_rvs[value_var] + transform = getattr(value_var.tag, "transform", None) + if transform is not None: + untrans_value_var = transform.backward(rv, rv_samples) + untrans_value_var.name = rv.name + sample_outputs.append(untrans_value_var) + + if keep_untransformed: + rv_samples.name = value_var.name + sample_outputs.append(rv_samples) + else: + rv_samples.name = rv.name + sample_outputs.append(rv_samples) - names_to_compute = unobserved_names - free_rv_names - ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute] + print("Compiling...") - # Create function graph for these: - fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute) + tic1 = pd.Timestamp.now() + _sample = aesara.function( + [], + sample_outputs + [numpyro_samples[-1]], + allow_input_downcast=True, + on_unused_input="ignore", + accept_inplace=True, + mode="JAX", + ) + tic2 = pd.Timestamp.now() - # Jaxify, which returns a list of functions, one for each op - jax_fns = jax_funcify(fgraph) + print("Compilation time = ", tic2 - tic1) - # Put together the inputs - inputs = [samples[x.name] for x in model.free_RVs] + print("Sampling...") - for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns): + *mcmc_samples, leapfrogs_taken = _sample() + tic3 = pd.Timestamp.now() - # We need a function taking a single argument to run vmap, while the - # jax_fn takes a list, so: - result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs) + print("Sampling time = ", tic3 - tic2) - # Add to sample dict - samples[cur_op.name] = result + posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)} - # Discard unwanted transformed variables, if desired: - vars_to_keep = set( - pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed) - ) - samples = {x: y for x, y in samples.items() if x in vars_to_keep} + az_trace = az.from_dict(posterior=posterior) - return samples + return az_trace diff --git a/pymc3/tests/test_sampling_jax.py b/pymc3/tests/test_sampling_jax.py index 164f3eb7ec5..b2d39d130e6 100644 --- a/pymc3/tests/test_sampling_jax.py +++ b/pymc3/tests/test_sampling_jax.py @@ -1,21 +1,35 @@ +import aesara import numpy as np -import pytest import pymc3 as pm from pymc3.sampling_jax import sample_numpyro_nuts -@pytest.mark.xfail(reason="HalfNormal was not yet refactored") def test_transform_samples(): + aesara.config.on_opt_error = "raise" + np.random.seed(13244) + obs = np.random.normal(10, 2, size=100) + obs_at = aesara.shared(obs, borrow=True, name="obs") with pm.Model() as model: - + a = pm.Uniform("a", -20, 20) sigma = pm.HalfNormal("sigma") - b = pm.Normal("b", sigma=sigma) - trace = sample_numpyro_nuts(keep_untransformed=True) + b = pm.Normal("b", a, sigma=sigma, observed=obs_at) + + trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=True) log_vals = trace.posterior["sigma_log__"].values - trans_vals = trace.posterior["sigma"].values + trans_vals = trace.posterior["sigma"].values assert np.allclose(np.exp(log_vals), trans_vals) + + assert 8 < trace.posterior["a"].mean() < 11 + assert 1.5 < trace.posterior["sigma"].mean() < 2.5 + + obs_at.set_value(-obs) + with model: + trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=False) + + assert -11 < trace.posterior["a"].mean() < -8 + assert 1.5 < trace.posterior["sigma"].mean() < 2.5 diff --git a/requirements.txt b/requirements.txt index a508d1ef9eb..5e9d1e9f524 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aesara>=2.0.5 +aesara>=2.0.7 arviz>=0.11.2 cachetools>=4.2.1 dill