From 5a9c7558b4443fb179d62ce7f18933dd8f917058 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 25 Nov 2022 20:32:20 +0100 Subject: [PATCH] Allow measurable stack and join with interdependent inputs --- pymc/logprob/tensor.py | 75 ++++++++++++++++++++------- pymc/tests/logprob/test_tensor.py | 86 +++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 19 deletions(-) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index b62a6d2b40d..20f9e58b7c2 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector): @_logprob.register(MeasurableMakeVector) -def logprob_make_vector(op, values, *base_vars, **kwargs): +def logprob_make_vector(op, values, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `MeasurableMakeVector`.""" + # TODO: Sort out this circular dependency issue + from pymc.aesaraf import replace_rvs_by_values + (value,) = values - return at.stack([logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)]) + base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)} + for i, (base_rv, value) in enumerate(base_rvs_to_values.items()): + base_rv.name = f"base_rv[{i}]" + value.name = f"value[{i}]" + + logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()] + + # If the stacked variables depend on each other, we have to replace them by the respective values + logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values) + + return at.stack(logps) class MeasurableJoin(Join): @@ -144,27 +157,28 @@ class MeasurableJoin(Join): @_logprob.register(MeasurableJoin) -def logprob_join(op, values, axis, *base_vars, **kwargs): +def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" - (value,) = values + # TODO: Find better way to avoid circular dependency + from pymc.pytensorf import constant_fold, replace_rvs_by_values - base_var_shapes = [base_var.shape[axis] for base_var in base_vars] + (value,) = values - # TODO: Find better way to avoid circular dependency - from pymc.pytensorf import constant_fold + base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs] # We don't need the graph to be constant, just to have RandomVariables removed - base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False) + base_rv_shapes = constant_fold(base_rv_shapes, raise_not_constant=False) split_values = at.split( value, - splits_size=base_var_shapes, - n_splits=len(base_vars), + splits_size=base_rv_shapes, + n_splits=len(base_rvs), axis=axis, ) + base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)} logps = [ - logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values) + logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items() ] if len({logp.ndim for logp in logps}) != 1: @@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs): "joining univariate and multivariate distributions", ) + # If the stacked variables depend on each other, we have to replace them by the respective values + logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values) + base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim join_logprob = at.concatenate( - [ - at.atleast_1d(logprob(base_var, split_value)) - for base_var, split_value in zip(base_vars, split_values) - ], + [at.atleast_1d(logp) for logp in logps], axis=axis - base_vars_ndim_supp, ) @@ -190,6 +204,8 @@ def find_measurable_stacks( fgraph, node ) -> Optional[List[Union[MeasurableMakeVector, MeasurableJoin]]]: r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" + # TODO: Fix this import circularity! + from pymc.aesaraf import _replace_rvs_in_graphs if isinstance(node.op, (MeasurableMakeVector, MeasurableJoin)): return None # pragma: no cover @@ -199,6 +215,8 @@ def find_measurable_stacks( if rv_map_feature is None: return None # pragma: no cover + rvs_to_values = rv_map_feature.rv_values + stack_out = node.outputs[0] is_join = isinstance(node.op, Join) @@ -211,18 +229,37 @@ def find_measurable_stacks( if not all( base_var.owner and isinstance(base_var.owner.op, MeasurableVariable) - and base_var not in rv_map_feature.rv_values + and base_var not in rvs_to_values for base_var in base_vars ): return None # pragma: no cover # Make base_vars unmeasurable - base_vars = [assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars] + base_to_unmeasurable_vars = { + base_var: assign_custom_measurable_outputs(base_var.owner).outputs[ + base_var.owner.outputs.index(base_var) + ] + for base_var in base_vars + } + + def replacement_fn(var, replacements): + if var in base_to_unmeasurable_vars: + replacements[var] = base_to_unmeasurable_vars[var] + # We don't want to clone valued nodes. Assigning a var to itself in the + # replacements prevents this + elif var in rvs_to_values: + replacements[var] = var + + return [] + + unmeasurable_base_vars, _ = _replace_rvs_in_graphs( + graphs=base_vars, replacement_fn=replacement_fn + ) if is_join: - measurable_stack = MeasurableJoin()(axis, *base_vars) + measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars) else: - measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + measurable_stack = MeasurableMakeVector(node.op.dtype)(*unmeasurable_base_vars) measurable_stack.name = stack_out.name diff --git a/pymc/tests/logprob/test_tensor.py b/pymc/tests/logprob/test_tensor.py index fb6812ef42a..198f06cabde 100644 --- a/pymc/tests/logprob/test_tensor.py +++ b/pymc/tests/logprob/test_tensor.py @@ -48,6 +48,7 @@ from pymc.logprob import factorized_joint_logprob, joint_logprob from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift +from pymc.tests.helpers import assert_no_rvs def test_naive_bcast_rv_lift(): @@ -109,6 +110,91 @@ def test_measurable_make_vector(): assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) +@pytest.mark.parametrize("reverse", (False, True)) +def test_measurable_make_vector_interdependent(reverse): + """Test that we can obtain a proper graph when stacked RVs depend on each other""" + x = at.random.normal(name="x") + y_rvs = [] + prev_rv = x + for i in range(3): + next_rv = at.random.normal(prev_rv + 1, name=f"y{i}") + y_rvs.append(next_rv) + prev_rv = next_rv + + if reverse: + y_rvs = y_rvs[::-1] + + ys = at.stack(y_rvs) + ys.name = "ys" + + x_vv = x.clone() + ys_vv = ys.clone() + + logp = joint_logprob({x: x_vv, ys: ys_vv}) + assert_no_rvs(logp) + + y0_vv = y_rvs[0].clone() + y1_vv = y_rvs[1].clone() + y2_vv = y_rvs[2].clone() + + ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv}) + + rng = np.random.default_rng() + x_vv_test = rng.normal() + ys_vv_test = rng.normal(size=3) + np.testing.assert_allclose( + logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}), + ref_logp.eval( + {x_vv: x_vv_test, y0_vv: ys_vv_test[0], y1_vv: ys_vv_test[1], y2_vv: ys_vv_test[2]} + ), + ) + + +@pytest.mark.parametrize("reverse", (False, True)) +def test_measurable_join_interdependent(reverse): + """Test that we can obtain a proper graph when stacked RVs depend on each other""" + x = at.random.normal(name="x") + y_rvs = [] + prev_rv = x + for i in range(3): + next_rv = at.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2)) + y_rvs.append(next_rv) + prev_rv = next_rv + + if reverse: + y_rvs = y_rvs[::-1] + + ys = at.concatenate(y_rvs, axis=0) + ys.name = "ys" + + x_vv = x.clone() + ys_vv = ys.clone() + + logp = joint_logprob({x: x_vv, ys: ys_vv}) + assert_no_rvs(logp) + + y0_vv = y_rvs[0].clone() + y1_vv = y_rvs[1].clone() + y2_vv = y_rvs[2].clone() + + ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv}) + + rng = np.random.default_rng() + x_vv_test = rng.normal() + ys_vv_test = rng.normal(size=(3, 2)) + np.testing.assert_allclose( + logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}), + ref_logp.eval( + { + x_vv: x_vv_test, + y0_vv: ys_vv_test[0:1], + y1_vv: ys_vv_test[1:2], + y2_vv: ys_vv_test[2:3], + } + ), + ) + + @pytest.mark.parametrize( "size1, size2, axis, concatenate", [