From 4b197166782cf413632ef8040987b7ccec14dbab Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 28 Jan 2022 13:01:11 +0100 Subject: [PATCH] Refactor repeated compilation logic in `sampling_jax` and avoid Aesara Supervisor warning --- pymc/sampling_jax.py | 62 ++++++++++++++++++--------------- pymc/tests/test_sampling_jax.py | 12 +++++++ 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 1f58529c5c..5c935fa990 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -4,9 +4,8 @@ import sys import warnings -from typing import Callable, List +from typing import Callable, List, Optional -from aesara.graph import optimize_graph from aesara.tensor import TensorVariable xla_flags = os.getenv("XLA_FLAGS", "") @@ -20,7 +19,7 @@ import pandas as pd from aeppl.logprob import CheckParameterValue -from aesara.compile import SharedVariable +from aesara.compile import SharedVariable, Supervisor, mode from aesara.graph.basic import clone_replace, graph_inputs from aesara.graph.fg import FunctionGraph from aesara.link.jax.dispatch import jax_funcify @@ -69,29 +68,40 @@ def replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable return new_graph -def get_jaxified_logp(model: Model) -> Callable: - """Compile model.logpt into an optimized jax function""" - - logpt = replace_shared_variables([model.logpt()])[0] - - logpt_fgraph = FunctionGraph(outputs=[logpt], clone=True) - optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) +def get_jaxified_graph( + inputs: Optional[List[TensorVariable]] = None, + outputs: Optional[List[TensorVariable]] = None, +) -> List[TensorVariable]: + """Compile an Aesara graph into an optimized JAX function""" + + graph = replace_shared_variables(outputs) + + fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) + # We need to add a Supervisor to the fgraph to be able to run the + # JAX sequential optimizer without warnings. We made sure there + # are no mutable input variables, so we only need to check for + # "destroyers". This should be automatically handled by Aesara + # once https://github.com/aesara-devs/aesara/issues/637 is fixed. + fgraph.attach_feature( + Supervisor( + input + for input in fgraph.inputs + if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) + ) + ) + mode.JAX.optimizer.optimize(fgraph) # We now jaxify the optimized fgraph - logp_fn = jax_funcify(logpt_fgraph) + return jax_funcify(fgraph) - if isinstance(logp_fn, (list, tuple)): - # This handles the new JAX backend, which always returns a tuple - logp_fn = logp_fn[0] - def logp_fn_wrap(x): - res = logp_fn(*x) +def get_jaxified_logp(model: Model) -> Callable: - if isinstance(res, (list, tuple)): - # This handles the new JAX backend, which always returns a tuple - res = res[0] + logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()]) - # Jax expects a potential with the opposite sign of model.logpt + def logp_fn_wrap(x): + # NumPyro expects a scalar potential with the opposite sign of model.logpt + res = logp_fn(*x)[0] return -res return logp_fn_wrap @@ -119,13 +129,11 @@ def _sample_stats_to_xarray(posterior): def _get_log_likelihood(model, samples): - "Compute log-likelihood for all observations" + """Compute log-likelihood for all observations""" data = {} for v in model.observed_RVs: - logp_v = replace_shared_variables([model.logpt(v, sum=False)[0]]) - fgraph = FunctionGraph(model.value_vars, logp_v, clone=True) - optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) - jax_fn = jax_funcify(fgraph) + v_elemwise_logpt = model.logpt(v, sum=False) + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logpt) result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0] data[v.name] = result return data @@ -229,9 +237,7 @@ def sample_numpyro_nuts( print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: - fgraph = FunctionGraph(model.value_vars, [v], clone=True) - optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) - jax_fn = jax_funcify(fgraph) + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 172eceb4d0..3ddc8eea2b 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -10,6 +10,7 @@ from pymc.sampling_jax import ( _get_log_likelihood, + get_jaxified_graph, get_jaxified_logp, replace_shared_variables, sample_numpyro_nuts, @@ -62,6 +63,17 @@ def test_deterministic_samples(): assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2) +def test_get_jaxified_graph(): + # Check that jaxifying a graph does not emmit the Supervisor Warning. This test can + # be removed once https://github.com/aesara-devs/aesara/issues/637 is sorted. + x = at.scalar("x") + y = at.exp(x) + with pytest.warns(None) as record: + fn = get_jaxified_graph(inputs=[x], outputs=[y]) + assert not record + assert np.isclose(fn(0), 1) + + def test_get_log_likelihood(): obs = np.random.normal(10, 2, size=100) obs_at = aesara.shared(obs, borrow=True, name="obs")