Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpFromGraph wrapper around alloc_diag #915

Merged
merged 33 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6857bea
Add `OpFromGraph` wrapper around `alloc_diag`
jessegrabowski Jul 10, 2024
5604d9a
Refactor `rewrite_det_diag_to_prod_diag` to use `AllocDiag2`
jessegrabowski Jul 10, 2024
b0abe17
Save arguments passed to `alloc_diag` as properties in `AllocDiag2`
jessegrabowski Jul 10, 2024
dbfe92c
Fix bug in `rewrite_det_diag_to_prod_diag` where batch case was incor…
jessegrabowski Jul 10, 2024
afe2a65
Remove `AllocDiag2` from graphs in the `specialization` phase, after …
jessegrabowski Jul 10, 2024
e810df0
Remove debugging code, formatting
jessegrabowski Jul 10, 2024
a5355b6
Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag`
jessegrabowski Jul 12, 2024
6e37d26
Correctly register `eagerly_inline_alloc_diag` as a JAX-only rewrite
jessegrabowski Jul 12, 2024
f6f27ec
Use `self` (not `type(self)`) in `OpFromGraph.make_node`
jessegrabowski Jul 12, 2024
9486c64
Use base class `OpFromGraph` when constructing `OpFromGraph` gradients
jessegrabowski Jul 12, 2024
abcde3f
Revert "Use `self` (not `type(self)`) in `OpFromGraph.make_node`"
jessegrabowski Jul 12, 2024
66013fc
Solve XY problem
jessegrabowski Jul 12, 2024
e7583d1
Appease mypy
jessegrabowski Jul 12, 2024
70b9cd6
Remove `inline` prop from wrapper class and set `inline=True`
jessegrabowski Jul 12, 2024
9605a0e
Set `inline = False`
jessegrabowski Jul 12, 2024
46fbc55
Add rewrite to inline all `OpFromGraph` `Op`s
jessegrabowski Jul 16, 2024
98cf641
Allow symbolic `offset`
jessegrabowski Jul 16, 2024
c76b54e
Exclude inline rewrite in JAX mode
jessegrabowski Jul 16, 2024
7b44e22
refactor `late_inline_ofg` rewrite to actually perform the correct re…
jessegrabowski Jul 16, 2024
1dfc3fc
Narrow scope of `late_inline` rewrite
jessegrabowski Jul 16, 2024
9f32661
Fix tests
jessegrabowski Jul 17, 2024
f30a63f
Remove `is_inline` prop
jessegrabowski Jul 17, 2024
bf705f9
Add JAX `OpFromGraph` test
jessegrabowski Jul 17, 2024
c8958a4
Don't omit `inline_ofg` rewrites in JAX mode
jessegrabowski Jul 17, 2024
665c766
Don't inline `KroneckerProduct`
jessegrabowski Jul 17, 2024
6487e61
Skip inline rewrite tests when `mode == FAST_COMPILE`
jessegrabowski Jul 17, 2024
43474fd
Incorporate review feedback
jessegrabowski Jul 17, 2024
6ef4084
Incorporate review feedback
jessegrabowski Jul 17, 2024
04ddb46
Add `is_zero_offset` helper to `Eye`
jessegrabowski Jul 17, 2024
c038109
Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `D…
jessegrabowski Jul 18, 2024
19f2895
Seed `test_local_lift_through_linalg` test
jessegrabowski Jul 18, 2024
fcbccde
Fix failing diag_rewrite test
jessegrabowski Jul 18, 2024
56a3ffe
Revert symbolic offset
jessegrabowski Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def lop_overrides(inps, grads):
for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType)
]
lop_op = type(self)(
lop_op = OpFromGraph(
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads,
inline=self.is_inline,
Expand Down Expand Up @@ -669,7 +669,7 @@ def _build_and_cache_rop_op(self):
for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType | NullType)
]
rop_op = type(self)(
rop_op = OpFromGraph(
inputs=inner_inputs + eval_points,
outputs=filtered_output_grads,
inline=self.is_inline,
Expand Down
42 changes: 42 additions & 0 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import jax.numpy as jnp
import numpy as np

import pytensor
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
Expand Down Expand Up @@ -205,3 +211,39 @@ def tri(*args):
return jnp.tri(*args, dtype=op.dtype)

return tri


@node_rewriter([AllocDiag])
def eagerly_inline_alloc_diag(fgraph, node):
"""
Inline `AllocDiag` OpFromGraph into the graph so the component Ops can themselves be jaxified

Parameters
----------
fgraph: FunctionGraph
The function graph being rewritten
node: Apply
Node of the function graph to be optimized

Returns
-------

"""
[input] = node.inputs
[output] = node.op.inner_outputs
inner_input = output.owner.inputs[1]

inline = pytensor.clone_replace(output, {inner_input: input})

return [inline]
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved


remove_alloc_ofg_opt = SequenceDB()
remove_alloc_ofg_opt.register(
"inline_alloc_diag",
in2out(eagerly_inline_alloc_diag),
"jax",
)

# Do this right away so other JAX rewrites can act on the inner graph
optdb.register("jax_inline_alloc_diag", remove_alloc_ofg_opt, "jax", position=0)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 13 additions & 98 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
Expand Down Expand Up @@ -3726,109 +3727,21 @@
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)


class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
class AllocDiag(OpFromGraph):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Wrapper Op for alloc_diag graphs
"""

__props__ = ("offset", "axis1", "axis2")
__props__ = ("offset", "axis1", "axis2", "inline")

def __init__(self, offset=0, axis1=0, axis2=1):
"""
Parameters
----------
offset: int
Offset of the diagonal from the main diagonal defined by `axis1`
and `axis2`. Can be positive or negative. Defaults to main
diagonal (i.e. 0).
axis1: int
Axis to be used as the first axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to first axis (i.e. 0).
axis2: int
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
def __init__(self, *args, offset, axis1, axis2, **kwargs):
inline = kwargs.pop("inline", False)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
self.axis1 = axis1
self.axis2 = axis2
self.inline = inline

def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim < 1:
raise ValueError(
"AllocDiag needs an input with 1 or more dimensions", diag.type
)
return Apply(
self,
[diag],
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
)

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs

axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset

# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result = np.zeros(result_shape, dtype=x.dtype)

# Create slice for diagonal in final 2 axes
idxs = np.arange(x.shape[-1])
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
]

# Fill in final 2 axes with x
result[tuple(diagonal_slice)] = x

if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

z[0] = result

def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]

def infer_shape(self, fgraph, nodes, shapes):
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)

result_shape = list(x_shape[:-1])
diag_shape = x_shape[-1] + abs(self.offset)
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]

def __setstate__(self, state):
if "view_map" in state:
del state["view_map"]

self.__dict__.update(state)

if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
super().__init__(*args, **kwargs, strict=True, inline=inline)


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
Expand Down Expand Up @@ -3865,7 +3778,9 @@
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

return result
return AllocDiag(
inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2

Check warning on line 3782 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3782

Added line #L3782 was not covered by tests
)(diag)


def diag(v, k=0):
Expand Down
113 changes: 69 additions & 44 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
PatternNodeRewriter,
copy_stack_trace,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
from pytensor.tensor.basic import (
AllocDiag,
Eye,
TensorVariable,
diagonal,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
Expand Down Expand Up @@ -41,7 +45,6 @@
solve,
solve_triangular,
)
from pytensor.tensor.subtensor import advanced_set_subtensor


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -401,30 +404,59 @@
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
if mul_input.owner
and (
isinstance(mul_input.owner.op, Eye)
or
# This whole condition checks if there is an Eye hiding inside a DimShuffle.
# This arises from batched elementwise multiplication between a tensor and an eye, e.g.:
# tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites.
(
isinstance(mul_input.owner.op, DimShuffle)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
and mul_input.owner.inputs[0].owner is not None
and isinstance(mul_input.owner.inputs[0].owner.op, Eye)
)
)
]

# Check if 1's are being put on the main diagonal only (k = 0)
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
if not eye_input:
return None

# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
eye_input = eye_input[0]

# If this multiplication came from a batched operation, it will be wrapped in a DimShuffle
if isinstance(eye_input.owner.op, DimShuffle):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
inner_eye = eye_input.owner.inputs[0]
if not isinstance(inner_eye.owner.op, Eye):
return None
# Check if 1's are being put on the main diagonal only (k = 0)

Check warning on line 431 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L431

Added line #L431 was not covered by tests
# and if the identity matrix is degenerate (column or row matrix)
if getattr(
inner_eye.owner.inputs[-1], "data", -1
).item() != 0 or inner_eye.broadcastable[-2:] != (False, False):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
return None

Check warning on line 437 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L437

Added line #L437 was not covered by tests
elif getattr(
eye_input.owner.inputs[-1], "data", -1
).item() != 0 or eye_input.broadcastable[-2:] != (False, False):
return None

# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
non_eye_inputs = list(set(inputs_to_mul) - {eye_input})
return eye_input, non_eye_inputs


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([det])
def rewrite_det_diag_from_eye_mul(fgraph, node):
def rewrite_det_diag_to_prod_diag(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its
diagonal elements.

The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
vector or a matrix.

Parameters
----------
Expand All @@ -438,53 +470,46 @@
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is None:
inputs = node.inputs[0]

# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and inputs.owner.op.offset == 0
):
diag_input = inputs.owner.inputs[0]
det_val = diag_input.prod(axis=-1)
return [det_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)

if inputs_or_none is None:
return None
eye_input, non_eye_inputs = eye_non_eye_inputs

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]

# Checking if original x was scalar/vector/matrix
if useful_non_eye.type.broadcastable[-2:] == (True, True):
if non_eye_input.type.broadcastable[-2:] == (True, True):
# For scalar
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0])
elif non_eye_input.type.broadcastable[-2:] == (False, False):
# For Matrix
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
else:
# For vector
det_val = useful_non_eye.prod(axis=(-1, -2))
det_val = non_eye_input.prod(axis=(-1, -2))
det_val = det_val.astype(node.outputs[0].type.dtype)
return [det_val]


arange = ARange("int64")
det_diag_from_diag = PatternNodeRewriter(
(
det,
(
advanced_set_subtensor,
(alloc, 0, "sh1", "sh2"),
"x",
(arange, 0, "stop", 1),
(arange, 0, "stop", 1),
),
),
(prod, "x"),
name="det_diag_from_diag",
allow_multiple_clients=True,
)
register_canonicalize(det_diag_from_diag)
register_stabilize(det_diag_from_diag)
register_specialize(det_diag_from_diag)


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down
Loading
Loading