Skip to content

Commit

Permalink
Add rewrite for sum of normal RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored and brandonwillard committed Apr 18, 2023
1 parent 58b57c2 commit 325afb8
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 11 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# isort: off
# Add rewrites to the DBs
import aeppl.censoring
import aeppl.convolutions
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan
Expand Down
59 changes: 59 additions & 0 deletions aeppl/convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import aesara.tensor as at
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.random.basic import NormalRV, normal

from aeppl.rewriting import measurable_ir_rewrites_db


@node_rewriter((at.sub, at.add))
def add_independent_normals(fgraph, node):
"""Replace a sum of two un-valued independent normal RVs with a single normal RV."""

if node.op == at.add:
sub = False
else:
sub = True

X_rv, Y_rv = node.inputs

if not (X_rv.owner and Y_rv.owner) or not (
# This also checks that the RVs are un-valued (i.e. they're not
# `ValuedVariable`s)
isinstance(X_rv.owner.op, NormalRV)
and isinstance(Y_rv.owner.op, NormalRV)
):
return None

old_rv = node.outputs[0]

mu_x, sigma_x, mu_y, sigma_y, _ = at.broadcast_arrays(
*(X_rv.owner.inputs[-2:] + Y_rv.owner.inputs[-2:] + [old_rv])
)

new_rng = X_rv.owner.inputs[0].clone()

new_node = normal.make_node(
new_rng,
old_rv.shape,
old_rv.dtype,
mu_x + mu_y if not sub else mu_x - mu_y,
at.sqrt(sigma_x**2 + sigma_y**2),
)

fgraph.add_input(new_rng)

# new_rng must be updated with values of the RNGs output by `new_node
new_rng.default_update = new_node.outputs[0]
new_normal_rv = new_node.default_output()

if old_rv.name:
new_normal_rv.name = old_rv.name

return [new_normal_rv]


measurable_ir_rewrites_db.register(
"add_independent_normals",
add_independent_normals,
"basic",
)
119 changes: 119 additions & 0 deletions tests/test_convolutions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import aesara.tensor as at
import numpy as np
import pytest
from aesara.tensor.random.basic import NormalRV

from aeppl.rewriting import construct_ir_fgraph
from aeppl.transforms import MeasurableElemwiseTransform


@pytest.mark.parametrize(
"mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape",
[
(
np.array([1, 10, 100]),
np.array(2),
np.array(0.03),
np.tile(0.04, 3),
(),
(),
),
(
np.array([1, 10, 100]),
np.array(2),
np.array(0.03),
np.full((5, 1), 0.04),
(),
(5, 3),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
np.array(0.03),
np.array(0.04),
(),
(),
),
(
np.broadcast_to(np.array([1, 10, 100]), (5, 3)),
np.array([2, 20, 200]),
np.array(0.03),
np.array(0.04),
(2, 5, 3),
(),
),
(
np.array([[1, 10, 100]]),
np.array([[0.2], [2], [20], [200], [2000]]),
np.array([[0.5], [5], [50], [500], [5000]]),
np.array([[0.4, 4, 40]]),
(2, 5, 3),
(),
),
(
np.array(1),
np.array(2),
np.array(3),
np.array(4),
(5, 1),
(1,),
),
],
)
@pytest.mark.parametrize("sub", [False, True])
def test_add_independent_normals(mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape, sub):
srng = at.random.RandomStream(29833)

X_rv = srng.normal(mu_x, sigma_x, size=x_shape)
X_rv.name = "X"

Y_rv = srng.normal(mu_y, sigma_y, size=y_shape)
Y_rv.name = "Y"

Z_rv = X_rv + Y_rv if not sub else X_rv - Y_rv
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})

(valued_var_out_node) = fgraph.outputs[0].owner
# The convolution should be applied, and not the transform
assert isinstance(valued_var_out_node.inputs[0].owner.op, NormalRV)

new_rv = fgraph.outputs[0].owner.inputs[0]

new_rv_mu = mu_x + mu_y if not sub else mu_x - mu_y
new_rv_sigma = np.sqrt(sigma_x**2 + sigma_y**2)

new_rv_shape = np.broadcast_shapes(
new_rv_mu.shape, new_rv_sigma.shape, x_shape, y_shape
)

new_rv_mu = np.broadcast_to(new_rv_mu, new_rv_shape)
new_rv_sigma = np.broadcast_to(new_rv_sigma, new_rv_shape)

assert isinstance(new_rv.owner.op, NormalRV)
assert np.allclose(new_rv.owner.inputs[3].eval(), new_rv_mu)
assert np.allclose(new_rv.owner.inputs[4].eval(), new_rv_sigma)
assert new_rv.name == "Z"


def test_normal_add_input_valued():
"""Test the case when one of the normal inputs to the add `Op` is a `ValuedVariable`."""
srng = at.random.RandomStream(0)

X_rv = srng.normal(1.0, name="X")
x_vv = X_rv.clone()
Y_rv = srng.normal(1.0, name="Y")
Z_rv = X_rv + Y_rv
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv})

valued_var_out_node = fgraph.outputs[0].owner
# We should not expect the convolution to be applied; instead, the
# transform should be (for now)
assert isinstance(
valued_var_out_node.inputs[0].owner.op, MeasurableElemwiseTransform
)
12 changes: 1 addition & 11 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,17 +675,6 @@ def test_transformed_rv_and_value():
)


def test_loc_transform_multiple_rvs_fails1():
srng = at.random.RandomStream(0)

x_rv1 = srng.normal(name="x_rv1")
x_rv2 = srng.normal(name="x_rv2")
y_rv = x_rv1 + x_rv2

with pytest.raises(DensityNotFound):
joint_logprob(y_rv)


def test_nested_loc_transform_multiple_rvs_fails2():
srng = at.random.RandomStream(0)

Expand Down Expand Up @@ -816,6 +805,7 @@ def test_transform_sub_valued():
Z_rv = A_rv - X_rv

logp, (z_vv, a_vv) = joint_logprob(Z_rv, A_rv)

z_logp_fn = aesara.function([z_vv, a_vv], logp)
exp_logp = sp.stats.norm.logpdf(5.0 - 7.3, 1.0) + sp.stats.norm.logpdf(5.0, 1.0)
assert np.isclose(z_logp_fn(7.3, 5.0), exp_logp)

0 comments on commit 325afb8

Please sign in to comment.