Skip to content

Commit

Permalink
Add rewrites for measurable negation and subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 7, 2022
1 parent 106b86c commit 6959886
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 10 deletions.
55 changes: 54 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div
from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div
from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -286,6 +286,46 @@ def measurable_div_to_reciprocal_product(fgraph, node):
return [at.mul(numerator, at.reciprocal(denominator))]


@node_rewriter([neg])
def measurable_neg_to_product(fgraph, node):
"""Convert negation of `MeasurableVariable`s to product with `-1`."""

inp = node.inputs[0]
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
return None

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if the variable is unvalued
if inp in rv_map_feature.rv_values:
return None # pragma: no cover

return [at.mul(inp, -1.0)]


@node_rewriter([sub])
def measurable_sub_to_neg(fgraph, node):
"""Convert subtraction involving `MeasurableVariable`s to addition with neg"""
measurable_vars = [
var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable))
]
if not measurable_vars:
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

# Only apply this rewrite if there is one unvalued MeasurableVariable involved
if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars):
return None # pragma: no cover

minuend, subtrahend = node.inputs
return [at.add(minuend, at.neg(subtrahend))]


@node_rewriter([exp, log, add, mul, reciprocal])
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""
Expand Down Expand Up @@ -377,6 +417,19 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_neg_to_product",
measurable_neg_to_product,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_sub_to_neg",
measurable_sub_to_neg,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"find_measurable_transforms",
Expand Down
40 changes: 31 additions & 9 deletions pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,20 @@ def test_log_transform_rv():


@pytest.mark.parametrize(
"rv_size, loc_type",
"rv_size, loc_type, addition",
[
(None, at.scalar),
(2, at.vector),
((2, 1), at.col),
(None, at.scalar, True),
(2, at.vector, False),
((2, 1), at.col, True),
],
)
def test_loc_transform_rv(rv_size, loc_type):
def test_loc_transform_rv(rv_size, loc_type, addition):

loc = loc_type("loc")
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
if addition:
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
else:
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc)
y_rv.name = "y"
y_vv = y_rv.clone()

Expand Down Expand Up @@ -642,9 +645,7 @@ def test_scale_transform_rv(rv_size, scale_type, product):
if product:
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale
else:
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") / at.reciprocal(
scale
)
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") / at.reciprocal(scale)
y_rv.name = "y"
y_vv = y_rv.clone()

Expand Down Expand Up @@ -741,3 +742,24 @@ def test_reciprocal_rv_transform(numerator):
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)


def test_negated_rv_transform():
x_rv = -at.random.halfnormal()
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))


def test_subtracted_rv_transform():
# Choose base RV that is assymetric around zero
x_rv = 5.0 - at.random.normal(1.0)
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))

assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))

0 comments on commit 6959886

Please sign in to comment.