Skip to content

Commit

Permalink
Adding Deterministic for sampling_numpyro
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Nov 14, 2021
1 parent 140dab0 commit 9dc7feb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
61 changes: 32 additions & 29 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from pymc import Model, modelcontext
from pymc.aesaraf import compile_rv_inplace
from pymc.util import get_default_varnames

warnings.warn("This module is experimental.")

Expand Down Expand Up @@ -101,13 +102,19 @@ def sample_numpyro_nuts(
target_accept=0.8,
random_seed=10,
model=None,
var_names=None,
progress_bar=True,
keep_untransformed=False,
):
from numpyro.infer import MCMC, NUTS

model = modelcontext(model)

if var_names is None:
var_names = model.unobserved_value_vars

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

tic1 = pd.Timestamp.now()
print("Compiling...", file=sys.stdout)

Expand All @@ -116,6 +123,7 @@ def sample_numpyro_nuts(
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

logp_fn = get_jaxified_logp(model)
fn = model.fastfn(vars_to_sample)

nuts_kernel = NUTS(
potential_fn=logp_fn,
Expand Down Expand Up @@ -143,45 +151,40 @@ def sample_numpyro_nuts(
seed = jax.random.PRNGKey(random_seed)
map_seed = jax.random.split(seed, chains)

pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
if chains == 1:
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
else:
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))

raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

tic3 = pd.Timestamp.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

print("Transforming variables...", file=sys.stdout)
mcmc_samples = []
for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
raw_samples = at.constant(np.asarray(raw_samples))

rv = model.values_to_rvs[value_var]
transform = getattr(value_var.tag, "transform", None)

if transform is not None:
# TODO: This will fail when the transformation depends on another variable
# such as in interval transform with RVs as edges
trans_samples = transform.backward(raw_samples, *rv.owner.inputs)
trans_samples.name = rv.name
mcmc_samples.append(trans_samples)

if keep_untransformed:
raw_samples.name = value_var.name
mcmc_samples.append(raw_samples)
else:
raw_samples.name = rv.name
mcmc_samples.append(raw_samples)

mcmc_varnames = [var.name for var in mcmc_samples]
mcmc_samples = compile_rv_inplace(
[],
mcmc_samples,
mode="JAX",
)()
mcmc_samples = {}
for v in vars_to_sample:
mcmc_samples[v.name] = []

for i in range(draws):
for c in range(chains):
draw = dict(
(value_var.name, raw_samples[c, i])
for value_var, raw_samples in zip(model.value_vars, raw_mcmc_samples)
)
sample = fn(draw)
for vi, v in enumerate(vars_to_sample):
mcmc_samples[v.name].append(sample[vi])

for v in vars_to_sample:
mcmc_samples[v.name] = np.array(mcmc_samples[v.name]).reshape(
(chains, draws) + mcmc_samples[v.name][-1].shape
)

tic4 = pd.Timestamp.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
posterior = mcmc_samples
az_trace = az.from_dict(posterior=posterior)

return az_trace
17 changes: 17 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ def test_transform_samples():
assert 1.5 < trace.posterior["sigma"].mean() < 2.5


def test_deterministic_samples():
aesara.config.on_opt_error = "raise"
np.random.seed(13244)

obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Normal("a", 0, 1)
b = pm.Deterministic("b", a / 2.0)
c = pm.Normal("c", a, sigma=1.0, observed=obs_at)

trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True)

assert 8 < trace.posterior["a"].mean() < 11
assert 4 < trace.posterior["b"].mean() < 6


def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

Expand Down

0 comments on commit 9dc7feb

Please sign in to comment.