Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive logprob of < and > operations #6662

Merged
merged 10 commits into from
Apr 19, 2023
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ jobs:
tests/distributions/test_truncated.py
tests/logprob/test_abstract.py
tests/logprob/test_basic.py
tests/logprob/test_binary.py
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# isort: off
# Add rewrites to the DBs
import pymc.logprob.binary
import pymc.logprob.censoring
import pymc.logprob.cumsum
import pymc.logprob.checks
Expand Down
113 changes: 113 additions & 0 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional

import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GT, LT
from pytensor.tensor.math import gt, lt

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability, ignore_logprob


class MeasurableComparison(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""

valid_scalar_types = (GT, LT)


@node_rewriter(tracks=[gt, lt])
def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableComparison]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableComparison):
return None # pragma: no cover

(compared_var,) = node.outputs
base_var, const = node.inputs

if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
):
return None

# check for potential measurability of const
if not check_potential_measurability((const,), rv_map_feature):
return None

# Make base_var unmeasurable
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
unmeasurable_base_var = ignore_logprob(base_var)

compared_op = MeasurableComparison(node.op.scalar_op)
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
compared_rv.name = compared_var.name
return [compared_rv]


measurable_ir_rewrites_db.register(
"find_measurable_comparisons",
find_measurable_comparisons,
"basic",
"comparison",
)


@_logprob.register(MeasurableComparison)
def comparison_logprob(op, values, base_rv, operand, **kwargs):
(value,) = values

base_rv_op = base_rv.owner.op

logcdf = _logcdf_helper(base_rv, operand, **kwargs)
logccdf = pt.log1mexp(logcdf)

condn_exp = pt.eq(value, np.array(True))

if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, logccdf, logcdf)
elif isinstance(op.scalar_op, LT):
logprob = pt.switch(condn_exp, logcdf, logccdf)
else:
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

if base_rv.dtype.startswith("int"):
logp_point = _logprob_helper(base_rv, operand, **kwargs)
if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, pt.logaddexp(logprob, logp_point), logprob)
elif isinstance(op.scalar_op, LT):
logprob = pt.switch(condn_exp, logprob, pt.logaddexp(logprob, logp_point))

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"

return logprob
17 changes: 3 additions & 14 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import ignore_logprob, walk_model
from pymc.logprob.utils import check_potential_measurability, ignore_logprob


class TransformedVariable(Op):
Expand Down Expand Up @@ -573,19 +573,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input)
if any(
ancestor_node
for ancestor_node in walk_model(
other_inputs,
walk_past_rvs=False,
stop_at_vars=set(rv_map_feature.rv_values),
)
if (
ancestor_node.owner
and isinstance(ancestor_node.owner.op, MeasurableVariable)
and ancestor_node not in rv_map_feature.rv_values
)
):

if not check_potential_measurability(other_inputs, rv_map_feature):
return None

# Make base_measure outputs unmeasurable
Expand Down
18 changes: 18 additions & 0 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def indices_from_subtensor(idx_list, indices):
)


def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature):
if any(
ancestor_node
for ancestor_node in walk_model(
inputs,
walk_past_rvs=False,
stop_at_vars=set(rv_map_feature.rv_values),
)
if (
ancestor_node.owner
and isinstance(ancestor_node.owner.op, MeasurableVariable)
and ancestor_node not in rv_map_feature.rv_values
)
):
return None
return True


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""

Expand Down
1 change: 1 addition & 0 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
pymc/distributions/timeseries.py
pymc/distributions/truncated.py
pymc/initial_point.py
pymc/logprob/binary.py
pymc/logprob/censoring.py
pymc/logprob/basic.py
pymc/logprob/mixture.py
Expand Down
102 changes: 102 additions & 0 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
shreyas3156 marked this conversation as resolved.
Show resolved Hide resolved
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.stats as st

from pytensor import function

from pymc import logp
from pymc.logprob import factorized_joint_logprob
from pymc.testing import assert_no_rvs


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
[
(pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf),
(pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf),
],
)
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.normal(0, 1)
comp_x_rv = comparison_op(x_rv, 0.5)

comp_x_vv = comp_x_rv.clone()

logprob = logp(comp_x_rv, comp_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([comp_x_vv], logprob)

assert np.isclose(logp_fn(0), exp_logp_false(0.5))
assert np.isclose(logp_fn(1), exp_logp_true(0.5))


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
[
(
pt.lt,
st.poisson(2).logcdf,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
),
(
pt.gt,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
st.poisson(2).logcdf,
),
],
)
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.poisson(2)
cens_x_rv = comparison_op(x_rv, 3)

cens_x_vv = cens_x_rv.clone()

logprob = logp(cens_x_rv, cens_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([cens_x_vv], logprob)

assert np.isclose(logp_fn(1), exp_logp_true(3))
assert np.isclose(logp_fn(0), exp_logp_false(3))


def test_potentially_measurable_operand():
x_rv = pt.random.normal(2)
z_rv = pt.random.normal(x_rv)
y_rv = pt.lt(x_rv, z_rv)

y_vv = y_rv.clone()
z_vv = z_rv.clone()

ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
logprob = factorized_joint_logprob({z_rv: z_vv, y_rv: y_vv})[y_vv]
assert_no_rvs(logprob)

fn = function([z_vv, y_vv], logprob)
z_vv_test = 0.5
y_vv_test = True
np.testing.assert_array_almost_equal(
fn(z_vv_test, y_vv_test),
st.norm(2, 1).logcdf(z_vv_test),
)

with pytest.raises(
NotImplementedError,
match="Logprob method not implemented",
):
logp(y_rv, y_vv).eval({y_vv: y_vv_test})