Skip to content

Commit

Permalink
Derive matmul probability
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 21, 2024
1 parent c3c1d94 commit 1249c86
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ jobs:
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
tests/logprob/test_linalg.py
tests/logprob/test_mixture.py
tests/logprob/test_order.py
tests/logprob/test_rewriting.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import pymc.logprob.censoring
import pymc.logprob.cumsum
import pymc.logprob.checks
import pymc.logprob.linalg
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
Expand Down
5 changes: 5 additions & 0 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable

Expand Down Expand Up @@ -168,6 +169,10 @@ def __str__(self):
return f"Measurable{super().__str__()}"


class MeasurableBlockwise(MeasurableOp, Blockwise):
"""Base class for Measurable Blockwise variables."""


class ValuedRV(Op):
r"""Represents the association of a measurable variable and its value.
Expand Down
102 changes: 102 additions & 0 deletions pymc/logprob/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.math import _matrix_matrix_matmul

from pymc.logprob.abstract import MeasurableBlockwise, MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables


class MeasurableMatMul(MeasurableBlockwise):
"""Measurable matrix multiplication operation."""

right_measurable: bool

def __init__(self, measurable_right: bool, **kwargs):
self.right_measurable = measurable_right
super().__init__(**kwargs)


@_logprob.register(MeasurableMatMul)
def logprob_measurable_matmul(op, values, l, r): # noqa: E741
[y_value] = values
if op.right_measurable:
A, x = l, r
x_value = pt.linalg.solve(A, y_value)
else:
x, A = l, r
x_value = pt.linalg.solve(A.mT, y_value.mT).mT

x_logp = _logprob_helper(x, x_value)

# The operation has a support dimensionality of 2
# We need to reduce it if it's still present in the base logp
if x_logp.type.ndim == x_value.type.ndim:
x_logp = pt.sum(x_logp, axis=(-1, -2))
elif x_logp.type.ndim == x_value.type.ndim - 1:
x_logp = pt.sum(x_logp, axis=-1)

_, log_abs_jac_det = pt.linalg.slogdet(A)

return x_logp - log_abs_jac_det


@node_rewriter(tracks=[_matrix_matrix_matmul])
def find_measurable_matmul(fgraph, node):
"""Find measurable matrix-matrix multiplication operations."""
if isinstance(node.op, MeasurableOp):
return None

[out] = node.outputs
[l, r] = node.inputs # noqa: E741

# Check that not both a and r are measurable
measurable_inputs = filter_measurable_variables([l, r])
if len(measurable_inputs) != 1:
return None

[measurable_input] = measurable_inputs

# Check the measurable input is not broadcasted
if measurable_input.type.broadcastable[:-2] != out.type.broadcastable[:-2]:
return None

measurable_right = measurable_input is r
A = l if measurable_right else r

# Check if the static shape already reveals a non-square matrix,
if (
A.type.shape[-1] is not None
and A.type.shape[-2] is not None
and A.type.shape[-1] != A.type.shape[-2]
):
return None

# Check the other input is not potentially measurable
if check_potential_measurability([A]):
return None

measurable_matmul = MeasurableMatMul(measurable_right=measurable_right, **node.op._props_dict())
return [measurable_matmul(l, r)]


measurable_ir_rewrites_db.register(
find_measurable_matmul.__name__,
find_measurable_matmul,
"basic",
"linalg",
)
12 changes: 10 additions & 2 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def remove_DiracDelta(fgraph, node):
logprob_rewrites_db.register(
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
)
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
logprob_rewrites_db.register(
"pre-canonicalize",
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
"basic",
)

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand All @@ -175,7 +179,11 @@ def remove_DiracDelta(fgraph, node):
)


logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
logprob_rewrites_db.register(
"post-canonicalize",
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
"basic",
)

# Rewrites that remove IR Ops
cleanup_ir_rewrites_db = LocalGroupDB()
Expand Down
85 changes: 85 additions & 0 deletions tests/logprob/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest

from pytensor.tensor.type import tensor

from pymc.distributions import MatrixNormal, MvNormal, Normal
from pymc.logprob.basic import logp


@pytest.mark.parametrize("univariate", [True, False])
@pytest.mark.parametrize("batch_shape", [(), (3,)])
def test_matrix_vector_transform(univariate, batch_shape):
rng = np.random.default_rng(755)

μ = rng.normal(size=(*batch_shape, 2))
if univariate:
σ = np.abs(rng.normal(size=(*batch_shape, 2)))
Σ = np.eye(2) * (σ**2)[..., None]
x = Normal.dist(mu=μ, sigma=σ)
else:
A = rng.normal(size=(*batch_shape, 2, 2))
Σ = np.swapaxes(A, -1, -2) @ A
x = MvNormal.dist(mu=μ, cov=Σ)

c = rng.normal(size=(*batch_shape, 2))
B = rng.normal(size=(*batch_shape, 2, 2))
y = c + (B @ x[..., None]).squeeze(-1)

# An affine transformed MvNormal is still a MvNormal
# https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Affine_transformation
ref_dist = MvNormal.dist(
mu=c + (B @ μ[..., None]).squeeze(-1), cov=B @ Σ @ np.swapaxes(B, -1, -2)
)
test_y = rng.normal(size=(*batch_shape, 2))
np.testing.assert_allclose(
logp(y, test_y).eval(),
logp(ref_dist, test_y).eval(),
)


def test_matrix_matrix_transform():
rng = np.random.default_rng(46)

n, p = 2, 3
M = rng.normal(size=(n, p))
A = rng.normal(size=(n, n)) * 0.1
U = A.T @ A
B = rng.normal(size=(p, p)) * 0.1
V = B.T @ B
X = MatrixNormal.dist(mu=M, rowcov=U, colcov=V)

D = rng.normal(size=(n, n))
C = rng.normal(size=(p, p))
Y = D @ X @ C

# A linearly transformed MatrixNormal is still a MatrixNormal
# https://en.wikipedia.org/wiki/Matrix_normal_distribution#Transformation
ref_dist = MatrixNormal.dist(mu=D @ M @ C, rowcov=D @ U @ D.T, colcov=C.T @ V @ C)
test_Y = rng.normal(size=(n, p))
np.testing.assert_allclose(
logp(Y, test_Y).eval(),
logp(ref_dist, test_Y).eval(),
rtol=1e-5,
)


def test_broadcasted_matmul_fails():
x = Normal.dist(size=(3, 2))
A = tensor("A", shape=(4, 3, 3))
y = A @ x
with pytest.raises(NotImplementedError):
logp(y, y.type())

0 comments on commit 1249c86

Please sign in to comment.