Skip to content

Commit

Permalink
Create a NumPyro sampler Op for better JAX backend integration
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 6, 2021
1 parent 3a02dcc commit 56e1688
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 159 deletions.
314 changes: 162 additions & 152 deletions pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,150 +3,133 @@
import re
import warnings

from collections import defaultdict

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])

import aesara.graph.fg
import aesara.tensor as at
import arviz as az
import jax
import numpy as np
import pandas as pd

from aesara.link.jax.jax_dispatch import jax_funcify

import pymc3 as pm
from aesara.compile import SharedVariable
from aesara.graph.basic import Apply, Constant, clone, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.type import TensorType

from pymc3 import modelcontext

warnings.warn("This module is experimental.")

# Disable C compilation by default
# aesara.config.cxx = ""
# This will make the JAX Linker the default
# aesara.config.mode = "JAX"

class NumPyroNUTS(Op):
def __init__(
self,
inputs,
outputs,
target_accept=0.8,
draws=1000,
tune=1000,
chains=4,
seed=None,
progress_bar=True,
):
self.draws = draws
self.tune = tune
self.chains = chains
self.target_accept = target_accept
self.progress_bar = progress_bar
self.seed = seed

def sample_tfp_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
num_tuning_epoch=2,
num_compute_step_size=500,
):
import jax
self.inputs, self.outputs = clone(inputs, outputs, copy_inputs=False)
self.inputs_type = tuple([input.type for input in inputs])
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.nshared = len([v for v in inputs if isinstance(v, SharedVariable)])
self.samples_bcast = [self.chains == 1, self.draws == 1]

from tensorflow_probability.substrates import jax as tfp
self.fgraph = FunctionGraph(self.inputs, self.outputs, clone=False)
MergeOptimizer().optimize(self.fgraph)

model = modelcontext(model)
super().__init__()

seed = jax.random.PRNGKey(random_seed)
def make_node(self, *inputs):

fgraph = model.logp.f.maker.fgraph
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]
# The samples for each variable
outputs = [
TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs
]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
# The leapfrog statistics
outputs += [TensorType("int64", self.samples_bcast)()]

@jax.pmap
def _sample(init_state, seed):
def gen_kernel(step_size):
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
return tfp.mcmc.DualAveragingStepSizeAdaptation(
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
)
all_inputs = list(inputs)
if self.nshared > 0:
all_inputs += self.inputs[-self.nshared :]

def trace_fn(_, pkr):
return pkr.new_step_size

def get_tuned_stepsize(samples, step_size):
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])

step_size = jax.tree_map(jax.numpy.ones_like, init_state)
for i in range(num_tuning_epoch - 1):
tuning_hmc = gen_kernel(step_size)
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
num_results=tune // num_tuning_epoch,
current_state=init_state,
kernel=tuning_hmc,
trace_fn=trace_fn,
return_final_kernel_results=True,
seed=seed,
)
return Apply(self, all_inputs, outputs)

step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
init_state = [x[-1] for x in init_samples]

# Run inference
sample_kernel = gen_kernel(step_size)
mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
num_results=draws,
num_burnin_steps=tune // num_tuning_epoch,
current_state=init_state,
kernel=sample_kernel,
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
seed=seed,
)
def do_constant_folding(self, *args):
return False

return mcmc_samples, leapfrog_num
def perform(self, node, inputs, outputs):
raise NotImplementedError()

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)

# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

az_trace = az.from_dict(posterior=posterior)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)
return az_trace # , leapfrog_num, tic3 - tic2


def sample_numpyro_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
@jax_funcify.register(NumPyroNUTS)
def jax_funcify_NumPyroNUTS(op, node, **kwargs):
from numpyro.infer import MCMC, NUTS

from pymc3 import modelcontext
draws = op.draws
tune = op.tune
chains = op.chains
target_accept = op.target_accept
progress_bar = op.progress_bar
seed = op.seed

# Compile the "inner" log-likelihood function. This will have extra shared
# variable inputs as the last arguments
logp_fn = jax_funcify(op.fgraph, **kwargs)

if isinstance(logp_fn, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
logp_fn = logp_fn[0]

def _sample(*inputs):

if op.nshared > 0:
current_state = inputs[: -op.nshared]
shared_inputs = tuple(op.fgraph.inputs[-op.nshared :])
else:
current_state = inputs
shared_inputs = ()

def log_fn_wrap(x):
res = logp_fn(
*(
x
# We manually obtain the shared values and added them
# as arguments to our compiled "inner" function
+ tuple(
v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs
)
)
)

model = modelcontext(model)
if isinstance(res, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
res = res[0]

seed = jax.random.PRNGKey(random_seed)
return -res

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

@jax.jit
def _sample(current_state, seed):
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
nuts_kernel = NUTS(
potential_fn=lambda x: -logp_fn_jax(*x),
# model=model,
potential_fn=log_fn_wrap,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
Expand All @@ -166,60 +149,87 @@ def _sample(current_state, seed):
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
samples = pmap_numpyro.get_samples(group_by_chain=True)
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
return samples, leapfrogs_taken

print("Compiling...")
tic2 = pd.Timestamp.now()
map_seed = jax.random.split(seed, chains)
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
# map_seed = jax.random.split(seed, chains)
# mcmc_samples = _sample(init_state_batched, map_seed)
# tic4 = pd.Timestamp.now()
# print("Sampling time = ", tic4 - tic3)
return tuple(samples) + (leapfrogs_taken,)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
tic3 = pd.Timestamp.now()
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
tic4 = pd.Timestamp.now()
return _sample

az_trace = az.from_dict(posterior=posterior)
print("Compilation + sampling time = ", tic3 - tic2)
print("Transformation time = ", tic4 - tic3)

return az_trace # , leapfrogs_taken, tic3 - tic2
def sample_numpyro_nuts(
draws=1000,
tune=1000,
chains=4,
target_accept=0.8,
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
model = modelcontext(model)

seed = jax.random.PRNGKey(random_seed)

def _transform_samples(samples, model, keep_untransformed=False):
rv_names = [rv.name for rv in model.value_vars]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]

# Find out which RVs we need to compute:
free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}
nuts_inputs = sorted(
[v for v in graph_inputs([model.logpt]) if not isinstance(v, Constant)],
key=lambda x: isinstance(x, SharedVariable),
)
map_seed = jax.random.split(seed, chains)
numpyro_samples = NumPyroNUTS(
nuts_inputs,
[model.logpt],
target_accept=target_accept,
draws=draws,
tune=tune,
chains=chains,
seed=map_seed,
progress_bar=progress_bar,
)(*init_state_batched_at)

# Un-transform the transformed variables in JAX
sample_outputs = []
for i, (value_var, rv_samples) in enumerate(zip(model.value_vars, numpyro_samples[:-1])):
rv = model.values_to_rvs[value_var]
transform = getattr(value_var.tag, "transform", None)
if transform is not None:
untrans_value_var = transform.backward(rv, rv_samples)
untrans_value_var.name = rv.name
sample_outputs.append(untrans_value_var)

if keep_untransformed:
rv_samples.name = value_var.name
sample_outputs.append(rv_samples)
else:
rv_samples.name = rv.name
sample_outputs.append(rv_samples)

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]
print("Compiling...")

# Create function graph for these:
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)
tic1 = pd.Timestamp.now()
_sample = aesara.function(
[],
sample_outputs + [numpyro_samples[-1]],
allow_input_downcast=True,
on_unused_input="ignore",
accept_inplace=True,
mode="JAX",
)
tic2 = pd.Timestamp.now()

# Jaxify, which returns a list of functions, one for each op
jax_fns = jax_funcify(fgraph)
print("Compilation time = ", tic2 - tic1)

# Put together the inputs
inputs = [samples[x.name] for x in model.free_RVs]
print("Sampling...")

for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):
*mcmc_samples, leapfrogs_taken = _sample()
tic3 = pd.Timestamp.now()

# We need a function taking a single argument to run vmap, while the
# jax_fn takes a list, so:
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)
print("Sampling time = ", tic3 - tic2)

# Add to sample dict
samples[cur_op.name] = result
posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}

# Discard unwanted transformed variables, if desired:
vars_to_keep = set(
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
)
samples = {x: y for x, y in samples.items() if x in vars_to_keep}
az_trace = az.from_dict(posterior=posterior)

return samples
return az_trace
Loading

0 comments on commit 56e1688

Please sign in to comment.