diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 0767bee29d1..21d1dd407a2 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -14,6 +14,7 @@ import warnings from typing import ( + Any, Callable, Dict, Generator, @@ -33,6 +34,7 @@ from aeppl.abstract import MeasurableVariable from aeppl.logprob import CheckParameterValue from aesara import config, scalar +from aesara.compile.function.types import Function from aesara.compile.mode import Mode, get_mode from aesara.gradient import grad from aesara.graph import local_optimizer @@ -41,6 +43,7 @@ Constant, Variable, clone_get_equiv, + general_toposort, graph_inputs, vars_between, walk, @@ -82,6 +85,7 @@ "at_rng", "take_along_axis", "convert_observed_data", + "compile_forward_sampling_function", ] @@ -1015,3 +1019,144 @@ def compile_pymc( **kwargs, ) return aesara_function + + +def compile_forward_sampling_function( + outputs: List[Variable], + vars_in_trace: List[Variable], + givens_dict: Optional[Dict[Variable, Any]] = None, + **kwargs, +) -> Function: + """Compile a function to draw samples, conditioned on the values of some variables. + + The goal of this function is to walk the aesara computational graph from the list + of output nodes down to the root nodes, and then compile a function that will produce + values for these output nodes. The compiled function will take as inputs the subset of + variables in the ``vars_in_trace`` that are deemed to not be **volatile**. + + Volatile variables are variables whose values could change between runs of the + compiled function. These variables are: variables in the outputs list, SharedVariable + instances, RandomVariable instances that have volatile parameters, RandomVariables that + are not in the ``vars_in_trace`` list, and that any other type of variable that depends + on volatile variables. + + Concretely, this function can be used to compile a function to sample from the + posterior predictive distribution of a model that has variables that are conditioned + on ``MutableData`` instances. The variables that depend on the mutable data will be + considered volatile, and as such, they wont be included as inputs into the compiled function. + This means that if they have values stored in the posterior, these values will be ignored + and new values will be computed (in the case of deterministics and potentials) or sampled + (in the case of random variables). + + This function also enables a way to impute values for any variable in the computational + graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used + to set the ``givens`` argument of the aesara function compilation. This will essentially + replace a node in the computational graph with any other expression that has the same + type as the desired node. Passing variables in the givens_dict is considered an intervention + that might lead to different variable values from those that could have been seen during + inference, as such, **any variable that is passed in the ``givens_dict`` will be considered + volatile**. + + Parameters + ---------- + outputs : List[aesara.graph.basic.Variable] + The list of variables that will be returned by the compiled function + vars_in_trace : List[aesara.graph.basic.Variable] + The list of variables that are assumed to have values stored in the trace + givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]] + A dictionary that maps tensor variables to the values that should be used to replace them + in the compiled function. The types of the key and value should match or an error will be + raised during compilation. + """ + if givens_dict is None: + givens_dict = {} + # We need a function graph to walk the clients and propagate the volatile property + fg = aesara.graph.FunctionGraph(outputs=outputs, clone=False) + + # Walk the graph from inputs to outputs and tag the volatile variables + nodes = general_toposort(fg.outputs, deps=lambda x: x.get_parents()) + givens = [] + registered_givens = set() + volatile_nodes = set() + for node in nodes: + if ( + isinstance(node, aesara.compile.SharedVariable) + or node in fg.outputs + or node in givens_dict + ): + volatile_nodes.add(node) + is_volatile = node in volatile_nodes + + # Populate the givens list if the node is in the givens_dict + if node in givens_dict and node not in registered_givens: + value = givens_dict[node] + if not isinstance(value, (Apply, Variable)): + # If the value is not a tensor, it is wrapped as a Constant with the node's name + value = at.constant( + givens_dict[node], + dtype=node.dtype, + name=node.name, + ) + givens.append((node, value)) + registered_givens.add(node) + + # Propagate the volatile property to the node's clients + if is_volatile: + try: + clients = [c for c, _ in fg.clients[node]] + except KeyError: + # Some Apply nodes don't show up in the clients of a FunctionGraph + # so we take the Apply node's outputs instead + if isinstance(node, aesara.graph.basic.Apply): + clients = node.outputs + else: + raise + for client in clients: + if ( + isinstance(client, aesara.graph.basic.Variable) + and client.owner is not None + and isinstance(client.owner.op, at.random.basic.RandomVariable) + and any(rv_param in volatile_nodes for rv_param in client.owner.inputs[3:]) + ): + # Random variables that have volatile parameters are also volatile + volatile_nodes.add(client) + elif not ( + client in vars_in_trace + and isinstance(client, aesara.graph.basic.Variable) + and client.owner is not None + and isinstance(client.owner.op, at.random.basic.RandomVariable) + ): + # Other variables that aren't random variables in the trace are volatile + volatile_nodes.add(client) + + # Collect the function inputs by walking the graph from the outputs. Inputs will be: + # 1. Random variables that are not volatile + # 2. Variables that have no owner and are not constant or shared + inputs = [] + + def expand(node): + if isinstance(node, Apply): + # Apply nodes are never inputs + return node.get_parents() + assert isinstance(node, Variable) + if ( + ( + node.owner is None + and not isinstance(node, (at.basic.Constant, aesara.compile.SharedVariable)) + ) # Variables without owners that are not constant or shared + or ( + node.owner is not None and isinstance(node.owner.op, at.random.basic.RandomVariable) + ) # Random variables + or (node in vars_in_trace) # Variables in the trace + ): + if node not in volatile_nodes: + # This test will include variables without owners, and that are not constant + # or shared, because these nodes will never be considered volatile + inputs.append(node) + return node.get_parents() + + # walk produces a generator, so we have to actually exhaust the generator in a list to walk + # the entire graph + list(walk(fg.outputs, expand)) + + return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs) diff --git a/pymc/sampling.py b/pymc/sampling.py index 452e2303229..4ba2c22943e 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -44,7 +44,6 @@ import xarray from aesara.graph.basic import Constant, Variable -from aesara.tensor import TensorVariable from aesara.tensor.sharedvar import SharedVariable from arviz import InferenceData from fastprogress.fastprogress import progress_bar @@ -52,7 +51,12 @@ import pymc as pm -from pymc.aesaraf import change_rv_size, compile_pymc, inputvars, walk_model +from pymc.aesaraf import ( + change_rv_size, + compile_forward_sampling_function, + compile_pymc, + inputvars, +) from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray @@ -75,6 +79,7 @@ get_default_varnames, get_untransformed_name, is_transformed_name, + point_wrapper, ) from pymc.vartypes import discrete_types @@ -1534,6 +1539,16 @@ def stop_tuning(step): return step +def get_vars_in_point_list(trace, model): + """Get the list of Variable instances in the model that have values stored in the trace.""" + if not isinstance(trace, MultiTrace): + names_in_trace = list(trace[0]) + else: + names_in_trace = trace.varnames + vars_in_trace = [model[v] for v in names_in_trace] + return vars_in_trace + + def sample_posterior_predictive( trace, samples: Optional[int] = None, @@ -1718,38 +1733,22 @@ def sample_posterior_predictive( return trace return {} - inputs: Sequence[TensorVariable] - input_names: Sequence[str] - if not isinstance(_trace, MultiTrace): - names_in_trace = list(_trace[0]) - else: - names_in_trace = _trace.varnames - inputs_and_names = [ - (rv, rv.name) - for rv in walk_model(vars_to_sample, walk_past_rvs=True) - if rv not in vars_to_sample - and rv in model.named_vars.values() - and not isinstance(rv, (Constant, SharedVariable)) - and rv.name in names_in_trace - ] - if inputs_and_names: - inputs, input_names = zip(*inputs_and_names) - else: - inputs, input_names = [], [] - if size is not None: vars_to_sample = [change_rv_size(v, size, expand=True) for v in vars_to_sample] + vars_in_trace = get_vars_in_point_list(_trace, model) if compile_kwargs is None: compile_kwargs = {} - - sampler_fn = compile_pymc( - inputs, - vars_to_sample, - allow_input_downcast=True, - accept_inplace=True, - on_unused_input="ignore", - **compile_kwargs, + compile_kwargs.setdefault("allow_input_downcast", True) + compile_kwargs.setdefault("accept_inplace", True) + + sampler_fn = point_wrapper( + compile_forward_sampling_function( + outputs=vars_to_sample, + vars_in_trace=vars_in_trace, + givens_dict=None, + **compile_kwargs, + ) ) ppc_trace_t = _DefaultTrace(samples) @@ -1775,7 +1774,7 @@ def sample_posterior_predictive( else: param = _trace[idx % len_trace] - values = sampler_fn(*(param[n] for n in input_names)) + values = sampler_fn(**param) for k, v in zip(vars_, values): ppc_trace_t.insert(k.name, v, idx) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 6b032ff2c6f..88dc7e96fbd 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -24,8 +24,9 @@ from aeppl.abstract import MeasurableVariable from aeppl.logprob import ParameterValueError +from aesara.compile import SharedVariable from aesara.compile.builders import OpFromGraph -from aesara.graph.basic import Constant, Variable, ancestors, equal_computations +from aesara.graph.basic import Constant, Variable, ancestors, equal_computations, walk from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -37,6 +38,7 @@ from pymc.aesaraf import ( _conversion_map, change_rv_size, + compile_forward_sampling_function, compile_pymc, convert_observed_data, extract_obs_data, @@ -704,3 +706,126 @@ def update(self, node): fn = compile_pymc(inputs=[], outputs=dummy_x) assert fn() == 2.0 assert fn() == 3.0 + + +class TestCompileForwardSampler: + @staticmethod + def get_function_roots(function): + roots = [] + + def expand(node): + if ( + isinstance(node, Variable) + and node.owner is None + and getattr(node, "name", None) is not None + ): + roots.append(node) + return node.get_parents() + + list(walk(function.maker.fgraph.outputs, expand)) + return roots + + @staticmethod + def get_function_inputs(function): + return { + i + for i, istore, *_ in function.indices + if istore is None and not isinstance(i.variable, SharedVariable) + } + + def test_linear_model(self): + rng_state = np.random.RandomState(42) + rng = np.random.default_rng(42) + + with pm.Model(rng_seeder=rng_state): + x = pm.MutableData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", rng.normal(2, 0.3, size=10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function([obs], vars_in_trace=[alpha, beta, sigma, mu]) + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "alpha", "beta", "sigma"} + + with pm.Model(rng_seeder=rng_state): + x = pm.ConstantData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", rng.normal(2, 0.3, size=10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function([obs], vars_in_trace=[alpha, beta, sigma, mu]) + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} + + def test_nested_observed_model(self): + rng_state = np.random.RandomState(42) + rng = np.random.default_rng(42) + + with pm.Model(rng_seeder=rng_state) as model: + p = pm.ConstantData("p", np.array([0.25, 0.5, 0.25])) + x = pm.MutableData("x", rng.choice(len(p.value), p=p.value, size=10)) + y = pm.MutableData("y", rng.normal(2, 0.3, size=10)) + + category = pm.Categorical("category", p, observed=x) + beta = pm.Normal("beta", 0, 0.1, size=p.shape) + mu = pm.Deterministic("mu", beta[category]) + sigma = pm.HalfNormal("sigma", 0.1) + pm.Normal("obs", mu, sigma, observed=y) + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + ) + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "p", "beta", "sigma"} + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + givens_dict={category: np.zeros(10, dtype=category.dtype)}, + ) + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "x", + "p", + "category", + "beta", + "sigma", + } + + def test_volatile_parameters(self): + rng_state = np.random.RandomState(42) + rng = np.random.default_rng(42) + + with pm.Model(rng_seeder=rng_state) as model: + y = pm.MutableData("y", rng.normal(2, 0.3, size=10)) + mu = pm.Normal("mu", 0, 1) + nested_mu = pm.Normal("nested_mu", mu, 1, size=10) + sigma = pm.HalfNormal("sigma", 1) + pm.Normal("obs", nested_mu, sigma, observed=y) + + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[nested_mu, sigma], # mu isn't in the trace and will be deemed volatile + ) + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"sigma"} + + aesara.dprint(model.observed_RVs[0]) + f = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[mu, nested_mu, sigma], + givens_dict={ + mu: np.array(1.0) + }, # mu will be considered volatile because it's in givens + ) + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} diff --git a/pymc/util.py b/pymc/util.py index 8ef7d886d32..86b56e9c5cd 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -21,6 +21,7 @@ import numpy as np import xarray +from aesara.compile import SharedVariable from cachetools import LRUCache, cachedmethod @@ -349,3 +350,25 @@ def check_dist_not_registered(dist, model=None): f"You should use an unregistered (unnamed) distribution created via " f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`" ) + + +def point_wrapper(core_function): + """Wrap an aesara compiled function to be able to ingest point dictionaries whilst + ignoring the keys that are not valid inputs to the core function. + """ + ins = [ + i.name + for i, istore, *_ in core_function.indices + if istore is None and not isinstance(i.variable, SharedVariable) + ] + + def wrapped(*args, **kwargs): + input_point = {ins[i]: v for i, v in enumerate(args)} + kwargs = {k: v for k, v in kwargs.items() if k in ins} + for k in kwargs: + if k in input_point: + raise TypeError(f"Function got multiple values for argument {k}") + input_point.update(kwargs) + return core_function(**input_point) + + return wrapped