Skip to content

Commit

Permalink
Refactor repeated compilation logic in sampling_jax and avoid Aesar…
Browse files Browse the repository at this point in the history
…a Supervisor warning
  • Loading branch information
ricardoV94 authored and twiecki committed Jan 30, 2022
1 parent 0dca647 commit 4b19716
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
62 changes: 34 additions & 28 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import sys
import warnings

from typing import Callable, List
from typing import Callable, List, Optional

from aesara.graph import optimize_graph
from aesara.tensor import TensorVariable

xla_flags = os.getenv("XLA_FLAGS", "")
Expand All @@ -20,7 +19,7 @@
import pandas as pd

from aeppl.logprob import CheckParameterValue
from aesara.compile import SharedVariable
from aesara.compile import SharedVariable, Supervisor, mode
from aesara.graph.basic import clone_replace, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.link.jax.dispatch import jax_funcify
Expand Down Expand Up @@ -69,29 +68,40 @@ def replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable
return new_graph


def get_jaxified_logp(model: Model) -> Callable:
"""Compile model.logpt into an optimized jax function"""

logpt = replace_shared_variables([model.logpt()])[0]

logpt_fgraph = FunctionGraph(outputs=[logpt], clone=True)
optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
def get_jaxified_graph(
inputs: Optional[List[TensorVariable]] = None,
outputs: Optional[List[TensorVariable]] = None,
) -> List[TensorVariable]:
"""Compile an Aesara graph into an optimized JAX function"""

graph = replace_shared_variables(outputs)

fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
# JAX sequential optimizer without warnings. We made sure there
# are no mutable input variables, so we only need to check for
# "destroyers". This should be automatically handled by Aesara
# once https://github.com/aesara-devs/aesara/issues/637 is fixed.
fgraph.attach_feature(
Supervisor(
input
for input in fgraph.inputs
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
mode.JAX.optimizer.optimize(fgraph)

# We now jaxify the optimized fgraph
logp_fn = jax_funcify(logpt_fgraph)
return jax_funcify(fgraph)

if isinstance(logp_fn, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
logp_fn = logp_fn[0]

def logp_fn_wrap(x):
res = logp_fn(*x)
def get_jaxified_logp(model: Model) -> Callable:

if isinstance(res, (list, tuple)):
# This handles the new JAX backend, which always returns a tuple
res = res[0]
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])

# Jax expects a potential with the opposite sign of model.logpt
def logp_fn_wrap(x):
# NumPyro expects a scalar potential with the opposite sign of model.logpt
res = logp_fn(*x)[0]
return -res

return logp_fn_wrap
Expand Down Expand Up @@ -119,13 +129,11 @@ def _sample_stats_to_xarray(posterior):


def _get_log_likelihood(model, samples):
"Compute log-likelihood for all observations"
"""Compute log-likelihood for all observations"""
data = {}
for v in model.observed_RVs:
logp_v = replace_shared_variables([model.logpt(v, sum=False)[0]])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=True)
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
jax_fn = jax_funcify(fgraph)
v_elemwise_logpt = model.logpt(v, sum=False)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logpt)
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
data[v.name] = result
return data
Expand Down Expand Up @@ -229,9 +237,7 @@ def sample_numpyro_nuts(
print("Transforming variables...", file=sys.stdout)
mcmc_samples = {}
for v in vars_to_sample:
fgraph = FunctionGraph(model.value_vars, [v], clone=True)
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
jax_fn = jax_funcify(fgraph)
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v])
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
mcmc_samples[v.name] = result

Expand Down
12 changes: 12 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pymc.sampling_jax import (
_get_log_likelihood,
get_jaxified_graph,
get_jaxified_logp,
replace_shared_variables,
sample_numpyro_nuts,
Expand Down Expand Up @@ -62,6 +63,17 @@ def test_deterministic_samples():
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)


def test_get_jaxified_graph():
# Check that jaxifying a graph does not emmit the Supervisor Warning. This test can
# be removed once https://github.com/aesara-devs/aesara/issues/637 is sorted.
x = at.scalar("x")
y = at.exp(x)
with pytest.warns(None) as record:
fn = get_jaxified_graph(inputs=[x], outputs=[y])
assert not record
assert np.isclose(fn(0), 1)


def test_get_log_likelihood():
obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
Expand Down

0 comments on commit 4b19716

Please sign in to comment.