Skip to content

Commit

Permalink
Allow non-scalar measurable switch mixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 30, 2023
1 parent 8b79a03 commit df9f1b8
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 55 deletions.
86 changes: 53 additions & 33 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from pytensor.graph.op import Op, compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter, pre_greedy_node_rewriter
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.tensor.basic import Join, MakeVector, switch
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
Expand All @@ -55,15 +57,19 @@
AdvancedSubtensor,
AdvancedSubtensor1,
as_index_literal,
as_nontensor_scalar,
get_canonical_form_slice,
is_basic_idx,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.var import TensorVariable

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import (
PreserveRVMappings,
assume_measured_ir_outputs,
Expand Down Expand Up @@ -325,37 +331,6 @@ def find_measurable_index_mixture(fgraph, node):
return [new_mixture_rv]


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

old_mixture_rv = node.default_output()
idx, *components = node.inputs

if rv_map_feature.request_measurable(components) != components:
return None

mix_op = MixtureRV(
2,
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
)
new_mixture_rv = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1])
).default_output()

if pytensor.config.compute_test_value != "off":
if not hasattr(old_mixture_rv.tag, "test_value"):
compute_test_value(node)

new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(
op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs
Expand Down Expand Up @@ -433,6 +408,51 @@ def logprob_MixtureRV(
return logp_val


class MeasurableSwitchMixture(MeasurableElemwise):
valid_scalar_types = (Switch,)


measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch)


@node_rewriter([switch])
def find_measurable_switch_mixture(fgraph, node):
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

switch_cond, *components = node.inputs

# We don't support broadcasting of components, as that yields dependent (identical) values.
# The current logp implementation assumes all component values are independent.
# Broadcasting of the switch condition is fine
out_bcast = node.outputs[0].type.broadcastable
if any(comp.type.broadcastable != out_bcast for comp in components):
return None

# Check that `switch_cond` is not potentially measurable
valued_rvs = rv_map_feature.rv_values.keys()
if check_potential_measurability([switch_cond], valued_rvs):
return None

if rv_map_feature.request_measurable(components) != components:
return None

return [measurable_switch_mixture(switch_cond, *components)]


@_logprob.register(MeasurableSwitchMixture)
def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs):
[value] = values

return switch(
switch_cond,
_logprob_helper(component_true, value),
_logprob_helper(component_false, value),
)


measurable_ir_rewrites_db.register(
"find_measurable_index_mixture",
find_measurable_index_mixture,
Expand Down
96 changes: 74 additions & 22 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@
as_index_constant,
)

from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.mixture import MixtureRV, expand_indices
from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.utils import dirac_delta
from pymc.testing import assert_no_rvs
Expand Down Expand Up @@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta():
assert m_vv in logp_res


def test_switch_mixture():
def test_scalar_switch_mixture():
srng = pt.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Expand All @@ -919,6 +920,7 @@ def test_switch_mixture():

# When I_rv == True, X_rv flows through otherwise Y_rv does
Z1_rv = pt.switch(I_rv, X_rv, Y_rv)
Z1_rv.name = "Z1"

assert Z1_rv.eval({I_rv: 0}) > 5
assert Z1_rv.eval({I_rv: 1}) < -5
Expand All @@ -927,40 +929,90 @@ def test_switch_mixture():
z_vv.name = "z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)

assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
assert not hasattr(
fgraph.outputs[0].tag, "test_value"
) # pytensor.config.compute_test_value == "off"
assert fgraph.outputs[0].name is None

Z1_rv.name = "Z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

# building the identical graph but with a stack to check that mixture computations are identical

# building the identical graph but with a stack to check that mixture logps are identical
Z2_rv = pt.stack((Y_rv, X_rv))[I_rv]

assert Z2_rv.eval({I_rv: 0}) > 5
assert Z2_rv.eval({I_rv: 1}) < -5

fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})

assert equal_computations(fgraph.outputs, fgraph2.outputs)

z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv})
z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv})
z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()])
z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()])

# below should follow immediately from the equal_computations assertion above
assert equal_computations([z1_logp_combined], [z2_logp_combined])

np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1}))
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1}))


@pytest.mark.parametrize("switch_cond_scalar", (True, False))
def test_switch_mixture_vector(switch_cond_scalar):
if switch_cond_scalar:
switch_cond = pt.scalar("switch_cond", dtype=bool)
else:
switch_cond = pt.vector("switch_cond", dtype=bool)
true_branch = pt.exp(pt.random.normal(size=(4,)))
false_branch = pt.abs(pt.random.normal(size=(4,)))

switch = pt.switch(switch_cond, true_branch, false_branch)
switch.name = "switch_mix"
switch_value = switch.clone()
switch_logp = logp(switch, switch_value)

if switch_cond_scalar:
test_switch_cond = np.array(0, dtype=bool)
else:
test_switch_cond = np.array([0, 1, 0, 1], dtype=bool)
test_switch_value = np.linspace(0.1, 2.5, 4)
np.testing.assert_allclose(
switch_logp.eval({switch_cond: test_switch_cond, switch_value: test_switch_value}),
np.where(
test_switch_cond,
logp(true_branch, test_switch_value).eval(),
logp(false_branch, test_switch_value).eval(),
),
)


def test_switch_mixture_measurable_cond_fails():
"""Test that logprob inference fails when the switch condition is an unvalued measurable variable.
Otherwise, the logp function would have to marginalize over this variable.
NOTE: This could be supported in the future, in which case this test can be removed/adapted
"""
cond_var = 1 - pt.random.bernoulli(p=0.5)
true_branch = pt.random.normal()
false_branch = pt.random.normal()

switch = pt.switch(cond_var, true_branch, false_branch)
with pytest.raises(NotImplementedError, match="Logprob method not implemented for"):
logp(switch, switch.type())


def test_switch_mixture_invalid_bcast():
"""Test that we don't mark switches where components are broadcasted as measurable"""
valid_switch_cond = pt.vector("switch_cond", dtype=bool)
invalid_switch_cond = pt.matrix("switch_cond", dtype=bool)

valid_true_branch = pt.exp(pt.random.normal(size=(4,)))
valid_false_branch = pt.abs(pt.random.normal(size=(4,)))
invalid_false_branch = pt.abs(pt.random.normal(size=()))

valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch)
fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()})
assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)
assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture)

invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch)
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)

invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch)
fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()})
assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable)


def test_ifelse_mixture_one_component():
if_rv = pt.random.bernoulli(0.5, name="if")
scale_rv = pt.random.halfnormal(name="scale")
Expand Down

0 comments on commit df9f1b8

Please sign in to comment.