From 7540d9efbee19902e96f37bd65b985b8cd70f7de Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Sun, 9 Apr 2023 11:39:37 +0530 Subject: [PATCH 01/10] Implement logprob for binary ops --- pymc/logprob/__init__.py | 1 + pymc/logprob/binary.py | 109 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 pymc/logprob/binary.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index f6ae51408c2..7992efac8ba 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -38,6 +38,7 @@ # isort: off # Add rewrites to the DBs +import pymc.logprob.binary import pymc.logprob.censoring import pymc.logprob.cumsum import pymc.logprob.checks diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py new file mode 100644 index 00000000000..0422f20a823 --- /dev/null +++ b/pymc/logprob/binary.py @@ -0,0 +1,109 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import numpy as np +import pytensor.tensor as pt + +from pytensor.graph.basic import Node +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.scalar.basic import GT, LT +from pytensor.tensor.math import gt, lt + +from pymc.logprob.abstract import ( + MeasurableElemwise, + MeasurableVariable, + _logcdf, + _logprob, +) +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import ignore_logprob + + +class MeasurableComparison(MeasurableElemwise): + """A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph.""" + + valid_scalar_types = (GT, LT) + + +@node_rewriter(tracks=[gt, lt]) +def find_measurable_comparisons( + fgraph: FunctionGraph, node: Node +) -> Optional[List[MeasurableComparison]]: + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableComparison): + return None # pragma: no cover + + (compared_var,) = node.outputs + base_var, const = node.inputs + + if not ( + base_var.owner + and isinstance(base_var.owner.op, MeasurableVariable) + and base_var not in rv_map_feature.rv_values + ): + return None + + # Make base_var unmeasurable + unmeasurable_base_var = ignore_logprob(base_var) + + compared_op = MeasurableComparison(node.op.scalar_op) + compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output() + compared_rv.name = compared_var.name + return [compared_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_comparisons", + find_measurable_comparisons, + "basic", + "comparison", +) + + +@_logprob.register(MeasurableComparison) +def comparison_logprob(op, values, base_rv, operand, **kwargs): + (value,) = values + + base_rv_op = base_rv.owner.op + base_rv_inputs = base_rv.owner.inputs + + logcdf = _logcdf(base_rv_op, operand, *base_rv_inputs, **kwargs) + logccdf = pt.log1mexp(logcdf) + + condn_exp = pt.eq(value, np.array(True)) + + if isinstance(op.scalar_op, GT): + logprob = pt.switch(condn_exp, logccdf, logcdf) + elif isinstance(op.scalar_op, LT): + logprob = pt.switch(condn_exp, logcdf, logccdf) + else: + raise TypeError(f"Unsupported scalar_op {op.scalar_op}") + + if base_rv.dtype.startswith("int"): + logp_point = _logprob(base_rv_op, (operand,), *base_rv_inputs, **kwargs) + if isinstance(op.scalar_op, GT): + logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob) + elif isinstance(op.scalar_op, LT): + logprob = pt.switch(condn_exp, logprob, pt.logaddexp(logprob, logp_point)) + + if base_rv_op.name: + logprob.name = f"{base_rv_op}_logprob" + logcdf.name = f"{base_rv_op}_logcdf" + + return logprob From 6b40e55aff93b9d267c4e87ee82d669ceacfadee Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Sun, 9 Apr 2023 11:42:23 +0530 Subject: [PATCH 02/10] Tests for binary comparison ops lopgprob --- tests/logprob/test_binary.py | 70 ++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/logprob/test_binary.py diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py new file mode 100644 index 00000000000..ca8195c40e1 --- /dev/null +++ b/tests/logprob/test_binary.py @@ -0,0 +1,70 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor +import pytensor.tensor as pt +import scipy.stats as st + +from pymc import logp +from pymc.testing import assert_no_rvs + + +def test_continuous_rv_comparison_lt(): + x_rv = pt.random.normal(0.5, 1) + comp_x_rv = pt.lt(x_rv, 0.5) + + comp_x_vv = comp_x_rv.clone() + comp_x_vv.tag.test_value = 0 + + logprob = logp(comp_x_rv, comp_x_vv) + assert_no_rvs(logprob) + + logp_fn = pytensor.function([comp_x_vv], logprob) + ref_scipy = st.norm(0.5, 1) + + assert np.isclose(logp_fn(0), ref_scipy.logcdf(0.5)) + assert np.isclose(logp_fn(1), ref_scipy.logsf(0.5)) + + +def test_continuous_rv_comparison_gt(): + x_rv = pt.random.normal(0.5, 1) + comp_x_rv = pt.gt(x_rv, 0.5) + + comp_x_vv = comp_x_rv.clone() + comp_x_vv.tag.test_value = 0 + + logprob = logp(comp_x_rv, comp_x_vv) + assert_no_rvs(logprob) + + logp_fn = pytensor.function([comp_x_vv], logprob) + ref_scipy = st.norm(0.5, 1) + + assert np.isclose(logp_fn(0), ref_scipy.logsf(0.5)) + assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.5)) + + +def test_discrete_rv_comparison(): + x_rv = pt.random.poisson(2) + cens_x_rv = pt.lt(x_rv, 3) + + cens_x_vv = cens_x_rv.clone() + + logprob = logp(cens_x_rv, cens_x_vv) + assert_no_rvs(logprob) + + logp_fn = pytensor.function([cens_x_vv], logprob) + ref_scipy = st.poisson(2) + + assert np.isclose(logp_fn(1), ref_scipy.logcdf(3)) + assert np.isclose(logp_fn(0), np.logaddexp(ref_scipy.logsf(3), ref_scipy.logpmf(3))) From d99d66c382c84ecc2a043dcf3e8f476ef4402cf9 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 10 Apr 2023 19:36:21 +0530 Subject: [PATCH 03/10] Add test to github workflows --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d0bd850ca8e..1a1d7a50fc3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -104,6 +104,7 @@ jobs: tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py + tests/logprob/test_binary.py tests/logprob/test_censoring.py tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py From 6aeb191b25bee685b0993b9366ebcca2a112d8a1 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Tue, 11 Apr 2023 04:04:47 +0530 Subject: [PATCH 04/10] Use logprob and logcdf helpers --- pymc/logprob/binary.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 0422f20a823..412d7ec0454 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -25,8 +25,9 @@ from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableVariable, - _logcdf, + _logcdf_helper, _logprob, + _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import ignore_logprob @@ -81,9 +82,8 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): (value,) = values base_rv_op = base_rv.owner.op - base_rv_inputs = base_rv.owner.inputs - logcdf = _logcdf(base_rv_op, operand, *base_rv_inputs, **kwargs) + logcdf = _logcdf_helper(base_rv, operand, **kwargs) logccdf = pt.log1mexp(logcdf) condn_exp = pt.eq(value, np.array(True)) @@ -96,7 +96,7 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): raise TypeError(f"Unsupported scalar_op {op.scalar_op}") if base_rv.dtype.startswith("int"): - logp_point = _logprob(base_rv_op, (operand,), *base_rv_inputs, **kwargs) + logp_point = _logprob_helper(base_rv, operand, **kwargs) if isinstance(op.scalar_op, GT): logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob) elif isinstance(op.scalar_op, LT): From 92bacb8f8cb18b7cb9c98702ec481a1b40c2fe05 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Tue, 11 Apr 2023 19:38:31 +0530 Subject: [PATCH 05/10] Combine lt and gt tests and add gt test for discrete --- tests/logprob/test_binary.py | 65 +++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index ca8195c40e1..ae973b106a3 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -14,49 +14,53 @@ import numpy as np import pytensor import pytensor.tensor as pt +import pytest import scipy.stats as st from pymc import logp from pymc.testing import assert_no_rvs -def test_continuous_rv_comparison_lt(): - x_rv = pt.random.normal(0.5, 1) - comp_x_rv = pt.lt(x_rv, 0.5) +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false", + [ + (pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf), + (pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf), + ], +) +def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): + x_rv = pt.random.normal(0, 1) + comp_x_rv = comparison_op(x_rv, 0.5) comp_x_vv = comp_x_rv.clone() - comp_x_vv.tag.test_value = 0 logprob = logp(comp_x_rv, comp_x_vv) assert_no_rvs(logprob) logp_fn = pytensor.function([comp_x_vv], logprob) - ref_scipy = st.norm(0.5, 1) - assert np.isclose(logp_fn(0), ref_scipy.logcdf(0.5)) - assert np.isclose(logp_fn(1), ref_scipy.logsf(0.5)) - - -def test_continuous_rv_comparison_gt(): - x_rv = pt.random.normal(0.5, 1) - comp_x_rv = pt.gt(x_rv, 0.5) - - comp_x_vv = comp_x_rv.clone() - comp_x_vv.tag.test_value = 0 - - logprob = logp(comp_x_rv, comp_x_vv) - assert_no_rvs(logprob) - - logp_fn = pytensor.function([comp_x_vv], logprob) - ref_scipy = st.norm(0.5, 1) - - assert np.isclose(logp_fn(0), ref_scipy.logsf(0.5)) - assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.5)) - - -def test_discrete_rv_comparison(): + assert np.isclose(logp_fn(0), exp_logp_false(0.5)) + assert np.isclose(logp_fn(1), exp_logp_true(0.5)) + + +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false", + [ + ( + pt.lt, + st.poisson(2).logcdf, + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + ), + ( + pt.gt, + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + st.poisson(2).logcdf, + ), + ], +) +def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): x_rv = pt.random.poisson(2) - cens_x_rv = pt.lt(x_rv, 3) + cens_x_rv = comparison_op(x_rv, 3) cens_x_vv = cens_x_rv.clone() @@ -64,7 +68,6 @@ def test_discrete_rv_comparison(): assert_no_rvs(logprob) logp_fn = pytensor.function([cens_x_vv], logprob) - ref_scipy = st.poisson(2) - assert np.isclose(logp_fn(1), ref_scipy.logcdf(3)) - assert np.isclose(logp_fn(0), np.logaddexp(ref_scipy.logsf(3), ref_scipy.logpmf(3))) + assert np.isclose(logp_fn(1), exp_logp_true(3)) + assert np.isclose(logp_fn(0), exp_logp_false(3)) From cca969b321c824dfb9b7b06f52afd9355a95d7bd Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 12 Apr 2023 04:05:05 +0530 Subject: [PATCH 06/10] Refactor and check for potential measurability of const --- pymc/logprob/binary.py | 6 +++++- pymc/logprob/transforms.py | 17 +++-------------- pymc/logprob/utils.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 412d7ec0454..c4e70681ac2 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -30,7 +30,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db -from pymc.logprob.utils import ignore_logprob +from pymc.logprob.utils import check_potential_measurability, ignore_logprob class MeasurableComparison(MeasurableElemwise): @@ -60,6 +60,10 @@ def find_measurable_comparisons( ): return None + # check for potential measurability of const + if not check_potential_measurability((const,), rv_map_feature): + return None + # Make base_var unmeasurable unmeasurable_base_var = ignore_logprob(base_var) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 903f013abef..1e0d1c0aed0 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -85,7 +85,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import ignore_logprob, walk_model +from pymc.logprob.utils import check_potential_measurability, ignore_logprob class TransformedVariable(Op): @@ -573,19 +573,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Check that other inputs are not potentially measurable, in which case this rewrite # would be invalid other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input) - if any( - ancestor_node - for ancestor_node in walk_model( - other_inputs, - walk_past_rvs=False, - stop_at_vars=set(rv_map_feature.rv_values), - ) - if ( - ancestor_node.owner - and isinstance(ancestor_node.owner.op, MeasurableVariable) - and ancestor_node not in rv_map_feature.rv_values - ) - ): + + if not check_potential_measurability(other_inputs, rv_map_feature): return None # Make base_measure outputs unmeasurable diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 18f9b803e76..c44e88a500d 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -210,6 +210,24 @@ def indices_from_subtensor(idx_list, indices): ) +def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature): + if any( + ancestor_node + for ancestor_node in walk_model( + inputs, + walk_past_rvs=False, + stop_at_vars=set(rv_map_feature.rv_values), + ) + if ( + ancestor_node.owner + and isinstance(ancestor_node.owner.op, MeasurableVariable) + and ancestor_node not in rv_map_feature.rv_values + ) + ): + return None + return True + + class ParameterValueError(ValueError): """Exception for invalid parameters values in logprob graphs""" From db7e0c0fa3fa23ed80044a31320a30bddebbe9de Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Fri, 14 Apr 2023 23:32:46 +0530 Subject: [PATCH 07/10] Add test for when const is measurable --- tests/logprob/test_binary.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index ae973b106a3..cb3f2a11bb8 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -17,7 +17,10 @@ import pytest import scipy.stats as st +from pytensor import function + from pymc import logp +from pymc.logprob import factorized_joint_logprob from pymc.testing import assert_no_rvs @@ -71,3 +74,23 @@ def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): assert np.isclose(logp_fn(1), exp_logp_true(3)) assert np.isclose(logp_fn(0), exp_logp_false(3)) + + +def test_potentially_measurable_operand(): + x_rv = pt.random.normal(2) + z_rv = pt.random.normal(x_rv) + y_rv = pt.lt(x_rv, z_rv) + + y_vv = y_rv.clone() + z_vv = z_rv.clone() + + logprob = factorized_joint_logprob({z_rv: z_vv, y_rv: y_vv})[y_vv] + assert_no_rvs(logprob) + + fn = function([z_vv, y_vv], logprob) + z_vv_test = 0.5 + y_vv_test = True + np.testing.assert_array_almost_equal( + fn(z_vv_test, y_vv_test), + st.norm(2, 1).logcdf(z_vv_test), + ) From 09004858916125de5cdb883e98786afe7e9dbede Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Sat, 15 Apr 2023 13:22:47 +0530 Subject: [PATCH 08/10] Add expected mypy failure of logprob/binary.py --- scripts/run_mypy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 2ad9c8a6f9c..72f70130070 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -29,6 +29,7 @@ pymc/distributions/timeseries.py pymc/distributions/truncated.py pymc/initial_point.py +pymc/logprob/binary.py pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py From 597e44e1a078de31060391962e6948469a279d67 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 17 Apr 2023 20:44:21 +0530 Subject: [PATCH 09/10] Add failed test for logp when const is measurable --- tests/logprob/test_binary.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index cb3f2a11bb8..a154d53a9d4 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -94,3 +94,9 @@ def test_potentially_measurable_operand(): fn(z_vv_test, y_vv_test), st.norm(2, 1).logcdf(z_vv_test), ) + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv).eval({y_vv: y_vv_test}) From 19ab8ad5e3895c847e0504055b6043b0cfcdd5aa Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Tue, 18 Apr 2023 17:00:24 +0530 Subject: [PATCH 10/10] Correction in logprob derivation of discrete distributions --- pymc/logprob/binary.py | 14 ++++++-------- tests/logprob/test_binary.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index c4e70681ac2..a35673b4547 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -95,17 +95,15 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): if isinstance(op.scalar_op, GT): logprob = pt.switch(condn_exp, logccdf, logcdf) elif isinstance(op.scalar_op, LT): - logprob = pt.switch(condn_exp, logcdf, logccdf) + if base_rv.dtype.startswith("int"): + logpmf = _logprob_helper(base_rv, operand, **kwargs) + logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs) + logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf)) + else: + logprob = pt.switch(condn_exp, logcdf, logccdf) else: raise TypeError(f"Unsupported scalar_op {op.scalar_op}") - if base_rv.dtype.startswith("int"): - logp_point = _logprob_helper(base_rv, operand, **kwargs) - if isinstance(op.scalar_op, GT): - logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob) - elif isinstance(op.scalar_op, LT): - logprob = pt.switch(condn_exp, logprob, pt.logaddexp(logprob, logp_point)) - if base_rv_op.name: logprob.name = f"{base_rv_op}_logprob" logcdf.name = f"{base_rv_op}_logcdf" diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index a154d53a9d4..c2dc6926598 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -51,12 +51,12 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): [ ( pt.lt, - st.poisson(2).logcdf, + lambda x: st.poisson(2).logcdf(x - 1), lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), ), ( pt.gt, - lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + st.poisson(2).logsf, st.poisson(2).logcdf, ), ],