Skip to content

Fix measurable stack and join with interdependent inputs #6342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector):


@_logprob.register(MeasurableMakeVector)
def logprob_make_vector(op, values, *base_vars, **kwargs):
def logprob_make_vector(op, values, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
# TODO: Sort out this circular dependency issue
from pymc.pytensorf import replace_rvs_by_values

(value,) = values

return at.stack([logprob(base_var, value[i]) for i, base_var in enumerate(base_vars)])
base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)}
for i, (base_rv, value) in enumerate(base_rvs_to_values.items()):
base_rv.name = f"base_rv[{i}]"
value.name = f"value[{i}]"

logps = [logprob(base_rv, value) for base_rv, value in base_rvs_to_values.items()]

# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_values)

return at.stack(logps)


class MeasurableJoin(Join):
Expand All @@ -144,27 +157,28 @@ class MeasurableJoin(Join):


@_logprob.register(MeasurableJoin)
def logprob_join(op, values, axis, *base_vars, **kwargs):
def logprob_join(op, values, axis, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `Join`."""
(value,) = values
# TODO: Find better way to avoid circular dependency
from pymc.pytensorf import constant_fold, replace_rvs_by_values

base_var_shapes = [base_var.shape[axis] for base_var in base_vars]
(value,) = values

# TODO: Find better way to avoid circular dependency
from pymc.pytensorf import constant_fold
base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs]

# We don't need the graph to be constant, just to have RandomVariables removed
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)
base_rv_shapes = constant_fold(base_rv_shapes, raise_not_constant=False)

split_values = at.split(
value,
splits_size=base_var_shapes,
n_splits=len(base_vars),
splits_size=base_rv_shapes,
n_splits=len(base_rvs),
axis=axis,
)

base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)}
logps = [
logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values)
logprob(base_var, split_value) for base_var, split_value in base_rvs_to_split_values.items()
]

if len({logp.ndim for logp in logps}) != 1:
Expand All @@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs):
"joining univariate and multivariate distributions",
)

# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)

base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
join_logprob = at.concatenate(
[
at.atleast_1d(logprob(base_var, split_value))
for base_var, split_value in zip(base_vars, split_values)
],
[at.atleast_1d(logp) for logp in logps],
axis=axis - base_vars_ndim_supp,
)

Expand All @@ -199,6 +213,8 @@ def find_measurable_stacks(
if rv_map_feature is None:
return None # pragma: no cover

rvs_to_values = rv_map_feature.rv_values

stack_out = node.outputs[0]

is_join = isinstance(node.op, Join)
Expand All @@ -211,18 +227,40 @@ def find_measurable_stacks(
if not all(
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and base_var not in rvs_to_values
for base_var in base_vars
):
return None # pragma: no cover

# Make base_vars unmeasurable
base_vars = [assign_custom_measurable_outputs(base_var.owner) for base_var in base_vars]
base_to_unmeasurable_vars = {
base_var: assign_custom_measurable_outputs(base_var.owner).outputs[
base_var.owner.outputs.index(base_var)
]
for base_var in base_vars
}

def replacement_fn(var, replacements):
if var in base_to_unmeasurable_vars:
replacements[var] = base_to_unmeasurable_vars[var]
# We don't want to clone valued nodes. Assigning a var to itself in the
# replacements prevents this
elif var in rvs_to_values:
replacements[var] = var

return []

# TODO: Fix this import circularity!
from pymc.pytensorf import _replace_rvs_in_graphs

unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
graphs=base_vars, replacement_fn=replacement_fn
)

if is_join:
measurable_stack = MeasurableJoin()(axis, *base_vars)
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)
else:
measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars)
measurable_stack = MeasurableMakeVector(node.op.dtype)(*unmeasurable_base_vars)

measurable_stack.name = stack_out.name

Expand Down
24 changes: 13 additions & 11 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def replace_rvs_by_values(
graphs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None,
**kwargs,
) -> List[TensorVariable]:
"""Clone and replace random variables in graphs with their value variables.
Expand All @@ -346,7 +346,7 @@ def replace_rvs_by_values(
The graphs in which to perform the replacements.
rvs_to_values
Mapping between the original graph RVs and respective value variables
rvs_to_transforms
rvs_to_transforms, optional
Mapping between the original graph RVs and respective value transforms
"""

Expand All @@ -361,7 +361,8 @@ def replace_rvs_by_values(
for rv, value in rvs_to_values.items():
equiv_rv = equiv.get(rv, rv)
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]
if rvs_to_transforms is not None:
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]

def poulate_replacements(rv, replacements):
# Populate replacements dict with {rv: value} pairs indicating which graph
Expand All @@ -372,14 +373,15 @@ def poulate_replacements(rv, replacements):
if value is None:
return []

transform = equiv_rvs_to_transforms.get(rv, None)
if transform is not None:
# We want to replace uses of the RV by the back-transformation of its value
value = transform.backward(value, *rv.owner.inputs)
# The value may have a less precise type than the rv. In this case
# filter_variable will add a SpecifyShape to ensure they are consistent
value = rv.type.filter_variable(value, allow_convert=True)
value.name = rv.name
if rvs_to_transforms is not None:
transform = equiv_rvs_to_transforms.get(rv, None)
if transform is not None:
# We want to replace uses of the RV by the back-transformation of its value
value = transform.backward(value, *rv.owner.inputs)
# The value may have a less precise type than the rv. In this case
# filter_variable will add a SpecifyShape to ensure they are consistent
value = rv.type.filter_variable(value, allow_convert=True)
value.name = rv.name

replacements[rv] = value
# Also walk the graph of the value variable to make any additional
Expand Down
86 changes: 86 additions & 0 deletions pymc/tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pymc.logprob import factorized_joint_logprob, joint_logprob
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.logprob.tensor import naive_bcast_rv_lift
from pymc.tests.helpers import assert_no_rvs


def test_naive_bcast_rv_lift():
Expand Down Expand Up @@ -109,6 +110,91 @@ def test_measurable_make_vector():
assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval)


@pytest.mark.parametrize("reverse", (False, True))
def test_measurable_make_vector_interdependent(reverse):
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
x = at.random.normal(name="x")
y_rvs = []
prev_rv = x
for i in range(3):
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}")
y_rvs.append(next_rv)
prev_rv = next_rv

if reverse:
y_rvs = y_rvs[::-1]

ys = at.stack(y_rvs)
ys.name = "ys"

x_vv = x.clone()
ys_vv = ys.clone()

logp = joint_logprob({x: x_vv, ys: ys_vv})
assert_no_rvs(logp)

y0_vv = y_rvs[0].clone()
y1_vv = y_rvs[1].clone()
y2_vv = y_rvs[2].clone()

ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})

rng = np.random.default_rng()
x_vv_test = rng.normal()
ys_vv_test = rng.normal(size=3)
np.testing.assert_allclose(
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
ref_logp.eval(
{x_vv: x_vv_test, y0_vv: ys_vv_test[0], y1_vv: ys_vv_test[1], y2_vv: ys_vv_test[2]}
),
)


@pytest.mark.parametrize("reverse", (False, True))
def test_measurable_join_interdependent(reverse):
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
x = at.random.normal(name="x")
y_rvs = []
prev_rv = x
for i in range(3):
next_rv = at.random.normal(prev_rv + 1, name=f"y{i}", size=(1, 2))
y_rvs.append(next_rv)
prev_rv = next_rv

if reverse:
y_rvs = y_rvs[::-1]

ys = at.concatenate(y_rvs, axis=0)
ys.name = "ys"

x_vv = x.clone()
ys_vv = ys.clone()

logp = joint_logprob({x: x_vv, ys: ys_vv})
assert_no_rvs(logp)

y0_vv = y_rvs[0].clone()
y1_vv = y_rvs[1].clone()
y2_vv = y_rvs[2].clone()

ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv})

rng = np.random.default_rng()
x_vv_test = rng.normal()
ys_vv_test = rng.normal(size=(3, 2))
np.testing.assert_allclose(
logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}),
ref_logp.eval(
{
x_vv: x_vv_test,
y0_vv: ys_vv_test[0:1],
y1_vv: ys_vv_test[1:2],
y2_vv: ys_vv_test[2:3],
}
),
)


@pytest.mark.parametrize(
"size1, size2, axis, concatenate",
[
Expand Down