Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpFromGraph wrapper around alloc_diag #915

Merged
merged 33 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6857bea
Add `OpFromGraph` wrapper around `alloc_diag`
jessegrabowski Jul 10, 2024
5604d9a
Refactor `rewrite_det_diag_to_prod_diag` to use `AllocDiag2`
jessegrabowski Jul 10, 2024
b0abe17
Save arguments passed to `alloc_diag` as properties in `AllocDiag2`
jessegrabowski Jul 10, 2024
dbfe92c
Fix bug in `rewrite_det_diag_to_prod_diag` where batch case was incor…
jessegrabowski Jul 10, 2024
afe2a65
Remove `AllocDiag2` from graphs in the `specialization` phase, after …
jessegrabowski Jul 10, 2024
e810df0
Remove debugging code, formatting
jessegrabowski Jul 10, 2024
a5355b6
Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag`
jessegrabowski Jul 12, 2024
6e37d26
Correctly register `eagerly_inline_alloc_diag` as a JAX-only rewrite
jessegrabowski Jul 12, 2024
f6f27ec
Use `self` (not `type(self)`) in `OpFromGraph.make_node`
jessegrabowski Jul 12, 2024
9486c64
Use base class `OpFromGraph` when constructing `OpFromGraph` gradients
jessegrabowski Jul 12, 2024
abcde3f
Revert "Use `self` (not `type(self)`) in `OpFromGraph.make_node`"
jessegrabowski Jul 12, 2024
66013fc
Solve XY problem
jessegrabowski Jul 12, 2024
e7583d1
Appease mypy
jessegrabowski Jul 12, 2024
70b9cd6
Remove `inline` prop from wrapper class and set `inline=True`
jessegrabowski Jul 12, 2024
9605a0e
Set `inline = False`
jessegrabowski Jul 12, 2024
46fbc55
Add rewrite to inline all `OpFromGraph` `Op`s
jessegrabowski Jul 16, 2024
98cf641
Allow symbolic `offset`
jessegrabowski Jul 16, 2024
c76b54e
Exclude inline rewrite in JAX mode
jessegrabowski Jul 16, 2024
7b44e22
refactor `late_inline_ofg` rewrite to actually perform the correct re…
jessegrabowski Jul 16, 2024
1dfc3fc
Narrow scope of `late_inline` rewrite
jessegrabowski Jul 16, 2024
9f32661
Fix tests
jessegrabowski Jul 17, 2024
f30a63f
Remove `is_inline` prop
jessegrabowski Jul 17, 2024
bf705f9
Add JAX `OpFromGraph` test
jessegrabowski Jul 17, 2024
c8958a4
Don't omit `inline_ofg` rewrites in JAX mode
jessegrabowski Jul 17, 2024
665c766
Don't inline `KroneckerProduct`
jessegrabowski Jul 17, 2024
6487e61
Skip inline rewrite tests when `mode == FAST_COMPILE`
jessegrabowski Jul 17, 2024
43474fd
Incorporate review feedback
jessegrabowski Jul 17, 2024
6ef4084
Incorporate review feedback
jessegrabowski Jul 17, 2024
04ddb46
Add `is_zero_offset` helper to `Eye`
jessegrabowski Jul 17, 2024
c038109
Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `D…
jessegrabowski Jul 18, 2024
19f2895
Seed `test_local_lift_through_linalg` test
jessegrabowski Jul 18, 2024
fcbccde
Fix failing diag_rewrite test
jessegrabowski Jul 18, 2024
56a3ffe
Revert symbolic offset
jessegrabowski Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, grad
Expand All @@ -24,7 +23,6 @@
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError


Expand Down Expand Up @@ -575,7 +573,7 @@ def lop_overrides(inps, grads):
for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType)
]
lop_op = type(self)(
lop_op = OpFromGraph(
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads,
inline=self.is_inline,
Expand Down Expand Up @@ -669,7 +667,7 @@ def _build_and_cache_rop_op(self):
for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType | NullType)
]
rop_op = type(self)(
rop_op = OpFromGraph(
inputs=inner_inputs + eval_points,
outputs=filtered_output_grads,
inline=self.is_inline,
Expand Down Expand Up @@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs):
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
output[0] = variable


@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))


# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
24 changes: 24 additions & 0 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from collections.abc import Callable
from functools import singledispatch

import jax
import jax.numpy as jnp
import numpy as np

from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -114,3 +117,24 @@ def viewop(x):
return x

return viewop


@jax_funcify.register(OpFromGraph)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
_ = kwargs.pop("storage_map", None)

# Apply inner rewrites
JAX.optimizer(ofg.fgraph)
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)

if len(ofg.fgraph.outputs) == 1:

def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]

else:

def opfromgraph(*inputs):
return fgraph_fn(*inputs)

return opfromgraph
150 changes: 53 additions & 97 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
Expand Down Expand Up @@ -1333,6 +1334,25 @@ def infer_shape(self, fgraph, node, in_shapes):
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]

@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an Eye Op has a diagonal offset of zero

Parameters
----------
node
Eye node to test

Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""

offset = node.inputs[-1]
return isinstance(offset, Constant) and offset.data.item() == 0


def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
Expand Down Expand Up @@ -3726,109 +3746,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)


class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
class AllocDiag(OpFromGraph):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Wrapper Op for alloc_diag graphs
"""

__props__ = ("offset", "axis1", "axis2")
__props__ = ("axis1", "axis2")

def __init__(self, offset=0, axis1=0, axis2=1):
"""
Parameters
----------
offset: int
Offset of the diagonal from the main diagonal defined by `axis1`
and `axis2`. Can be positive or negative. Defaults to main
diagonal (i.e. 0).
axis1: int
Axis to be used as the first axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to first axis (i.e. 0).
axis2: int
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
def __init__(self, *args, axis1, axis2, **kwargs):
self.axis1 = axis1
self.axis2 = axis2

def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim < 1:
raise ValueError(
"AllocDiag needs an input with 1 or more dimensions", diag.type
)
return Apply(
self,
[diag],
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
)

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs

axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset

# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result = np.zeros(result_shape, dtype=x.dtype)

# Create slice for diagonal in final 2 axes
idxs = np.arange(x.shape[-1])
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
]

# Fill in final 2 axes with x
result[tuple(diagonal_slice)] = x
super().__init__(*args, **kwargs, strict=True)

if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

z[0] = result

def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]

def infer_shape(self, fgraph, nodes, shapes):
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)

result_shape = list(x_shape[:-1])
diag_shape = x_shape[-1] + abs(self.offset)
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]
@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an AllocDiag Op has a diagonal offset of zero

def __setstate__(self, state):
if "view_map" in state:
del state["view_map"]
Parameters
----------
node
AllocDiag node to test

self.__dict__.update(state)
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""

if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
offset = node.inputs[-1]
return isinstance(offset, Constant) and offset.data.item() == 0


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
Expand All @@ -3837,8 +3785,11 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
diagonal(alloc_diag(x)) == x
"""
from pytensor.tensor import set_subtensor
from pytensor.tensor.math import maximum

diag = as_tensor_variable(diag)
offset = as_tensor_variable(offset)

axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
Expand All @@ -3850,8 +3801,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
# Create slice for diagonal in final 2 axes
idxs = arange(diag.shape[-1])
diagonal_slice = (slice(None),) * (len(result_shape) - 2) + (
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
idxs + maximum(0, -offset),
idxs + maximum(0, offset),
)

# Fill in final 2 axes with diag
Expand All @@ -3865,7 +3816,12 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

return result
return AllocDiag(
inputs=[diag, offset],
outputs=[result],
axis1=axis1,
axis2=axis2,
)(diag, offset)


def diag(v, k=0):
Expand Down
10 changes: 10 additions & 0 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ def __init__(self, input_broadcastable, new_order):
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])
self.drop = drop

input_ndim = len(input_broadcastable)

self.is_left_expand_dims = "x" in new_order and (
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
input_ndim == 0 or new_order[:input_ndim] == list(range(input_ndim))
)

self.is_right_expand_dims = "x" in new_order and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
)

if self.inplace:
self.view_map = {0: [0]}

Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
Expand Down
Loading
Loading