diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index a84b89a732..267670a835 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -44,6 +44,8 @@ from pytensor.graph.op import Op, compute_test_value from pytensor.graph.rewriting.basic import node_rewriter, pre_greedy_node_rewriter from pytensor.ifelse import IfElse, ifelse +from pytensor.scalar import Switch +from pytensor.scalar import switch as scalar_switch from pytensor.tensor.basic import Join, MakeVector, switch from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, @@ -55,7 +57,6 @@ AdvancedSubtensor, AdvancedSubtensor1, as_index_literal, - as_nontensor_scalar, get_canonical_form_slice, is_basic_idx, ) @@ -63,7 +64,12 @@ from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType from pytensor.tensor.var import TensorVariable -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.abstract import ( + MeasurableElemwise, + MeasurableVariable, + _logprob, + _logprob_helper, +) from pymc.logprob.rewriting import ( PreserveRVMappings, assume_measured_ir_outputs, @@ -325,37 +331,6 @@ def find_measurable_index_mixture(fgraph, node): return [new_mixture_rv] -@node_rewriter([switch]) -def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - old_mixture_rv = node.default_output() - idx, *components = node.inputs - - if rv_map_feature.request_measurable(components) != components: - return None - - mix_op = MixtureRV( - 2, - old_mixture_rv.dtype, - old_mixture_rv.broadcastable, - ) - new_mixture_rv = mix_op.make_node( - *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1]) - ).default_output() - - if pytensor.config.compute_test_value != "off": - if not hasattr(old_mixture_rv.tag, "test_value"): - compute_test_value(node) - - new_mixture_rv.tag.test_value = old_mixture_rv.tag.test_value - - return [new_mixture_rv] - - @_logprob.register(MixtureRV) def logprob_MixtureRV( op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs @@ -433,6 +408,51 @@ def logprob_MixtureRV( return logp_val +class MeasurableSwitchMixture(MeasurableElemwise): + valid_scalar_types = (Switch,) + + +measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch) + + +@node_rewriter([switch]) +def find_measurable_switch_mixture(fgraph, node): + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + switch_cond, *components = node.inputs + + # We don't support broadcasting of components, as that yields dependent (identical) values. + # The current logp implementation assumes all component values are independent. + # Broadcasting of the switch condition is fine + out_bcast = node.outputs[0].type.broadcastable + if any(comp.type.broadcastable != out_bcast for comp in components): + return None + + # Check that `switch_cond` is not potentially measurable + valued_rvs = rv_map_feature.rv_values.keys() + if check_potential_measurability([switch_cond], valued_rvs): + return None + + if rv_map_feature.request_measurable(components) != components: + return None + + return [measurable_switch_mixture(switch_cond, *components)] + + +@_logprob.register(MeasurableSwitchMixture) +def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs): + [value] = values + + return switch( + switch_cond, + _logprob_helper(component_true, value), + _logprob_helper(component_false, value), + ) + + measurable_ir_rewrites_db.register( "find_measurable_index_mixture", find_measurable_index_mixture, diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 697eaf9a38..53e2096b55 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,8 +52,9 @@ as_index_constant, ) +from pymc.logprob.abstract import MeasurableVariable from pymc.logprob.basic import conditional_logp, logp -from pymc.logprob.mixture import MixtureRV, expand_indices +from pymc.logprob.mixture import MeasurableSwitchMixture, MixtureRV, expand_indices from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.utils import dirac_delta from pymc.testing import assert_no_rvs @@ -907,7 +908,7 @@ def test_mixture_with_DiracDelta(): assert m_vv in logp_res -def test_switch_mixture(): +def test_scalar_switch_mixture(): srng = pt.random.RandomStream(29833) X_rv = srng.normal(-10.0, 0.1, name="X") @@ -919,6 +920,7 @@ def test_switch_mixture(): # When I_rv == True, X_rv flows through otherwise Y_rv does Z1_rv = pt.switch(I_rv, X_rv, Y_rv) + Z1_rv.name = "Z1" assert Z1_rv.eval({I_rv: 0}) > 5 assert Z1_rv.eval({I_rv: 1}) < -5 @@ -927,40 +929,90 @@ def test_switch_mixture(): z_vv.name = "z1" fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) - assert isinstance(fgraph.outputs[0].owner.op, MixtureRV) - assert not hasattr( - fgraph.outputs[0].tag, "test_value" - ) # pytensor.config.compute_test_value == "off" - assert fgraph.outputs[0].name is None - - Z1_rv.name = "Z1" - - fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) - - # building the identical graph but with a stack to check that mixture computations are identical - + # building the identical graph but with a stack to check that mixture logps are identical Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] assert Z2_rv.eval({I_rv: 0}) > 5 assert Z2_rv.eval({I_rv: 1}) < -5 - fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv}) - - assert equal_computations(fgraph.outputs, fgraph2.outputs) - z1_logp = conditional_logp({Z1_rv: z_vv, I_rv: i_vv}) z2_logp = conditional_logp({Z2_rv: z_vv, I_rv: i_vv}) z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()]) z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()]) - - # below should follow immediately from the equal_computations assertion above - assert equal_computations([z1_logp_combined], [z2_logp_combined]) - np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1})) np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) +@pytest.mark.parametrize("switch_cond_scalar", (True, False)) +def test_switch_mixture_vector(switch_cond_scalar): + if switch_cond_scalar: + switch_cond = pt.scalar("switch_cond", dtype=bool) + else: + switch_cond = pt.vector("switch_cond", dtype=bool) + true_branch = pt.exp(pt.random.normal(size=(4,))) + false_branch = pt.abs(pt.random.normal(size=(4,))) + + switch = pt.switch(switch_cond, true_branch, false_branch) + switch.name = "switch_mix" + switch_value = switch.clone() + switch_logp = logp(switch, switch_value) + + if switch_cond_scalar: + test_switch_cond = np.array(0, dtype=bool) + else: + test_switch_cond = np.array([0, 1, 0, 1], dtype=bool) + test_switch_value = np.linspace(0.1, 2.5, 4) + np.testing.assert_allclose( + switch_logp.eval({switch_cond: test_switch_cond, switch_value: test_switch_value}), + np.where( + test_switch_cond, + logp(true_branch, test_switch_value).eval(), + logp(false_branch, test_switch_value).eval(), + ), + ) + + +def test_switch_mixture_measurable_cond_fails(): + """Test that logprob inference fails when the switch condition is an unvalued measurable variable. + + Otherwise, the logp function would have to marginalize over this variable. + + NOTE: This could be supported in the future, in which case this test can be removed/adapted + """ + cond_var = 1 - pt.random.bernoulli(p=0.5) + true_branch = pt.random.normal() + false_branch = pt.random.normal() + + switch = pt.switch(cond_var, true_branch, false_branch) + with pytest.raises(NotImplementedError, match="Logprob method not implemented for"): + logp(switch, switch.type()) + + +def test_switch_mixture_invalid_bcast(): + """Test that we don't mark switches where components are broadcasted as measurable""" + valid_switch_cond = pt.vector("switch_cond", dtype=bool) + invalid_switch_cond = pt.matrix("switch_cond", dtype=bool) + + valid_true_branch = pt.exp(pt.random.normal(size=(4,))) + valid_false_branch = pt.abs(pt.random.normal(size=(4,))) + invalid_false_branch = pt.abs(pt.random.normal(size=())) + + valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch) + fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()}) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + + invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch) + fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + + invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch) + fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + + def test_ifelse_mixture_one_component(): if_rv = pt.random.bernoulli(0.5, name="if") scale_rv = pt.random.halfnormal(name="scale")