Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive logprob of matmul #7542

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ jobs:
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
tests/logprob/test_linalg.py
tests/logprob/test_mixture.py
tests/logprob/test_order.py
tests/logprob/test_rewriting.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import pymc.logprob.censoring
import pymc.logprob.cumsum
import pymc.logprob.checks
import pymc.logprob.linalg
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
Expand Down
5 changes: 5 additions & 0 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable

Expand Down Expand Up @@ -168,6 +169,10 @@ def __str__(self):
return f"Measurable{super().__str__()}"


class MeasurableBlockwise(MeasurableOp, Blockwise):
"""Base class for Measurable Blockwise variables."""


class ValuedRV(Op):
r"""Represents the association of a measurable variable and its value.
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def conditional_logp(
if not isinstance(node.op, MeasurableOp):
continue

valued_nodes = get_related_valued_nodes(node, fgraph)
valued_nodes = get_related_valued_nodes(fgraph, node)

if not valued_nodes:
continue
Expand Down
102 changes: 102 additions & 0 deletions pymc/logprob/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.math import _matrix_matrix_matmul

from pymc.logprob.abstract import MeasurableBlockwise, MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables


class MeasurableMatMul(MeasurableBlockwise):
"""Measurable matrix multiplication operation."""

right_measurable: bool

def __init__(self, measurable_right: bool, **kwargs):
self.right_measurable = measurable_right
super().__init__(**kwargs)


@_logprob.register(MeasurableMatMul)
def logprob_measurable_matmul(op, values, l, r): # noqa: E741
[y_value] = values
if op.right_measurable:
A, x = l, r
x_value = pt.linalg.solve(A, y_value)
else:
x, A = l, r
x_value = pt.linalg.solve(A.mT, y_value.mT).mT

x_logp = _logprob_helper(x, x_value)

# The operation has a support dimensionality of 2
# We need to reduce it if it's still present in the base logp
if x_logp.type.ndim == x_value.type.ndim:
x_logp = pt.sum(x_logp, axis=(-1, -2))
elif x_logp.type.ndim == x_value.type.ndim - 1:
x_logp = pt.sum(x_logp, axis=-1)

_, log_abs_jac_det = pt.linalg.slogdet(A)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

return x_logp - log_abs_jac_det


@node_rewriter(tracks=[_matrix_matrix_matmul])
def find_measurable_matmul(fgraph, node):
"""Find measurable matrix-matrix multiplication operations."""
if isinstance(node.op, MeasurableOp):
return None

[out] = node.outputs
[l, r] = node.inputs # noqa: E741

# Check that not both a and r are measurable
measurable_inputs = filter_measurable_variables([l, r])
if len(measurable_inputs) != 1:
return None

[measurable_input] = measurable_inputs

# Check the measurable input is not broadcasted
if measurable_input.type.broadcastable[:-2] != out.type.broadcastable[:-2]:
return None

measurable_right = measurable_input is r
A = l if measurable_right else r

# Check if the static shape already reveals a non-square matrix,
if (
A.type.shape[-1] is not None
and A.type.shape[-2] is not None
and A.type.shape[-1] != A.type.shape[-2]
):
return None

# Check the other input is not potentially measurable
if check_potential_measurability([A]):
return None

measurable_matmul = MeasurableMatMul(measurable_right=measurable_right, **node.op._props_dict())
return [measurable_matmul(l, r)]


measurable_ir_rewrites_db.register(
find_measurable_matmul.__name__,
find_measurable_matmul,
"basic",
"linalg",
)
2 changes: 1 addition & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def split_valued_ifelse(fgraph, node):
# Single outputs IfElse
return None

valued_output_nodes = get_related_valued_nodes(node, fgraph)
valued_output_nodes = get_related_valued_nodes(fgraph, node)
if not valued_output_nodes:
return None

Expand Down
12 changes: 10 additions & 2 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def remove_DiracDelta(fgraph, node):
logprob_rewrites_db.register(
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
)
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
logprob_rewrites_db.register(
"pre-canonicalize",
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
"basic",
)

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand All @@ -175,7 +179,11 @@ def remove_DiracDelta(fgraph, node):
)


logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
logprob_rewrites_db.register(
"post-canonicalize",
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
"basic",
)

# Rewrites that remove IR Ops
cleanup_ir_rewrites_db = LocalGroupDB()
Expand Down
4 changes: 2 additions & 2 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def find_measurable_scans(fgraph, node):
# Find outputs of scan that are directly valued.
# These must be mapping outputs, such as `outputs_info = [None]` (i.e, no recurrence nit_sot outputs)
direct_valued_outputs = [
valued_node.inputs[0] for valued_node in get_related_valued_nodes(node, fgraph)
valued_node.inputs[0] for valued_node in get_related_valued_nodes(fgraph, node)
]
if not all(valued_out in scan_args.outer_out_nit_sot for valued_out in direct_valued_outputs):
return None
Expand All @@ -434,7 +434,7 @@ def find_measurable_scans(fgraph, node):
client.outputs[0]
for out in node.outputs
for client, _ in fgraph.clients[out]
if (isinstance(client.op, Subtensor) and get_related_valued_nodes(client, fgraph))
if (isinstance(client.op, Subtensor) and get_related_valued_nodes(fgraph, client))
]
indirect_valued_outputs = [out.owner.inputs[0] for out in sliced_valued_outputs]
if not all(
Expand Down
77 changes: 62 additions & 15 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import Join, MakeVector
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
)

from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv
from pymc.logprob.abstract import (
MeasurableOp,
ValuedRV,
_logprob,
_logprob_helper,
promised_valued_rv,
)
from pymc.logprob.rewriting import (
assume_valued_outputs,
early_measurable_ir_rewrites_db,
Expand All @@ -57,6 +63,7 @@
from pymc.logprob.utils import (
check_potential_measurability,
filter_measurable_variables,
get_related_valued_nodes,
replace_rvs_by_values,
)
from pymc.pytensorf import constant_fold
Expand Down Expand Up @@ -183,6 +190,9 @@ class MeasurableDimShuffle(MeasurableOp, DimShuffle):
# find it locally and fails when a new `Op` is initialized
c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) # type: ignore[arg-type]

def __str__(self):
return f"Measurable{super().__str__()}"


@_logprob.register(MeasurableDimShuffle)
def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
Expand Down Expand Up @@ -215,29 +225,66 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
return raw_logp.dimshuffle(redo_ds)


def _elemwise_univariate_chain(fgraph, node) -> bool:
# Check whether only Elemwise operations connect a base univariate RV to the valued node through var.
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.logprob.transforms import MeasurableTransform

[inp] = node.inputs
[out] = node.outputs

def elemwise_root(var: TensorVariable) -> TensorVariable | None:
if isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable):
return var
elif isinstance(var.owner.op, MeasurableTransform):
return elemwise_root(var.owner.inputs[var.owner.op.measurable_input_idx])
else:
return None

# Check that the root is a univariate distribution linked by only elemwise operations
root = elemwise_root(inp)
if root is None:
return False
elif root.owner.op.ndim_supp != 0:
# This is still fine if the variable is directly valued
return any(get_related_valued_nodes(fgraph, node))

def elemwise_leaf(var: TensorVariable, clients=fgraph.clients) -> bool:
var_clients = clients[var]
if len(var_clients) != 1:
return False
[(client, _)] = var_clients
if isinstance(client.op, ValuedRV):
return True
elif isinstance(client.op, Elemwise) and len(client.outputs) == 1:
return elemwise_leaf(client.outputs[0])
else:
return False

# Check that the path to the valued node consists only of elemwise operations
return elemwise_leaf(out)


@node_rewriter([DimShuffle])
def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
r"""Find `Dimshuffle`\s for which a `logprob` can be computed."""
from pymc.distributions.distribution import SymbolicRandomVariable

if isinstance(node.op, MeasurableOp):
return None

if not filter_measurable_variables(node.inputs):
return None

base_var = node.inputs[0]
# In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise
# operations separate it from the valued node. Further transformations likely need to know where
# the support axes are for a correct implementation (and thus assume they are the rightmost axes).
# TODO: When we include the support axis as meta information in each intermediate MeasurableVariable,
# we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360)
if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain(
fgraph, node
):
return None
Comment on lines +277 to +285
Copy link
Member Author

@ricardoV94 ricardoV94 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These DimShuffle changes were needed to naturally accommodate A @ x when x is a vector, which looks like:

import pytensor.tensor as pt

A = pt.matrix("A")
x = pt.vector("x")
y = A @ x
y.dprint()
# DropDims{axis=1} [id A]
#  └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id B]
#     ├─ A [id C]
#     └─ ExpandDims{axis=1} [id D]
#        └─ x [id E]

It's also more strict / correct than the limitation we had before, because the concerns are much more about what's after the DimShuffle not so much before.


# We can only apply this rewrite directly to `RandomVariable`s, as those are
# the only `Op`s for which we always know the support axis. Other measurable
# variables can have arbitrary support axes (e.g., if they contain separate
# `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s
# should still be supported as long as the `DimShuffle`s can be merged/
# lifted towards the base RandomVariable.
# TODO: If we include the support axis as meta information in each
# intermediate MeasurableVariable, we can lift this restriction.
if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable):
return None # pragma: no cover
base_var = node.inputs[0]

measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
base_var
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None:
return None

rv_node = node.inputs[0].owner
valued_nodes = get_related_valued_nodes(rv_node, fgraph)
valued_nodes = get_related_valued_nodes(fgraph, rv_node)
rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
values = [valued_var.inputs[1] for valued_var in valued_nodes]
transforms = [values_to_transforms.get(value, None) for value in values]
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def find_negated_var(var):
return None


def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]:
def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much more natural order used in all sorts of pytensor utilities that require a node/variable and its' fgraph

"""Get all ValuedVars related to the same RV node.

Returns
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ ignore = [
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D105", # Missing docstring in magic method
]

[tool.ruff.lint.pydocstyle]
Expand Down
Loading
Loading