Skip to content

Commit

Permalink
Allow measurable stack and join with interdependent inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2022
1 parent f46fe43 commit 5a9c755
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 19 deletions.
75 changes: 56 additions & 19 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector):


@_logprob.register(MeasurableMakeVector)
def logprob_make_vector(op, values, *base_vars, **kwargs):
def logprob_make_vector(op, values, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
# TODO: Sort out this circular dependency issue
from pymc.aesaraf import replace_rvs_by_values

(value,) = values

return at.stack([logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)])
base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)}
for i, (base_rv, value) in enumerate(base_rvs_to_values.items()):
base_rv.name = f"base_rv[{i}]"
value.name = f"value[{i}]"

logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()]

# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values)

return at.stack(logps)


class MeasurableJoin(Join):
Expand All @@ -144,27 +157,28 @@ class MeasurableJoin(Join):


@_logprob.register(MeasurableJoin)
def logprob_join(op, values, axis, *base_vars, **kwargs):
def logprob_join(op, values, axis, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `Join`."""
(value,) = values
# TODO: Find better way to avoid circular dependency
from pymc.pytensorf import constant_fold, replace_rvs_by_values

base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
(value,) = values

# TODO: Find better way to avoid circular dependency
from pymc.pytensorf import constant_fold
base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs]

# We don't need the graph to be constant, just to have RandomVariables removed
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)
base_rv_shapes = constant_fold(base_rv_shapes, raise_not_constant=False)

split_values = at.split(
value,
splits_size=base_var_shapes,
n_splits=len(base_vars),
splits_size=base_rv_shapes,
n_splits=len(base_rvs),
axis=axis,
)

base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)}
logps = [
logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values)
logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items()
]

if len({logp.ndim for logp in logps}) != 1:
Expand All @@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs):
"joining univariate and multivariate distributions",
)

# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)

base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
join_logprob = at.concatenate(
[
at.atleast_1d(logprob(base_var, split_value))
for base_var, split_value in zip(base_vars, split_values)
],
[at.atleast_1d(logp) for logp in logps],
axis=axis - base_vars_ndim_supp,
)

Expand All @@ -190,6 +204,8 @@ def find_measurable_stacks(
fgraph, node
) -> Optional[List[Union[MeasurableMakeVector, MeasurableJoin]]]:
r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed."""
# TODO: Fix this import circularity!
from pymc.aesaraf import _replace_rvs_in_graphs

if isinstance(node.op, (MeasurableMakeVector, MeasurableJoin)):
return None # pragma: no cover
Expand All @@ -199,6 +215,8 @@ def find_measurable_stacks(
if rv_map_feature is None:
return None # pragma: no cover

rvs_to_values = rv_map_feature.rv_values

stack_out = node.outputs[0]

is_join = isinstance(node.op, Join)
Expand All @@ -211,18 +229,37 @@ def find_measurable_stacks(
if not all(
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and base_var not in rvs_to_values
for base_var in base_vars
):
return None # pragma: no cover

# Make base_vars unmeasurable
base_vars = [assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars]
base_to_unmeasurable_vars = {
base_var: assign_custom_measurable_outputs(base_var.owner).outputs[
base_var.owner.outputs.index(base_var)
]
for base_var in base_vars
}

def replacement_fn(var, replacements):
if var in base_to_unmeasurable_vars:
replacements[var] = base_to_unmeasurable_vars[var]
# We don't want to clone valued nodes. Assigning a var to itself in the
# replacements prevents this
elif var in rvs_to_values:
replacements[var] = var

return []

unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
graphs=base_vars, replacement_fn=replacement_fn
)

if is_join:
measurable_stack = MeasurableJoin()(axis, *base_vars)
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)
else:
measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars)
measurable_stack = MeasurableMakeVector(node.op.dtype)(*unmeasurable_base_vars)

measurable_stack.name = stack_out.name

Expand Down
86 changes: 86 additions & 0 deletions pymc/tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pymc.logprob import factorized_joint_logprob, joint_logprob
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.logprob.tensor import naive_bcast_rv_lift
from pymc.tests.helpers import assert_no_rvs


def test_naive_bcast_rv_lift():
Expand Down Expand Up @@ -109,6 +110,91 @@ def test_measurable_make_vector():
assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval)


@pytest.mark.parametrize("reverse", (False, True))
def test_measurable_make_vector_interdependent(reverse):
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
x = at.random.normal(name="x")
y_rvs = []
prev_rv = x
for i in range(3):
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}")
y_rvs.append(next_rv)
prev_rv = next_rv

if reverse:
y_rvs = y_rvs[::-1]

ys = at.stack(y_rvs)
ys.name = "ys"

x_vv = x.clone()
ys_vv = ys.clone()

logp = joint_logprob({x: x_vv, ys: ys_vv})
assert_no_rvs(logp)

y0_vv = y_rvs[0].clone()
y1_vv = y_rvs[1].clone()
y2_vv = y_rvs[2].clone()

ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})

rng = np.random.default_rng()
x_vv_test = rng.normal()
ys_vv_test = rng.normal(size=3)
np.testing.assert_allclose(
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
ref_logp.eval(
{x_vv: x_vv_test, y0_vv: ys_vv_test[0], y1_vv: ys_vv_test[1], y2_vv: ys_vv_test[2]}
),
)


@pytest.mark.parametrize("reverse", (False, True))
def test_measurable_join_interdependent(reverse):
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
x = at.random.normal(name="x")
y_rvs = []
prev_rv = x
for i in range(3):
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2))
y_rvs.append(next_rv)
prev_rv = next_rv

if reverse:
y_rvs = y_rvs[::-1]

ys = at.concatenate(y_rvs, axis=0)
ys.name = "ys"

x_vv = x.clone()
ys_vv = ys.clone()

logp = joint_logprob({x: x_vv, ys: ys_vv})
assert_no_rvs(logp)

y0_vv = y_rvs[0].clone()
y1_vv = y_rvs[1].clone()
y2_vv = y_rvs[2].clone()

ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})

rng = np.random.default_rng()
x_vv_test = rng.normal()
ys_vv_test = rng.normal(size=(3, 2))
np.testing.assert_allclose(
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
ref_logp.eval(
{
x_vv: x_vv_test,
y0_vv: ys_vv_test[0:1],
y1_vv: ys_vv_test[1:2],
y2_vv: ys_vv_test[2:3],
}
),
)


@pytest.mark.parametrize(
"size1, size2, axis, concatenate",
[
Expand Down

0 comments on commit 5a9c755

Please sign in to comment.