Skip to content

Commit

Permalink
Add OpFromGraph wrapper around alloc_diag (pymc-devs#915)
Browse files Browse the repository at this point in the history
* Add `OpFromGraph` wrapper around `alloc_diag`

* Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag`

* Set `inline = False`

* Add rewrite to inline all `OpFromGraph` `Op`s

* Add `is_zero_offset` helper to `Eye`

* Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle`

* Seed `test_local_lift_through_linalg` test
  • Loading branch information
jessegrabowski authored and Ian Schweer committed Aug 15, 2024
1 parent 79232b2 commit 143ded6
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 173 deletions.
32 changes: 2 additions & 30 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.mode import optdb
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, grad
Expand All @@ -24,7 +23,6 @@
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError


Expand Down Expand Up @@ -575,7 +573,7 @@ def lop_overrides(inps, grads):
for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType)
]
lop_op = type(self)(
lop_op = OpFromGraph(
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads,
inline=self.is_inline,
Expand Down Expand Up @@ -669,7 +667,7 @@ def _build_and_cache_rop_op(self):
for out_grad in output_grads
if not isinstance(out_grad.type, DisconnectedType | NullType)
]
rop_op = type(self)(
rop_op = OpFromGraph(
inputs=inner_inputs + eval_points,
outputs=filtered_output_grads,
inline=self.is_inline,
Expand Down Expand Up @@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs):
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
output[0] = variable


@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False
if not op.is_inline:
return False
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))


# We want to run this before the first merge optimizer
# and before the first scan optimizer.
optdb.register(
"inline_ofg_expansion",
in2out(inline_ofg_expansion),
"fast_compile",
"fast_run",
position=-0.01,
)
24 changes: 24 additions & 0 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from collections.abc import Callable
from functools import singledispatch

import jax
import jax.numpy as jnp
import numpy as np

from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -114,3 +117,24 @@ def viewop(x):
return x

return viewop


@jax_funcify.register(OpFromGraph)
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
_ = kwargs.pop("storage_map", None)

# Apply inner rewrites
JAX.optimizer(ofg.fgraph)
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)

if len(ofg.fgraph.outputs) == 1:

def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]

else:

def opfromgraph(*inputs):
return fgraph_fn(*inputs)

return opfromgraph
141 changes: 46 additions & 95 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
Expand Down Expand Up @@ -1334,6 +1335,25 @@ def infer_shape(self, fgraph, node, in_shapes):
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]

@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an Eye Op has a diagonal offset of zero
Parameters
----------
node
Eye node to test
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""

offset = node.inputs[-1]
return isinstance(offset, Constant) and offset.data.item() == 0


def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
Expand Down Expand Up @@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)


class AllocDiag(Op):
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
class AllocDiag(OpFromGraph):
"""
Wrapper Op for alloc_diag graphs
"""

__props__ = ("offset", "axis1", "axis2")
__props__ = ("axis1", "axis2")

def __init__(self, offset=0, axis1=0, axis2=1):
"""
Parameters
----------
offset: int
Offset of the diagonal from the main diagonal defined by `axis1`
and `axis2`. Can be positive or negative. Defaults to main
diagonal (i.e. 0).
axis1: int
Axis to be used as the first axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to first axis (i.e. 0).
axis2: int
Axis to be used as the second axis of the 2-D sub-arrays to which
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
warnings.warn(
"AllocDiag is deprecated. Use `alloc_diag` instead",
FutureWarning,
)
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
def __init__(self, *args, axis1, axis2, offset, **kwargs):
self.axis1 = axis1
self.axis2 = axis2
self.offset = offset

def make_node(self, diag):
diag = as_tensor_variable(diag)
if diag.type.ndim < 1:
raise ValueError(
"AllocDiag needs an input with 1 or more dimensions", diag.type
)
return Apply(
self,
[diag],
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
)

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs

axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
offset = self.offset

# Create array with one extra dimension for resulting matrix
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
result = np.zeros(result_shape, dtype=x.dtype)

# Create slice for diagonal in final 2 axes
idxs = np.arange(x.shape[-1])
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
idxs + np.maximum(0, -offset),
idxs + np.maximum(0, offset),
]

# Fill in final 2 axes with x
result[tuple(diagonal_slice)] = x

if len(x.shape) > 1:
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(len(x.shape[:-1])))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

z[0] = result

def grad(self, inputs, gout):
(gz,) = gout
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]

def infer_shape(self, fgraph, nodes, shapes):
(x_shape,) = shapes
axis1 = np.minimum(self.axis1, self.axis2)
axis2 = np.maximum(self.axis1, self.axis2)
super().__init__(*args, **kwargs, strict=True)

result_shape = list(x_shape[:-1])
diag_shape = x_shape[-1] + abs(self.offset)
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
return [tuple(result_shape)]
@staticmethod
def is_offset_zero(node) -> bool:
"""
Test if an AllocDiag Op has a diagonal offset of zero
def __setstate__(self, state):
if "view_map" in state:
del state["view_map"]
Parameters
----------
node
AllocDiag node to test
self.__dict__.update(state)
Returns
-------
is_offset_zero: bool
True if the offset is zero (``k = 0``).
"""

if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
return node.op.offset == 0


def alloc_diag(diag, offset=0, axis1=0, axis2=1):
Expand All @@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
from pytensor.tensor import set_subtensor

diag = as_tensor_variable(diag)

axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
Expand All @@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
result = result.transpose(axes)

return result
return AllocDiag(
inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset
)(diag)


def diag(v, k=0):
Expand Down
8 changes: 8 additions & 0 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def __init__(self, input_broadcastable, new_order):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop

input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
range(input_ndim)
)

if self.inplace:
self.view_map = {0: [0]}

Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
import pytensor.tensor.rewriting.subtensor
Expand Down
Loading

0 comments on commit 143ded6

Please sign in to comment.