From 6857bea6f9aea71696e5f9d3591bfdae395fe513 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 10 Jul 2024 20:11:14 +0800 Subject: [PATCH 01/33] Add `OpFromGraph` wrapper around `alloc_diag` --- pytensor/tensor/basic.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 414e3b6ed2..7269e0421b 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -21,6 +21,7 @@ import pytensor.scalar.sharedvar from pytensor import compile, config, printing from pytensor import scalar as ps +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply, Constant, Variable, equal_computations @@ -3831,6 +3832,12 @@ def __setstate__(self, state): self.axis2 = 1 +class AllocDiag2(OpFromGraph): + """ + Wrapper Op for alloc_diag graphs + """ + + def alloc_diag(diag, offset=0, axis1=0, axis2=1): """Insert a vector on the diagonal of a zero-ed matrix. @@ -3865,7 +3872,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] result = result.transpose(axes) - return result + return AllocDiag2(inputs=[diag], outputs=[result])(diag) def diag(v, k=0): From 5604d9a9738be4226a80d28a348ddfa7a91faae7 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 10 Jul 2024 20:58:34 +0800 Subject: [PATCH 02/33] Refactor `rewrite_det_diag_to_prod_diag` to use `AllocDiag2` --- pytensor/tensor/rewriting/linalg.py | 68 +++++++++++++---------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3c98834c94..2bf62ca985 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,12 +5,16 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( - PatternNodeRewriter, copy_stack_trace, node_rewriter, ) from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal +from pytensor.tensor.basic import ( + AllocDiag2, + Eye, + TensorVariable, + diagonal, +) from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -41,7 +45,6 @@ solve, solve_triangular, ) -from pytensor.tensor.subtensor import advanced_set_subtensor logger = logging.getLogger(__name__) @@ -420,11 +423,15 @@ def _find_diag_from_eye_mul(potential_mul_input): @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @node_rewriter([det]) -def rewrite_det_diag_from_eye_mul(fgraph, node): +def rewrite_det_diag_to_prod_diag(fgraph, node): """ - This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its + diagonal elements. - The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. Parameters ---------- @@ -438,53 +445,40 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - potential_mul_input = node.inputs[0] - eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) - if eye_non_eye_inputs is None: + inputs = node.inputs[0] + + # Check for use of pt.diag first + if inputs.owner and isinstance(inputs.owner.op, AllocDiag2): + diag_input = inputs.owner.inputs[0] + det_val = diag_input.prod(axis=-1) + return [det_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: return None - eye_input, non_eye_inputs = eye_non_eye_inputs + eye_input, non_eye_inputs = inputs_or_none # Dealing with only one other input if len(non_eye_inputs) != 1: return None - useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] + eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] # Checking if original x was scalar/vector/matrix - if useful_non_eye.type.broadcastable[-2:] == (True, True): + if non_eye_input.type.broadcastable[-2:] == (True, True): # For scalar - det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) - elif useful_non_eye.type.broadcastable[-2:] == (False, False): + det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0]) + elif non_eye_input.type.broadcastable[-2:] == (False, False): # For Matrix - det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) + det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1) else: # For vector - det_val = useful_non_eye.prod(axis=(-1, -2)) + det_val = non_eye_input.prod(axis=(-1, -2)) det_val = det_val.astype(node.outputs[0].type.dtype) return [det_val] -arange = ARange("int64") -det_diag_from_diag = PatternNodeRewriter( - ( - det, - ( - advanced_set_subtensor, - (alloc, 0, "sh1", "sh2"), - "x", - (arange, 0, "stop", 1), - (arange, 0, "stop", 1), - ), - ), - (prod, "x"), - name="det_diag_from_diag", - allow_multiple_clients=True, -) -register_canonicalize(det_diag_from_diag) -register_stabilize(det_diag_from_diag) -register_specialize(det_diag_from_diag) - - @register_canonicalize @register_stabilize @register_specialize From b0abe170ad8d12a8708a2b267d32f1bfcd437c91 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 11 Jul 2024 00:19:28 +0800 Subject: [PATCH 03/33] Save arguments passed to `alloc_diag` as properties in `AllocDiag2` --- pytensor/tensor/basic.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7269e0421b..c3bb825cff 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3837,6 +3837,17 @@ class AllocDiag2(OpFromGraph): Wrapper Op for alloc_diag graphs """ + __props__ = ("offset", "axis1", "axis2", "inline") + + def __init__(self, *args, offset, axis1, axis2, **kwargs): + inline = kwargs.pop("inline", False) + self.offset = offset + self.axis1 = axis1 + self.axis2 = axis2 + self.inline = inline + + super().__init__(*args, **kwargs, strict=True, inline=inline) + def alloc_diag(diag, offset=0, axis1=0, axis2=1): """Insert a vector on the diagonal of a zero-ed matrix. @@ -3872,7 +3883,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] result = result.transpose(axes) - return AllocDiag2(inputs=[diag], outputs=[result])(diag) + return AllocDiag2( + inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2 + )(diag) def diag(v, k=0): From dbfe92c04361f270aa6036558627f40fe1ac2189 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 11 Jul 2024 00:20:18 +0800 Subject: [PATCH 04/33] Fix bug in `rewrite_det_diag_to_prod_diag` where batch case was incorrectly passing --- pytensor/tensor/rewriting/linalg.py | 48 ++++++++++++++++++++++----- tests/tensor/rewriting/test_linalg.py | 13 ++++++-- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2bf62ca985..a242fe7612 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -404,19 +404,44 @@ def _find_diag_from_eye_mul(potential_mul_input): eye_input = [ mul_input for mul_input in inputs_to_mul - if mul_input.owner and isinstance(mul_input.owner.op, Eye) + if mul_input.owner + and ( + isinstance(mul_input.owner.op, Eye) + or + # This whole condition checks if there is an Eye hiding inside a DimShuffle. + # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: + # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. + ( + isinstance(mul_input.owner.op, DimShuffle) + and mul_input.owner.inputs[0].owner is not None + and isinstance(mul_input.owner.inputs[0].owner.op, Eye) + ) + ) ] - - # Check if 1's are being put on the main diagonal only (k = 0) - if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: + if not eye_input: return None - # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite - if eye_input and eye_input[0].broadcastable[-2:] != (False, False): + eye_input = eye_input[0] + + # If this multiplication came from a batched operation, it will be wrapped in a DimShuffle + if isinstance(eye_input.owner.op, DimShuffle): + inner_eye = eye_input.owner.inputs[0] + if not isinstance(inner_eye.owner.op, Eye): + return None + # Check if 1's are being put on the main diagonal only (k = 0) + # and if the identity matrix is degenerate (column or row matrix) + if getattr( + inner_eye.owner.inputs[-1], "data", -1 + ).item() != 0 or inner_eye.broadcastable[-2:] != (False, False): + return None + + elif getattr( + eye_input.owner.inputs[-1], "data", -1 + ).item() != 0 or eye_input.broadcastable[-2:] != (False, False): return None # Get all non Eye inputs (scalars/matrices/vectors) - non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) return eye_input, non_eye_inputs @@ -448,15 +473,22 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): inputs = node.inputs[0] # Check for use of pt.diag first - if inputs.owner and isinstance(inputs.owner.op, AllocDiag2): + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag2) + and inputs.owner.op.offset == 0 + ): diag_input = inputs.owner.inputs[0] + diag_input.dprint() det_val = diag_input.prod(axis=-1) return [det_val] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: return None + eye_input, non_eye_inputs = inputs_or_none # Dealing with only one other input diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index d59e3cc88f..9353b7c167 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -396,20 +396,26 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): @pytest.mark.parametrize( "shape", - [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + [(), (7,), (1, 7), (7, 1), (7, 7), pytest.param((3, 7, 7))], ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], ) def test_det_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) y = pt.eye(7) * x + # Calculating determinant value using pt.linalg.det z_det = pt.linalg.det(y) # REWRITE TEST - f_rewritten = function([x], z_det, mode="FAST_RUN") + with pytensor.config.change_flags(optimizer_verbose=True): + f_rewritten = function([x], z_det, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes - assert not any(isinstance(node.op, Det) for node in nodes) + + assert not any( + isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det) + for node in nodes + ) # NUMERIC VALUE TEST if len(shape) == 0: @@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape): x_test = np.random.rand(*shape).astype(config.floatX) else: x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) From afe2a65c1fb3f573aa16880857b21d0d9038772e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 11 Jul 2024 00:21:30 +0800 Subject: [PATCH 05/33] Remove `AllocDiag2` from graphs in the `specialization` phase, after other rewrites that need it have fired. --- pytensor/link/jax/dispatch/tensor_basic.py | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..bcc884d223 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -3,11 +3,14 @@ import jax.numpy as jnp import numpy as np +import pytensor +from pytensor.graph import node_rewriter from pytensor.graph.basic import Constant from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( Alloc, + AllocDiag2, AllocEmpty, ARange, ExtractDiag, @@ -21,6 +24,7 @@ get_underlying_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.shape import Shape_i @@ -205,3 +209,28 @@ def tri(*args): return jnp.tri(*args, dtype=op.dtype) return tri + + +@register_specialize +@node_rewriter([AllocDiag2]) +def eagerly_inline_alloc_diag(fgraph, node): + """ + Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified + Parameters + ---------- + fgraph: FunctionGraph + The function graph being rewritten + node: Apply + Node of the function graph to be optimized + + Returns + ------- + + """ + [input] = node.inputs + [output] = node.op.inner_outputs + inner_input = output.owner.inputs[1] + + inline = pytensor.clone_replace(output, {inner_input: input}) + + return [inline] From e810df0c8b9238168d72585289350f571dbfea53 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 11 Jul 2024 02:04:49 +0800 Subject: [PATCH 06/33] Remove debugging code, formatting --- pytensor/link/jax/dispatch/tensor_basic.py | 1 + pytensor/tensor/rewriting/linalg.py | 1 - tests/tensor/rewriting/test_linalg.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bcc884d223..26c552baf1 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -216,6 +216,7 @@ def tri(*args): def eagerly_inline_alloc_diag(fgraph, node): """ Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified + Parameters ---------- fgraph: FunctionGraph diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a242fe7612..28c49a64ff 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -479,7 +479,6 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): and inputs.owner.op.offset == 0 ): diag_input = inputs.owner.inputs[0] - diag_input.dprint() det_val = diag_input.prod(axis=-1) return [det_val] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9353b7c167..eeaac53d82 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -396,7 +396,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): @pytest.mark.parametrize( "shape", - [(), (7,), (1, 7), (7, 1), (7, 7), pytest.param((3, 7, 7))], + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], ) def test_det_diag_from_eye_mul(shape): From a5355b6a7adc8426da6e5e44ac4429bb1d0163c4 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 12:46:34 +0800 Subject: [PATCH 07/33] Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag` --- pytensor/link/jax/dispatch/tensor_basic.py | 6 +- pytensor/tensor/basic.py | 109 +-------------------- pytensor/tensor/rewriting/linalg.py | 4 +- 3 files changed, 7 insertions(+), 112 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 26c552baf1..9525136b7d 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -10,7 +10,7 @@ from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( Alloc, - AllocDiag2, + AllocDiag, AllocEmpty, ARange, ExtractDiag, @@ -212,10 +212,10 @@ def tri(*args): @register_specialize -@node_rewriter([AllocDiag2]) +@node_rewriter([AllocDiag]) def eagerly_inline_alloc_diag(fgraph, node): """ - Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified + Inline `AllocDiag` OpFromGraph into the graph so the component Ops can themselves be jaxified Parameters ---------- diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index c3bb825cff..e67657fdf7 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3727,112 +3727,7 @@ def trace(a, offset=0, axis1=0, axis2=1): return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1) -class AllocDiag(Op): - """An `Op` that copies a vector to the diagonal of a zero-ed matrix.""" - - __props__ = ("offset", "axis1", "axis2") - - def __init__(self, offset=0, axis1=0, axis2=1): - """ - Parameters - ---------- - offset: int - Offset of the diagonal from the main diagonal defined by `axis1` - and `axis2`. Can be positive or negative. Defaults to main - diagonal (i.e. 0). - axis1: int - Axis to be used as the first axis of the 2-D sub-arrays to which - the diagonals will be allocated. Defaults to first axis (i.e. 0). - axis2: int - Axis to be used as the second axis of the 2-D sub-arrays to which - the diagonals will be allocated. Defaults to second axis (i.e. 1). - """ - warnings.warn( - "AllocDiag is deprecated. Use `alloc_diag` instead", - FutureWarning, - ) - self.offset = offset - if axis1 < 0 or axis2 < 0: - raise NotImplementedError("AllocDiag does not support negative axis") - if axis1 == axis2: - raise ValueError("axis1 and axis2 cannot be the same") - self.axis1 = axis1 - self.axis2 = axis2 - - def make_node(self, diag): - diag = as_tensor_variable(diag) - if diag.type.ndim < 1: - raise ValueError( - "AllocDiag needs an input with 1 or more dimensions", diag.type - ) - return Apply( - self, - [diag], - [diag.type.clone(shape=(None,) * (diag.ndim + 1))()], - ) - - def perform(self, node, inputs, outputs): - (x,) = inputs - (z,) = outputs - - axis1 = np.minimum(self.axis1, self.axis2) - axis2 = np.maximum(self.axis1, self.axis2) - offset = self.offset - - # Create array with one extra dimension for resulting matrix - result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2 - result = np.zeros(result_shape, dtype=x.dtype) - - # Create slice for diagonal in final 2 axes - idxs = np.arange(x.shape[-1]) - diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [ - idxs + np.maximum(0, -offset), - idxs + np.maximum(0, offset), - ] - - # Fill in final 2 axes with x - result[tuple(diagonal_slice)] = x - - if len(x.shape) > 1: - # Re-order axes so they correspond to diagonals at axis1, axis2 - axes = list(range(len(x.shape[:-1]))) - last_idx = axes[-1] - axes = axes[:axis1] + [last_idx + 1] + axes[axis1:] - axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] - result = result.transpose(axes) - - z[0] = result - - def grad(self, inputs, gout): - (gz,) = gout - return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)] - - def infer_shape(self, fgraph, nodes, shapes): - (x_shape,) = shapes - axis1 = np.minimum(self.axis1, self.axis2) - axis2 = np.maximum(self.axis1, self.axis2) - - result_shape = list(x_shape[:-1]) - diag_shape = x_shape[-1] + abs(self.offset) - result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:] - result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:] - return [tuple(result_shape)] - - def __setstate__(self, state): - if "view_map" in state: - del state["view_map"] - - self.__dict__.update(state) - - if "offset" not in state: - self.offset = 0 - if "axis1" not in state: - self.axis1 = 0 - if "axis2" not in state: - self.axis2 = 1 - - -class AllocDiag2(OpFromGraph): +class AllocDiag(OpFromGraph): """ Wrapper Op for alloc_diag graphs """ @@ -3883,7 +3778,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] result = result.transpose(axes) - return AllocDiag2( + return AllocDiag( inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2 )(diag) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 28c49a64ff..5be5b27d78 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -10,7 +10,7 @@ ) from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ( - AllocDiag2, + AllocDiag, Eye, TensorVariable, diagonal, @@ -475,7 +475,7 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): # Check for use of pt.diag first if ( inputs.owner - and isinstance(inputs.owner.op, AllocDiag2) + and isinstance(inputs.owner.op, AllocDiag) and inputs.owner.op.offset == 0 ): diag_input = inputs.owner.inputs[0] From 6e37d26aeed1cb13e24f41cc87bcdc71662ca694 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 12:51:31 +0800 Subject: [PATCH 08/33] Correctly register `eagerly_inline_alloc_diag` as a JAX-only rewrite --- pytensor/link/jax/dispatch/tensor_basic.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 9525136b7d..77de8bbf20 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -4,8 +4,11 @@ import numpy as np import pytensor +from pytensor.compile import optdb from pytensor.graph import node_rewriter from pytensor.graph.basic import Constant +from pytensor.graph.rewriting.basic import in2out +from pytensor.graph.rewriting.db import SequenceDB from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( @@ -24,7 +27,6 @@ get_underlying_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.shape import Shape_i @@ -211,7 +213,6 @@ def tri(*args): return tri -@register_specialize @node_rewriter([AllocDiag]) def eagerly_inline_alloc_diag(fgraph, node): """ @@ -235,3 +236,14 @@ def eagerly_inline_alloc_diag(fgraph, node): inline = pytensor.clone_replace(output, {inner_input: input}) return [inline] + + +remove_alloc_ofg_opt = SequenceDB() +remove_alloc_ofg_opt.register( + "inline_alloc_diag", + in2out(eagerly_inline_alloc_diag), + "jax", +) + +# Do this right away so other JAX rewrites can act on the inner graph +optdb.register("jax_inline_alloc_diag", remove_alloc_ofg_opt, "jax", position=0) From f6f27ecd243dd1a585a592a99eaa32d0ae68ae29 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 13:17:00 +0800 Subject: [PATCH 09/33] Use `self` (not `type(self)`) in `OpFromGraph.make_node` --- pytensor/compile/builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index b535540425..1c8d2ca543 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -758,7 +758,7 @@ def make_node(self, *inputs): ] new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs - new_op = type(self)( + new_op = self( inputs=new_inner_inputs, outputs=new_inner_outputs, inline=self.is_inline, From 9486c64f2af0a9a84c074e1d0c487857414bed12 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 13:22:11 +0800 Subject: [PATCH 10/33] Use base class `OpFromGraph` when constructing `OpFromGraph` gradients --- pytensor/compile/builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 1c8d2ca543..b9f3f761be 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -575,7 +575,7 @@ def lop_overrides(inps, grads): for inp_grad in input_grads if not isinstance(inp_grad.type, DisconnectedType | NullType) ] - lop_op = type(self)( + lop_op = OpFromGraph( inputs=inner_inputs + connected_inner_outputs + connected_output_grads, outputs=connected_input_grads, inline=self.is_inline, @@ -669,7 +669,7 @@ def _build_and_cache_rop_op(self): for out_grad in output_grads if not isinstance(out_grad.type, DisconnectedType | NullType) ] - rop_op = type(self)( + rop_op = OpFromGraph( inputs=inner_inputs + eval_points, outputs=filtered_output_grads, inline=self.is_inline, From abcde3f4091a64068c822fe2128619dc7a1ee9a2 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 13:39:25 +0800 Subject: [PATCH 11/33] Revert "Use `self` (not `type(self)`) in `OpFromGraph.make_node`" This reverts commit f6f27ecd243dd1a585a592a99eaa32d0ae68ae29. --- pytensor/compile/builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index b9f3f761be..3221586df0 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -758,7 +758,7 @@ def make_node(self, *inputs): ] new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs - new_op = self( + new_op = type(self)( inputs=new_inner_inputs, outputs=new_inner_outputs, inline=self.is_inline, From 66013fcecc6589d5996ec278c2a34c4a0b4ebb2f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 14:45:19 +0800 Subject: [PATCH 12/33] Solve XY problem --- pytensor/link/jax/dispatch/basic.py | 23 ++++++++++++ pytensor/link/jax/dispatch/tensor_basic.py | 42 ---------------------- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index b35759f837..0450f0201b 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -5,6 +5,8 @@ import jax.numpy as jnp import numpy as np +from pytensor.compile import JAX +from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph @@ -114,3 +116,24 @@ def viewop(x): return x return viewop + + +@jax_funcify.register(OpFromGraph) +def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> callable: + _ = kwargs.pop("storage_map", None) + + # Apply inner rewrites + JAX.optimizer(ofg.fgraph) + fgraph_fn = jax_funcify(ofg.fgraph, **kwargs) + + if len(ofg.fgraph.outputs) == 1: + + def opfromgraph(*inputs): + return fgraph_fn(*inputs)[0] + + else: + + def opfromgraph(*inputs): + return fgraph_fn(*inputs) + + return opfromgraph diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 77de8bbf20..bf1a93ce5b 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -3,17 +3,11 @@ import jax.numpy as jnp import numpy as np -import pytensor -from pytensor.compile import optdb -from pytensor.graph import node_rewriter from pytensor.graph.basic import Constant -from pytensor.graph.rewriting.basic import in2out -from pytensor.graph.rewriting.db import SequenceDB from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( Alloc, - AllocDiag, AllocEmpty, ARange, ExtractDiag, @@ -211,39 +205,3 @@ def tri(*args): return jnp.tri(*args, dtype=op.dtype) return tri - - -@node_rewriter([AllocDiag]) -def eagerly_inline_alloc_diag(fgraph, node): - """ - Inline `AllocDiag` OpFromGraph into the graph so the component Ops can themselves be jaxified - - Parameters - ---------- - fgraph: FunctionGraph - The function graph being rewritten - node: Apply - Node of the function graph to be optimized - - Returns - ------- - - """ - [input] = node.inputs - [output] = node.op.inner_outputs - inner_input = output.owner.inputs[1] - - inline = pytensor.clone_replace(output, {inner_input: input}) - - return [inline] - - -remove_alloc_ofg_opt = SequenceDB() -remove_alloc_ofg_opt.register( - "inline_alloc_diag", - in2out(eagerly_inline_alloc_diag), - "jax", -) - -# Do this right away so other JAX rewrites can act on the inner graph -optdb.register("jax_inline_alloc_diag", remove_alloc_ofg_opt, "jax", position=0) From e7583d192688e24ad0be4f26c9adae4245e580cc Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 14:48:23 +0800 Subject: [PATCH 13/33] Appease mypy --- pytensor/link/jax/dispatch/basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index 0450f0201b..bd559ee716 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Callable from functools import singledispatch import jax @@ -119,7 +120,7 @@ def viewop(x): @jax_funcify.register(OpFromGraph) -def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> callable: +def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: _ = kwargs.pop("storage_map", None) # Apply inner rewrites From 70b9cd6aa1b4281f12921a7cb295b0e6a6ca7543 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 14:52:33 +0800 Subject: [PATCH 14/33] Remove `inline` prop from wrapper class and set `inline=True` --- pytensor/tensor/basic.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e67657fdf7..f2a13ab35a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3732,16 +3732,14 @@ class AllocDiag(OpFromGraph): Wrapper Op for alloc_diag graphs """ - __props__ = ("offset", "axis1", "axis2", "inline") + __props__ = ("offset", "axis1", "axis2") def __init__(self, *args, offset, axis1, axis2, **kwargs): - inline = kwargs.pop("inline", False) self.offset = offset self.axis1 = axis1 self.axis2 = axis2 - self.inline = inline - super().__init__(*args, **kwargs, strict=True, inline=inline) + super().__init__(*args, **kwargs, strict=True) def alloc_diag(diag, offset=0, axis1=0, axis2=1): @@ -3779,7 +3777,12 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): result = result.transpose(axes) return AllocDiag( - inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2 + inputs=[diag], + outputs=[result], + offset=offset, + axis1=axis1, + axis2=axis2, + inline=True, )(diag) From 9605a0e1fe48027483acc617850a1a7a09be890d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Fri, 12 Jul 2024 15:21:48 +0800 Subject: [PATCH 15/33] Set `inline = False` --- pytensor/tensor/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f2a13ab35a..fdd9bc7dec 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3782,7 +3782,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): offset=offset, axis1=axis1, axis2=axis2, - inline=True, + inline=False, )(diag) From 46fbc5504a25ff3fe157051c9bef46f890a95ddc Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 16 Jul 2024 18:01:39 +0800 Subject: [PATCH 16/33] Add rewrite to inline all `OpFromGraph` `Op`s --- pytensor/tensor/basic.py | 6 ++-- pytensor/tensor/rewriting/__init__.py | 1 + pytensor/tensor/rewriting/ofg.py | 41 ++++++++++++++++++++++++ tests/tensor/rewriting/test_ofg.py | 46 +++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 pytensor/tensor/rewriting/ofg.py create mode 100644 tests/tensor/rewriting/test_ofg.py diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index fdd9bc7dec..334e78dfee 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3732,10 +3732,9 @@ class AllocDiag(OpFromGraph): Wrapper Op for alloc_diag graphs """ - __props__ = ("offset", "axis1", "axis2") + __props__ = ("axis1", "axis2", "is_inline") - def __init__(self, *args, offset, axis1, axis2, **kwargs): - self.offset = offset + def __init__(self, *args, axis1, axis2, **kwargs): self.axis1 = axis1 self.axis2 = axis2 @@ -3782,7 +3781,6 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): offset=offset, axis1=axis1, axis2=axis2, - inline=False, )(diag) diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 617eab04fa..168b636041 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math +import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.subtensor diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py new file mode 100644 index 0000000000..641e15be47 --- /dev/null +++ b/pytensor/tensor/rewriting/ofg.py @@ -0,0 +1,41 @@ +from pytensor.compile.builders import OpFromGraph +from pytensor.graph import node_rewriter +from pytensor.tensor.rewriting.basic import register_specialize + + +@register_specialize +@node_rewriter([OpFromGraph]) +def inline_OpFromGraph(fgraph, node): + """ + Inline `OpFromGraph` nodes. + + OpFromGraph nodes are used to compactly represent the output of a function graph. Certain `Ops`, like, einsum, + diag, and kron, are implemented using pytensor `Op`s. As a result, their outputs are not a single `Op`, but a + graph. To allow rewrites to easily spot and manipulate these "composite functions", we use the `OpFromGraph` node. + This node is a thin wrapper around the output graph. It is not, however, meant to be included in the final + program, because it hides the inner graph from certain optimizations. + + This rewrite specifies that all `OpFromGraph` nodes should be replaced by their inner graphs by setting the + `inplace=True` flag. + + Parameters + ---------- + fgraph: FunctionGraph + The function graph being rewritten + node: Apply + Node of the function graph to be optimized + + Returns + ------- + + """ + ofg = node.op + + if ofg.is_inline: + return None + + inputs = node.inputs + ofg.is_inline = True + new_node = ofg(*inputs) + + return [new_node] diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py new file mode 100644 index 0000000000..34041fcd6b --- /dev/null +++ b/tests/tensor/rewriting/test_ofg.py @@ -0,0 +1,46 @@ +import pytensor +import pytensor.tensor as pt +from pytensor.compile.builders import OpFromGraph + + +def test_OpFromGraph_inlined(): + x = pt.tensor("x", shape=(None,)) + z = x**2 + fx = OpFromGraph([x], [z], inline=False)(x) + f2 = pytensor.function([x], fx) + + nodes = f2.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, OpFromGraph) for node in nodes) + assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + + +def test_inherited_ofg_class_inlined(): + x = pt.tensor("x", shape=(None,)) + + # pt.diag calls AllocDiag, which inherits from OpFromGrpah + z = pt.diag(x) + + f = pytensor.function([x], z) + pytensor.dprint(f) + + nodes = f.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, OpFromGraph) for node in nodes) + assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + + +def test_several_ofg_inlined(): + x = pt.tensor("x", shape=(None,)) + y = pt.diag(x) + + # pt.linalg.kron also inherits from OpFromGraph + z = pt.linalg.kron(y, pt.eye(2)) + + f = pytensor.function([x], z) + pytensor.dprint(f) + + nodes = f.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, OpFromGraph) for node in nodes) + assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) From 98cf64111005879cd7aa3e13d4f4f4b1433bbf12 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 16 Jul 2024 18:13:09 +0800 Subject: [PATCH 17/33] Allow symbolic `offset` --- pytensor/tensor/basic.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 334e78dfee..7d13bccd3e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3747,8 +3747,11 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): diagonal(alloc_diag(x)) == x """ from pytensor.tensor import set_subtensor + from pytensor.tensor.math import maximum diag = as_tensor_variable(diag) + offset = as_tensor_variable(offset) + axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1) if axis1 > axis2: axis1, axis2 = axis2, axis1 @@ -3760,8 +3763,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): # Create slice for diagonal in final 2 axes idxs = arange(diag.shape[-1]) diagonal_slice = (slice(None),) * (len(result_shape) - 2) + ( - idxs + np.maximum(0, -offset), - idxs + np.maximum(0, offset), + idxs + maximum(0, -offset), + idxs + maximum(0, offset), ) # Fill in final 2 axes with diag @@ -3776,12 +3779,11 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): result = result.transpose(axes) return AllocDiag( - inputs=[diag], + inputs=[diag, offset], outputs=[result], - offset=offset, axis1=axis1, axis2=axis2, - )(diag) + )(diag, offset) def diag(v, k=0): From c76b54e9942f2ebff68489248781d58c36dcda2c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 16 Jul 2024 18:47:09 +0800 Subject: [PATCH 18/33] Exclude inline rewrite in JAX mode --- pytensor/compile/mode.py | 7 +------ pytensor/tensor/rewriting/ofg.py | 2 +- tests/tensor/rewriting/test_ofg.py | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 16019d4187..ac375f1d70 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -454,12 +454,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAXLinker(), RewriteDatabaseQuery( include=["fast_run", "jax"], - exclude=[ - "cxx_only", - "BlasOpt", - "fusion", - "inplace", - ], + exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "opfromgraph"], ), ) PYTORCH = Mode( diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 641e15be47..978d1ec5f9 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -3,7 +3,7 @@ from pytensor.tensor.rewriting.basic import register_specialize -@register_specialize +@register_specialize("opfromgraph") @node_rewriter([OpFromGraph]) def inline_OpFromGraph(fgraph, node): """ diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 34041fcd6b..36c753629b 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -1,3 +1,5 @@ +import pytest + import pytensor import pytensor.tensor as pt from pytensor.compile.builders import OpFromGraph @@ -44,3 +46,17 @@ def test_several_ofg_inlined(): assert any(isinstance(node.op, OpFromGraph) for node in nodes) assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + + +def test_ofg_not_inlined_in_JAX_mode(): + pytest.importorskip("jax") + + x = pt.tensor("x", shape=(None,)) + y = pt.diag(x) + + f = pytensor.function([x], y, mode="JAX") + nodes = f.maker.fgraph.apply_nodes + assert any(isinstance(node.op, OpFromGraph) for node in nodes) + assert not any( + node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph) + ) From 7b44e22f6afc3d57a14b668018ff32c45544a835 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 16 Jul 2024 23:03:52 +0800 Subject: [PATCH 19/33] refactor `late_inline_ofg` rewrite to actually perform the correct rewrite --- pytensor/compile/mode.py | 2 +- pytensor/tensor/rewriting/ofg.py | 18 +++++++++--------- tests/tensor/rewriting/test_ofg.py | 10 ++++------ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ac375f1d70..82a1f0ec1e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -454,7 +454,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAXLinker(), RewriteDatabaseQuery( include=["fast_run", "jax"], - exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "opfromgraph"], + exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "inline_ofg"], ), ) PYTORCH = Mode( diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 978d1ec5f9..7213328cfb 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,11 +1,12 @@ +from pytensor import clone_replace from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter from pytensor.tensor.rewriting.basic import register_specialize -@register_specialize("opfromgraph") +@register_specialize("inline_ofg") @node_rewriter([OpFromGraph]) -def inline_OpFromGraph(fgraph, node): +def late_inline_OpFromGraph(fgraph, node): """ Inline `OpFromGraph` nodes. @@ -29,13 +30,12 @@ def inline_OpFromGraph(fgraph, node): ------- """ - ofg = node.op + op = node.op - if ofg.is_inline: - return None + if not isinstance(op, OpFromGraph): + return False - inputs = node.inputs - ofg.is_inline = True - new_node = ofg(*inputs) + if op.is_inline: + return None - return [new_node] + return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 36c753629b..0b332c9085 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -10,11 +10,11 @@ def test_OpFromGraph_inlined(): z = x**2 fx = OpFromGraph([x], [z], inline=False)(x) f2 = pytensor.function([x], fx) + pytensor.dprint(f2) nodes = f2.maker.fgraph.apply_nodes - assert any(isinstance(node.op, OpFromGraph) for node in nodes) - assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + assert not any(isinstance(node.op, OpFromGraph) for node in nodes) def test_inherited_ofg_class_inlined(): @@ -28,8 +28,7 @@ def test_inherited_ofg_class_inlined(): nodes = f.maker.fgraph.apply_nodes - assert any(isinstance(node.op, OpFromGraph) for node in nodes) - assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + assert not any(isinstance(node.op, OpFromGraph) for node in nodes) def test_several_ofg_inlined(): @@ -44,8 +43,7 @@ def test_several_ofg_inlined(): nodes = f.maker.fgraph.apply_nodes - assert any(isinstance(node.op, OpFromGraph) for node in nodes) - assert all(node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph)) + assert not any(isinstance(node.op, OpFromGraph) for node in nodes) def test_ofg_not_inlined_in_JAX_mode(): From 1dfc3fc2096dfdfd7e625749bde146599e8ecde0 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 00:03:03 +0800 Subject: [PATCH 20/33] Narrow scope of `late_inline` rewrite --- pytensor/tensor/rewriting/ofg.py | 4 +++- tests/tensor/rewriting/test_ofg.py | 12 ------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 7213328cfb..3f0a5f688a 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,6 +1,8 @@ from pytensor import clone_replace from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter +from pytensor.tensor.basic import AllocDiag +from pytensor.tensor.nlinalg import KroneckerProduct from pytensor.tensor.rewriting.basic import register_specialize @@ -32,7 +34,7 @@ def late_inline_OpFromGraph(fgraph, node): """ op = node.op - if not isinstance(op, OpFromGraph): + if not isinstance(op, AllocDiag | KroneckerProduct): return False if op.is_inline: diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 0b332c9085..710e853971 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -5,18 +5,6 @@ from pytensor.compile.builders import OpFromGraph -def test_OpFromGraph_inlined(): - x = pt.tensor("x", shape=(None,)) - z = x**2 - fx = OpFromGraph([x], [z], inline=False)(x) - f2 = pytensor.function([x], fx) - pytensor.dprint(f2) - - nodes = f2.maker.fgraph.apply_nodes - - assert not any(isinstance(node.op, OpFromGraph) for node in nodes) - - def test_inherited_ofg_class_inlined(): x = pt.tensor("x", shape=(None,)) From 9f32661dcbbd7bdae1eca2519dc3be4b2cf64e26 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 08:33:10 +0800 Subject: [PATCH 21/33] Fix tests --- pytensor/tensor/rewriting/linalg.py | 25 +++++++++++++++++++++++-- tests/tensor/rewriting/test_linalg.py | 8 +++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5be5b27d78..113cd9985d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -3,7 +3,7 @@ from typing import cast from pytensor import Variable -from pytensor.graph import Apply, FunctionGraph +from pytensor.graph import Apply, Constant, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, node_rewriter, @@ -445,6 +445,27 @@ def _find_diag_from_eye_mul(potential_mul_input): return eye_input, non_eye_inputs +def is_offset_zero(node) -> bool: + """ + Test if an AllocDiag Op has a diagonal offset of zero + + Parameters + ---------- + node + AllocDiag node to test + + Returns + ------- + is_offset_zero: bool + True if the offset is zero (``k = 0``). + """ + if not isinstance(node.op, AllocDiag): + return False + + offset = node.inputs[-1] + return isinstance(offset, Constant) and offset.data.item() == 0 + + @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @node_rewriter([det]) @@ -476,7 +497,7 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): if ( inputs.owner and isinstance(inputs.owner.op, AllocDiag) - and inputs.owner.op.offset == 0 + and is_offset_zero(inputs.owner) ): diag_input = inputs.owner.inputs[0] det_val = diag_input.prod(axis=-1) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index eeaac53d82..b16985cb1c 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -368,9 +368,15 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): A, B = list(map(constructor, "ab")) X = f(g(A, B)) + # Exclude inline_ofg so we can count KroneckerProduct Ops (these are typically annihilated by rewrites) f1 = pytensor.function( - [A, B], X, mode=get_default_mode().including("local_lift_through_linalg") + [A, B], + X, + mode=get_default_mode() + .including("local_lift_through_linalg") + .excluding("inline_ofg"), ) + f2 = pytensor.function( [A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg") ) From f30a63f3a9f3a35d5916ba871d0272b23e7eff91 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 08:33:55 +0800 Subject: [PATCH 22/33] Remove `is_inline` prop --- pytensor/tensor/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d13bccd3e..3cde6d5f9c 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3732,7 +3732,7 @@ class AllocDiag(OpFromGraph): Wrapper Op for alloc_diag graphs """ - __props__ = ("axis1", "axis2", "is_inline") + __props__ = ("axis1", "axis2") def __init__(self, *args, axis1, axis2, **kwargs): self.axis1 = axis1 From bf705f9497e69ddc5d32851a9292070a2a7e491f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 12:19:38 +0800 Subject: [PATCH 23/33] Add JAX `OpFromGraph` test --- pytensor/tensor/rewriting/ofg.py | 10 +--------- tests/link/jax/test_basic.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 3f0a5f688a..fa53c4a343 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,5 +1,4 @@ from pytensor import clone_replace -from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter from pytensor.tensor.basic import AllocDiag from pytensor.tensor.nlinalg import KroneckerProduct @@ -7,7 +6,7 @@ @register_specialize("inline_ofg") -@node_rewriter([OpFromGraph]) +@node_rewriter([AllocDiag, KroneckerProduct]) def late_inline_OpFromGraph(fgraph, node): """ Inline `OpFromGraph` nodes. @@ -33,11 +32,4 @@ def late_inline_OpFromGraph(fgraph, node): """ op = node.op - - if not isinstance(op, AllocDiag | KroneckerProduct): - return False - - if op.is_inline: - return None - return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 76c8b4b329..5cd2bd54c6 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -13,7 +14,7 @@ from pytensor.graph.op import Op, get_test_value from pytensor.ifelse import ifelse from pytensor.raise_op import assert_op -from pytensor.tensor.type import dscalar, scalar, vector +from pytensor.tensor.type import dscalar, matrices, scalar, vector @pytest.fixture(scope="module", autouse=True) @@ -209,3 +210,19 @@ def test_jax_checkandraise(): def set_test_value(x, v): x.tag.test_value = v return x + + +def test_OpFromGraph(): + x, y, z = matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y], inline=False) + ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False) + + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 + out_fg = FunctionGraph([x, y, z], [out]) + + xv = np.ones((2, 2), dtype=config.floatX) + yv = np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + compare_jax_and_py(out_fg, [xv, yv, zv]) From c8958a4ac9302e01c5f8be2daae250079c60c7c1 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 12:20:13 +0800 Subject: [PATCH 24/33] Don't omit `inline_ofg` rewrites in JAX mode --- pytensor/compile/mode.py | 2 +- tests/tensor/rewriting/test_ofg.py | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 82a1f0ec1e..dcdc227141 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -454,7 +454,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAXLinker(), RewriteDatabaseQuery( include=["fast_run", "jax"], - exclude=["cxx_only", "BlasOpt", "fusion", "inplace", "inline_ofg"], + exclude=["cxx_only", "BlasOpt", "fusion", "inplace"], ), ) PYTORCH = Mode( diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 710e853971..7f1a1343ff 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -1,5 +1,3 @@ -import pytest - import pytensor import pytensor.tensor as pt from pytensor.compile.builders import OpFromGraph @@ -32,17 +30,3 @@ def test_several_ofg_inlined(): nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, OpFromGraph) for node in nodes) - - -def test_ofg_not_inlined_in_JAX_mode(): - pytest.importorskip("jax") - - x = pt.tensor("x", shape=(None,)) - y = pt.diag(x) - - f = pytensor.function([x], y, mode="JAX") - nodes = f.maker.fgraph.apply_nodes - assert any(isinstance(node.op, OpFromGraph) for node in nodes) - assert not any( - node.op.is_inline for node in nodes if isinstance(node.op, OpFromGraph) - ) From 665c766a01772fb46105f269bffc7f54878f2d45 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 12:38:21 +0800 Subject: [PATCH 25/33] Don't inline `KroneckerProduct` --- pytensor/tensor/rewriting/ofg.py | 3 +-- tests/tensor/rewriting/test_linalg.py | 6 +----- tests/tensor/rewriting/test_ofg.py | 19 +------------------ 3 files changed, 3 insertions(+), 25 deletions(-) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index fa53c4a343..b0ace33b75 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,12 +1,11 @@ from pytensor import clone_replace from pytensor.graph import node_rewriter from pytensor.tensor.basic import AllocDiag -from pytensor.tensor.nlinalg import KroneckerProduct from pytensor.tensor.rewriting.basic import register_specialize @register_specialize("inline_ofg") -@node_rewriter([AllocDiag, KroneckerProduct]) +@node_rewriter([AllocDiag]) def late_inline_OpFromGraph(fgraph, node): """ Inline `OpFromGraph` nodes. diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index b16985cb1c..10f3d6a4a2 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -370,11 +370,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): # Exclude inline_ofg so we can count KroneckerProduct Ops (these are typically annihilated by rewrites) f1 = pytensor.function( - [A, B], - X, - mode=get_default_mode() - .including("local_lift_through_linalg") - .excluding("inline_ofg"), + [A, B], X, mode=get_default_mode().including("local_lift_through_linalg") ) f2 = pytensor.function( diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 7f1a1343ff..79bfe2eeff 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -3,29 +3,12 @@ from pytensor.compile.builders import OpFromGraph -def test_inherited_ofg_class_inlined(): +def test_alloc_diag_inlined(): x = pt.tensor("x", shape=(None,)) # pt.diag calls AllocDiag, which inherits from OpFromGrpah z = pt.diag(x) - - f = pytensor.function([x], z) - pytensor.dprint(f) - - nodes = f.maker.fgraph.apply_nodes - - assert not any(isinstance(node.op, OpFromGraph) for node in nodes) - - -def test_several_ofg_inlined(): - x = pt.tensor("x", shape=(None,)) - y = pt.diag(x) - - # pt.linalg.kron also inherits from OpFromGraph - z = pt.linalg.kron(y, pt.eye(2)) - f = pytensor.function([x], z) - pytensor.dprint(f) nodes = f.maker.fgraph.apply_nodes From 6487e61dcb43cf723b516f704f067cc53290fe14 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 12:53:56 +0800 Subject: [PATCH 26/33] Skip inline rewrite tests when `mode == FAST_COMPILE` --- tests/tensor/rewriting/test_ofg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 79bfe2eeff..9ff1a29d0a 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -1,8 +1,15 @@ +import pytest + import pytensor import pytensor.tensor as pt +from pytensor import config from pytensor.compile.builders import OpFromGraph +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="Rewrite is not applied in FAST_COMPILE mode", +) def test_alloc_diag_inlined(): x = pt.tensor("x", shape=(None,)) From 43474fd86b35e58a7eaa974d9c5c5c1b467bd267 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 14:49:01 +0800 Subject: [PATCH 27/33] Incorporate review feedback --- pytensor/compile/builders.py | 28 -------------------- pytensor/compile/mode.py | 7 ++++- pytensor/tensor/basic.py | 19 ++++++++++++++ pytensor/tensor/rewriting/linalg.py | 31 +++++----------------- pytensor/tensor/rewriting/ofg.py | 38 ++++++++++++++++++++++++++- tests/tensor/rewriting/test_linalg.py | 4 +-- tests/tensor/rewriting/test_ofg.py | 4 +-- 7 files changed, 71 insertions(+), 60 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 3221586df0..9087e718b5 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -8,7 +8,6 @@ from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared -from pytensor.compile.mode import optdb from pytensor.compile.sharedvalue import SharedVariable from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, Rop, grad @@ -24,7 +23,6 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.utils import MissingInputError @@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs): assert len(variables) == len(outputs) for output, variable in zip(outputs, variables): output[0] = variable - - -@node_rewriter([OpFromGraph]) -def inline_ofg_expansion(fgraph, node): - """ - This optimization expands internal graph of OpFromGraph. - Only performed if node.op.is_inline == True - Doing so can improve optimization at the cost of compilation speed. - """ - op = node.op - if not isinstance(op, OpFromGraph): - return False - if not op.is_inline: - return False - return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) - - -# We want to run this before the first merge optimizer -# and before the first scan optimizer. -optdb.register( - "inline_ofg_expansion", - in2out(inline_ofg_expansion), - "fast_compile", - "fast_run", - position=-0.01, -) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index dcdc227141..16019d4187 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -454,7 +454,12 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAXLinker(), RewriteDatabaseQuery( include=["fast_run", "jax"], - exclude=["cxx_only", "BlasOpt", "fusion", "inplace"], + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + ], ), ) PYTORCH = Mode( diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 3cde6d5f9c..eb7156c799 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3740,6 +3740,25 @@ def __init__(self, *args, axis1, axis2, **kwargs): super().__init__(*args, **kwargs, strict=True) + @staticmethod + def is_offset_zero(node) -> bool: + """ + Test if an AllocDiag Op has a diagonal offset of zero + + Parameters + ---------- + node + AllocDiag node to test + + Returns + ------- + is_offset_zero: bool + True if the offset is zero (``k = 0``). + """ + + offset = node.inputs[-1] + return isinstance(offset, Constant) and offset.data.item() == 0 + def alloc_diag(diag, offset=0, axis1=0, axis2=1): """Insert a vector on the diagonal of a zero-ed matrix. diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 113cd9985d..295ddb1328 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -430,9 +430,11 @@ def _find_diag_from_eye_mul(potential_mul_input): return None # Check if 1's are being put on the main diagonal only (k = 0) # and if the identity matrix is degenerate (column or row matrix) - if getattr( - inner_eye.owner.inputs[-1], "data", -1 - ).item() != 0 or inner_eye.broadcastable[-2:] != (False, False): + if not ( + isinstance(inner_eye.owner.inputs[-1], Constant) + and inner_eye.owner.inputs[-1].data == 0 + and inner_eye.broadcastable[-1:] != (False, False) + ): return None elif getattr( @@ -445,27 +447,6 @@ def _find_diag_from_eye_mul(potential_mul_input): return eye_input, non_eye_inputs -def is_offset_zero(node) -> bool: - """ - Test if an AllocDiag Op has a diagonal offset of zero - - Parameters - ---------- - node - AllocDiag node to test - - Returns - ------- - is_offset_zero: bool - True if the offset is zero (``k = 0``). - """ - if not isinstance(node.op, AllocDiag): - return False - - offset = node.inputs[-1] - return isinstance(offset, Constant) and offset.data.item() == 0 - - @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @node_rewriter([det]) @@ -497,7 +478,7 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): if ( inputs.owner and isinstance(inputs.owner.op, AllocDiag) - and is_offset_zero(inputs.owner) + and AllocDiag.is_offset_zero(inputs.owner) ): diag_input = inputs.owner.inputs[0] det_val = diag_input.prod(axis=-1) diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index b0ace33b75..d0bac78782 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,9 +1,42 @@ from pytensor import clone_replace +from pytensor.compile import optdb +from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out from pytensor.tensor.basic import AllocDiag from pytensor.tensor.rewriting.basic import register_specialize +@node_rewriter([OpFromGraph]) +def inline_ofg_expansion(fgraph, node): + """ + This optimization expands internal graph of OpFromGraph. + Only performed if node.op.is_inline == True + Doing so can improve optimization at the cost of compilation speed. + """ + op = node.op + if not isinstance(op, OpFromGraph): + return False + if not op.is_inline: + return False + + new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + copy_stack_trace(op.inner_outputs, new_out) + + return new_out + + +# We want to run this before the first merge optimizer +# and before the first scan optimizer. +optdb.register( + "inline_ofg_expansion", + in2out(inline_ofg_expansion), + "fast_compile", + "fast_run", + position=-0.01, +) + + @register_specialize("inline_ofg") @node_rewriter([AllocDiag]) def late_inline_OpFromGraph(fgraph, node): @@ -31,4 +64,7 @@ def late_inline_OpFromGraph(fgraph, node): """ op = node.op - return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + copy_stack_trace(op.inner_outputs, new_out) + + return new_out diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 10f3d6a4a2..c146a5b096 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -368,7 +368,6 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): A, B = list(map(constructor, "ab")) X = f(g(A, B)) - # Exclude inline_ofg so we can count KroneckerProduct Ops (these are typically annihilated by rewrites) f1 = pytensor.function( [A, B], X, mode=get_default_mode().including("local_lift_through_linalg") ) @@ -410,8 +409,7 @@ def test_det_diag_from_eye_mul(shape): z_det = pt.linalg.det(y) # REWRITE TEST - with pytensor.config.change_flags(optimizer_verbose=True): - f_rewritten = function([x], z_det, mode="FAST_RUN") + f_rewritten = function([x], z_det, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert not any( diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 9ff1a29d0a..6304939562 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -13,10 +13,10 @@ def test_alloc_diag_inlined(): x = pt.tensor("x", shape=(None,)) - # pt.diag calls AllocDiag, which inherits from OpFromGrpah z = pt.diag(x) - f = pytensor.function([x], z) + assert isinstance(z.owner.op, OpFromGraph) + f = pytensor.function([x], z) nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, OpFromGraph) for node in nodes) From 6ef408430f33aa596c65f2f3f9df98d5acd29d1d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 15:17:28 +0800 Subject: [PATCH 28/33] Incorporate review feedback --- pytensor/tensor/elemwise.py | 10 ++++++++++ pytensor/tensor/rewriting/linalg.py | 9 ++++++++- pytensor/tensor/rewriting/ofg.py | 2 -- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 2fdc8e7fd5..386432e2f8 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -181,6 +181,16 @@ def __init__(self, input_broadcastable, new_order): self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"]) self.drop = drop + input_ndim = len(input_broadcastable) + + self.is_left_expand_dims = "x" in new_order and ( + input_ndim == 0 or new_order[:input_ndim] == list(range(input_ndim)) + ) + + self.is_right_expand_dims = "x" in new_order and ( + input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) + ) + if self.inplace: self.view_map = {0: [0]} diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 295ddb1328..8d6c191c47 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -413,6 +413,10 @@ def _find_diag_from_eye_mul(potential_mul_input): # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. ( isinstance(mul_input.owner.op, DimShuffle) + and ( + mul_input.owner.op.is_left_expand_dims + or mul_input.owner.op.is_right_expand_dims + ) and mul_input.owner.inputs[0].owner is not None and isinstance(mul_input.owner.inputs[0].owner.op, Eye) ) @@ -424,7 +428,10 @@ def _find_diag_from_eye_mul(potential_mul_input): eye_input = eye_input[0] # If this multiplication came from a batched operation, it will be wrapped in a DimShuffle - if isinstance(eye_input.owner.op, DimShuffle): + if isinstance(eye_input.owner.op, DimShuffle) and ( + eye_input.owner.op.is_left_expand_dims + or eye_input.owner.op.is_right_expand_dims + ): inner_eye = eye_input.owner.inputs[0] if not isinstance(inner_eye.owner.op, Eye): return None diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index d0bac78782..265f3ff2e8 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -15,8 +15,6 @@ def inline_ofg_expansion(fgraph, node): Doing so can improve optimization at the cost of compilation speed. """ op = node.op - if not isinstance(op, OpFromGraph): - return False if not op.is_inline: return False From 04ddb466101c6ce228b759cdf8f5cfd81fbaca87 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 17 Jul 2024 15:24:16 +0800 Subject: [PATCH 29/33] Add `is_zero_offset` helper to `Eye` --- pytensor/tensor/basic.py | 19 +++++++++++++++++++ pytensor/tensor/rewriting/linalg.py | 12 ++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index eb7156c799..ac4c169ebc 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1334,6 +1334,25 @@ def infer_shape(self, fgraph, node, in_shapes): def grad(self, inp, grads): return [grad_undefined(self, i, inp[i]) for i in range(3)] + @staticmethod + def is_offset_zero(node) -> bool: + """ + Test if an Eye Op has a diagonal offset of zero + + Parameters + ---------- + node + Eye node to test + + Returns + ------- + is_offset_zero: bool + True if the offset is zero (``k = 0``). + """ + + offset = node.inputs[-1] + return isinstance(offset, Constant) and offset.data.item() == 0 + def eye(n, m=None, k=0, dtype=None): """Return a 2-D array with ones on the diagonal and zeros elsewhere. diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 8d6c191c47..61ddbb3ef3 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -3,7 +3,7 @@ from typing import cast from pytensor import Variable -from pytensor.graph import Apply, Constant, FunctionGraph +from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, node_rewriter, @@ -438,15 +438,15 @@ def _find_diag_from_eye_mul(potential_mul_input): # Check if 1's are being put on the main diagonal only (k = 0) # and if the identity matrix is degenerate (column or row matrix) if not ( - isinstance(inner_eye.owner.inputs[-1], Constant) - and inner_eye.owner.inputs[-1].data == 0 + Eye.is_offset_zero(inner_eye.owner) and inner_eye.broadcastable[-1:] != (False, False) ): return None - elif getattr( - eye_input.owner.inputs[-1], "data", -1 - ).item() != 0 or eye_input.broadcastable[-2:] != (False, False): + elif not ( + Eye.is_offset_zero(eye_input.owner) + and eye_input.broadcastable[-1:] != (False, False) + ): return None # Get all non Eye inputs (scalars/matrices/vectors) From c0381091aa6875cee4c83c84a6e3f2d1bf549b88 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 18 Jul 2024 09:19:09 +0800 Subject: [PATCH 30/33] Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle` --- pytensor/tensor/elemwise.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 386432e2f8..bd6fc4a96f 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -182,14 +182,12 @@ def __init__(self, input_broadcastable, new_order): self.drop = drop input_ndim = len(input_broadcastable) - - self.is_left_expand_dims = "x" in new_order and ( - input_ndim == 0 or new_order[:input_ndim] == list(range(input_ndim)) - ) - - self.is_right_expand_dims = "x" in new_order and ( + self.is_left_expand_dims = self.augment and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) + self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( + range(input_ndim) + ) if self.inplace: self.view_map = {0: [0]} From 19f2895c21348c40711b5ed114168388b99c6c6a Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 18 Jul 2024 10:50:12 +0800 Subject: [PATCH 31/33] Seed `test_local_lift_through_linalg` test --- tests/tensor/rewriting/test_linalg.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c146a5b096..27c1c82bda 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -362,6 +362,8 @@ def test_invalid_batched_a(self): ids=["block_diag", "kron"], ) def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): + rng = np.random.default_rng(sum(map(ord, "lift_through_linalg"))) + if pytensor.config.floatX.endswith("32"): pytest.skip("Test is flaky at half precision") @@ -387,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): assert len(f_ops) == 2 assert len(g_ops) == 1 - test_vals = [ - np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2) - ] + test_vals = [rng.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)] test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) @@ -466,6 +466,8 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): x_diag = pt.eye(1, 1) * x y = pt.linalg.det(x_diag) f_rewritten = function([x], y, mode="FAST_RUN") + pytensor.dprint(f_rewritten) + nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Det) for node in nodes) @@ -475,6 +477,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): x_test_matrix = np.eye(1, 1) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) + assert_allclose( det_val, rewritten_val, From fcbccde1dd4e37bd9d36103b9942163d26dca825 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 18 Jul 2024 10:51:05 +0800 Subject: [PATCH 32/33] Fix failing diag_rewrite test --- pytensor/tensor/rewriting/linalg.py | 29 +++++++++++++-------------- tests/tensor/rewriting/test_linalg.py | 1 - 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 61ddbb3ef3..50a45956d9 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -422,33 +422,33 @@ def _find_diag_from_eye_mul(potential_mul_input): ) ) ] + if not eye_input: return None eye_input = eye_input[0] + # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset + if isinstance(eye_input.owner.op, Eye) and ( + not Eye.is_offset_zero(eye_input.owner) + or eye_input.broadcastable[-2:] != (False, False) + ): + return None - # If this multiplication came from a batched operation, it will be wrapped in a DimShuffle + # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on). + # We have to look inside DimShuffle to decide if the rewrite can be applied if isinstance(eye_input.owner.op, DimShuffle) and ( eye_input.owner.op.is_left_expand_dims or eye_input.owner.op.is_right_expand_dims ): inner_eye = eye_input.owner.inputs[0] - if not isinstance(inner_eye.owner.op, Eye): - return None - # Check if 1's are being put on the main diagonal only (k = 0) - # and if the identity matrix is degenerate (column or row matrix) - if not ( - Eye.is_offset_zero(inner_eye.owner) - and inner_eye.broadcastable[-1:] != (False, False) + # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't + # degenerate + if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != ( + False, + False, ): return None - elif not ( - Eye.is_offset_zero(eye_input.owner) - and eye_input.broadcastable[-1:] != (False, False) - ): - return None - # Get all non Eye inputs (scalars/matrices/vectors) non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) return eye_input, non_eye_inputs @@ -493,7 +493,6 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix inputs_or_none = _find_diag_from_eye_mul(inputs) - if inputs_or_none is None: return None diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 27c1c82bda..0bc064fe65 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -466,7 +466,6 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): x_diag = pt.eye(1, 1) * x y = pt.linalg.det(x_diag) f_rewritten = function([x], y, mode="FAST_RUN") - pytensor.dprint(f_rewritten) nodes = f_rewritten.maker.fgraph.apply_nodes From 56a3ffecee41374c6e69b6015cf7e55070e912a9 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 18 Jul 2024 21:08:14 +0800 Subject: [PATCH 33/33] Revert symbolic offset --- pytensor/tensor/basic.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index ac4c169ebc..25d6503531 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3753,9 +3753,10 @@ class AllocDiag(OpFromGraph): __props__ = ("axis1", "axis2") - def __init__(self, *args, axis1, axis2, **kwargs): + def __init__(self, *args, axis1, axis2, offset, **kwargs): self.axis1 = axis1 self.axis2 = axis2 + self.offset = offset super().__init__(*args, **kwargs, strict=True) @@ -3775,8 +3776,7 @@ def is_offset_zero(node) -> bool: True if the offset is zero (``k = 0``). """ - offset = node.inputs[-1] - return isinstance(offset, Constant) and offset.data.item() == 0 + return node.op.offset == 0 def alloc_diag(diag, offset=0, axis1=0, axis2=1): @@ -3785,10 +3785,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): diagonal(alloc_diag(x)) == x """ from pytensor.tensor import set_subtensor - from pytensor.tensor.math import maximum diag = as_tensor_variable(diag) - offset = as_tensor_variable(offset) axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1) if axis1 > axis2: @@ -3801,8 +3799,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): # Create slice for diagonal in final 2 axes idxs = arange(diag.shape[-1]) diagonal_slice = (slice(None),) * (len(result_shape) - 2) + ( - idxs + maximum(0, -offset), - idxs + maximum(0, offset), + idxs + np.maximum(0, -offset), + idxs + np.maximum(0, offset), ) # Fill in final 2 axes with diag @@ -3817,11 +3815,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): result = result.transpose(axes) return AllocDiag( - inputs=[diag, offset], - outputs=[result], - axis1=axis1, - axis2=axis2, - )(diag, offset) + inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset + )(diag) def diag(v, k=0):