Skip to content

Commit

Permalink
Merge pull request #5182 from zaxtax/add_numpyro_deterministic
Browse files Browse the repository at this point in the history
Adding Deterministic for sampling_numpyro
  • Loading branch information
twiecki authored Nov 16, 2021
2 parents 4b7aaad + 5ead708 commit 44cf8a7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
50 changes: 20 additions & 30 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from aesara.link.jax.dispatch import jax_funcify

from pymc import Model, modelcontext
from pymc.aesaraf import compile_rv_inplace
from pymc.aesaraf import compile_rv_inplace, inputvars
from pymc.util import get_default_varnames

warnings.warn("This module is experimental.")

Expand Down Expand Up @@ -101,13 +102,19 @@ def sample_numpyro_nuts(
target_accept=0.8,
random_seed=10,
model=None,
var_names=None,
progress_bar=True,
keep_untransformed=False,
):
from numpyro.infer import MCMC, NUTS

model = modelcontext(model)

if var_names is None:
var_names = model.unobserved_value_vars

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

tic1 = pd.Timestamp.now()
print("Compiling...", file=sys.stdout)

Expand Down Expand Up @@ -143,45 +150,28 @@ def sample_numpyro_nuts(
seed = jax.random.PRNGKey(random_seed)
map_seed = jax.random.split(seed, chains)

pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
if chains == 1:
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
else:
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))

raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

tic3 = pd.Timestamp.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

print("Transforming variables...", file=sys.stdout)
mcmc_samples = []
for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
raw_samples = at.constant(np.asarray(raw_samples))

rv = model.values_to_rvs[value_var]
transform = getattr(value_var.tag, "transform", None)

if transform is not None:
# TODO: This will fail when the transformation depends on another variable
# such as in interval transform with RVs as edges
trans_samples = transform.backward(raw_samples, *rv.owner.inputs)
trans_samples.name = rv.name
mcmc_samples.append(trans_samples)

if keep_untransformed:
raw_samples.name = value_var.name
mcmc_samples.append(raw_samples)
else:
raw_samples.name = rv.name
mcmc_samples.append(raw_samples)

mcmc_varnames = [var.name for var in mcmc_samples]
mcmc_samples = compile_rv_inplace(
[],
mcmc_samples,
mode="JAX",
)()
mcmc_samples = {}
for v in vars_to_sample:
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
mcmc_samples[v.name] = result

tic4 = pd.Timestamp.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
posterior = mcmc_samples
az_trace = az.from_dict(posterior=posterior)

return az_trace
17 changes: 17 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def test_transform_samples():
assert 1.5 < trace.posterior["sigma"].mean() < 2.5


def test_deterministic_samples():
aesara.config.on_opt_error = "raise"
np.random.seed(13244)

obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Uniform("a", -20, 20)
b = pm.Deterministic("b", a / 2.0)
c = pm.Normal("c", a, sigma=1.0, observed=obs_at)

trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True)

assert 8 < trace.posterior["a"].mean() < 11
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)


def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

Expand Down

0 comments on commit 44cf8a7

Please sign in to comment.