diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index b535540425..9087e718b5 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -8,7 +8,6 @@ from pytensor.compile.function import function from pytensor.compile.function.pfunc import rebuild_collect_shared -from pytensor.compile.mode import optdb from pytensor.compile.sharedvalue import SharedVariable from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, Rop, grad @@ -24,7 +23,6 @@ from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.utils import MissingInputError @@ -575,7 +573,7 @@ def lop_overrides(inps, grads): for inp_grad in input_grads if not isinstance(inp_grad.type, DisconnectedType | NullType) ] - lop_op = type(self)( + lop_op = OpFromGraph( inputs=inner_inputs + connected_inner_outputs + connected_output_grads, outputs=connected_input_grads, inline=self.is_inline, @@ -669,7 +667,7 @@ def _build_and_cache_rop_op(self): for out_grad in output_grads if not isinstance(out_grad.type, DisconnectedType | NullType) ] - rop_op = type(self)( + rop_op = OpFromGraph( inputs=inner_inputs + eval_points, outputs=filtered_output_grads, inline=self.is_inline, @@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs): assert len(variables) == len(outputs) for output, variable in zip(outputs, variables): output[0] = variable - - -@node_rewriter([OpFromGraph]) -def inline_ofg_expansion(fgraph, node): - """ - This optimization expands internal graph of OpFromGraph. - Only performed if node.op.is_inline == True - Doing so can improve optimization at the cost of compilation speed. - """ - op = node.op - if not isinstance(op, OpFromGraph): - return False - if not op.is_inline: - return False - return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) - - -# We want to run this before the first merge optimizer -# and before the first scan optimizer. -optdb.register( - "inline_ofg_expansion", - in2out(inline_ofg_expansion), - "fast_compile", - "fast_run", - position=-0.01, -) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index b35759f837..bd559ee716 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -1,10 +1,13 @@ import warnings +from collections.abc import Callable from functools import singledispatch import jax import jax.numpy as jnp import numpy as np +from pytensor.compile import JAX +from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph @@ -114,3 +117,24 @@ def viewop(x): return x return viewop + + +@jax_funcify.register(OpFromGraph) +def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: + _ = kwargs.pop("storage_map", None) + + # Apply inner rewrites + JAX.optimizer(ofg.fgraph) + fgraph_fn = jax_funcify(ofg.fgraph, **kwargs) + + if len(ofg.fgraph.outputs) == 1: + + def opfromgraph(*inputs): + return fgraph_fn(*inputs)[0] + + else: + + def opfromgraph(*inputs): + return fgraph_fn(*inputs) + + return opfromgraph diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 414e3b6ed2..25d6503531 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -21,6 +21,7 @@ import pytensor.scalar.sharedvar from pytensor import compile, config, printing from pytensor import scalar as ps +from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply, Constant, Variable, equal_computations @@ -1333,6 +1334,25 @@ def infer_shape(self, fgraph, node, in_shapes): def grad(self, inp, grads): return [grad_undefined(self, i, inp[i]) for i in range(3)] + @staticmethod + def is_offset_zero(node) -> bool: + """ + Test if an Eye Op has a diagonal offset of zero + + Parameters + ---------- + node + Eye node to test + + Returns + ------- + is_offset_zero: bool + True if the offset is zero (``k = 0``). + """ + + offset = node.inputs[-1] + return isinstance(offset, Constant) and offset.data.item() == 0 + def eye(n, m=None, k=0, dtype=None): """Return a 2-D array with ones on the diagonal and zeros elsewhere. @@ -3726,109 +3746,37 @@ def trace(a, offset=0, axis1=0, axis2=1): return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1) -class AllocDiag(Op): - """An `Op` that copies a vector to the diagonal of a zero-ed matrix.""" +class AllocDiag(OpFromGraph): + """ + Wrapper Op for alloc_diag graphs + """ - __props__ = ("offset", "axis1", "axis2") + __props__ = ("axis1", "axis2") - def __init__(self, offset=0, axis1=0, axis2=1): - """ - Parameters - ---------- - offset: int - Offset of the diagonal from the main diagonal defined by `axis1` - and `axis2`. Can be positive or negative. Defaults to main - diagonal (i.e. 0). - axis1: int - Axis to be used as the first axis of the 2-D sub-arrays to which - the diagonals will be allocated. Defaults to first axis (i.e. 0). - axis2: int - Axis to be used as the second axis of the 2-D sub-arrays to which - the diagonals will be allocated. Defaults to second axis (i.e. 1). - """ - warnings.warn( - "AllocDiag is deprecated. Use `alloc_diag` instead", - FutureWarning, - ) - self.offset = offset - if axis1 < 0 or axis2 < 0: - raise NotImplementedError("AllocDiag does not support negative axis") - if axis1 == axis2: - raise ValueError("axis1 and axis2 cannot be the same") + def __init__(self, *args, axis1, axis2, offset, **kwargs): self.axis1 = axis1 self.axis2 = axis2 + self.offset = offset - def make_node(self, diag): - diag = as_tensor_variable(diag) - if diag.type.ndim < 1: - raise ValueError( - "AllocDiag needs an input with 1 or more dimensions", diag.type - ) - return Apply( - self, - [diag], - [diag.type.clone(shape=(None,) * (diag.ndim + 1))()], - ) - - def perform(self, node, inputs, outputs): - (x,) = inputs - (z,) = outputs - - axis1 = np.minimum(self.axis1, self.axis2) - axis2 = np.maximum(self.axis1, self.axis2) - offset = self.offset - - # Create array with one extra dimension for resulting matrix - result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2 - result = np.zeros(result_shape, dtype=x.dtype) - - # Create slice for diagonal in final 2 axes - idxs = np.arange(x.shape[-1]) - diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [ - idxs + np.maximum(0, -offset), - idxs + np.maximum(0, offset), - ] - - # Fill in final 2 axes with x - result[tuple(diagonal_slice)] = x - - if len(x.shape) > 1: - # Re-order axes so they correspond to diagonals at axis1, axis2 - axes = list(range(len(x.shape[:-1]))) - last_idx = axes[-1] - axes = axes[:axis1] + [last_idx + 1] + axes[axis1:] - axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] - result = result.transpose(axes) - - z[0] = result - - def grad(self, inputs, gout): - (gz,) = gout - return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)] - - def infer_shape(self, fgraph, nodes, shapes): - (x_shape,) = shapes - axis1 = np.minimum(self.axis1, self.axis2) - axis2 = np.maximum(self.axis1, self.axis2) + super().__init__(*args, **kwargs, strict=True) - result_shape = list(x_shape[:-1]) - diag_shape = x_shape[-1] + abs(self.offset) - result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:] - result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:] - return [tuple(result_shape)] + @staticmethod + def is_offset_zero(node) -> bool: + """ + Test if an AllocDiag Op has a diagonal offset of zero - def __setstate__(self, state): - if "view_map" in state: - del state["view_map"] + Parameters + ---------- + node + AllocDiag node to test - self.__dict__.update(state) + Returns + ------- + is_offset_zero: bool + True if the offset is zero (``k = 0``). + """ - if "offset" not in state: - self.offset = 0 - if "axis1" not in state: - self.axis1 = 0 - if "axis2" not in state: - self.axis2 = 1 + return node.op.offset == 0 def alloc_diag(diag, offset=0, axis1=0, axis2=1): @@ -3839,6 +3787,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): from pytensor.tensor import set_subtensor diag = as_tensor_variable(diag) + axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1) if axis1 > axis2: axis1, axis2 = axis2, axis1 @@ -3865,7 +3814,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] result = result.transpose(axes) - return result + return AllocDiag( + inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset + )(diag) def diag(v, k=0): diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 2fdc8e7fd5..bd6fc4a96f 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -181,6 +181,14 @@ def __init__(self, input_broadcastable, new_order): self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"]) self.drop = drop + input_ndim = len(input_broadcastable) + self.is_left_expand_dims = self.augment and ( + input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) + ) + self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( + range(input_ndim) + ) + if self.inplace: self.view_map = {0: [0]} diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 617eab04fa..168b636041 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math +import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.subtensor diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3c98834c94..50a45956d9 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,12 +5,16 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( - PatternNodeRewriter, copy_stack_trace, node_rewriter, ) from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal +from pytensor.tensor.basic import ( + AllocDiag, + Eye, + TensorVariable, + diagonal, +) from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -41,7 +45,6 @@ solve, solve_triangular, ) -from pytensor.tensor.subtensor import advanced_set_subtensor logger = logging.getLogger(__name__) @@ -401,30 +404,68 @@ def _find_diag_from_eye_mul(potential_mul_input): eye_input = [ mul_input for mul_input in inputs_to_mul - if mul_input.owner and isinstance(mul_input.owner.op, Eye) + if mul_input.owner + and ( + isinstance(mul_input.owner.op, Eye) + or + # This whole condition checks if there is an Eye hiding inside a DimShuffle. + # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: + # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. + ( + isinstance(mul_input.owner.op, DimShuffle) + and ( + mul_input.owner.op.is_left_expand_dims + or mul_input.owner.op.is_right_expand_dims + ) + and mul_input.owner.inputs[0].owner is not None + and isinstance(mul_input.owner.inputs[0].owner.op, Eye) + ) + ) ] - # Check if 1's are being put on the main diagonal only (k = 0) - if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: + if not eye_input: return None - # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite - if eye_input and eye_input[0].broadcastable[-2:] != (False, False): + eye_input = eye_input[0] + # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset + if isinstance(eye_input.owner.op, Eye) and ( + not Eye.is_offset_zero(eye_input.owner) + or eye_input.broadcastable[-2:] != (False, False) + ): return None + # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on). + # We have to look inside DimShuffle to decide if the rewrite can be applied + if isinstance(eye_input.owner.op, DimShuffle) and ( + eye_input.owner.op.is_left_expand_dims + or eye_input.owner.op.is_right_expand_dims + ): + inner_eye = eye_input.owner.inputs[0] + # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't + # degenerate + if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != ( + False, + False, + ): + return None + # Get all non Eye inputs (scalars/matrices/vectors) - non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) return eye_input, non_eye_inputs @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @node_rewriter([det]) -def rewrite_det_diag_from_eye_mul(fgraph, node): +def rewrite_det_diag_to_prod_diag(fgraph, node): """ - This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its + diagonal elements. - The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. Parameters ---------- @@ -438,53 +479,45 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - potential_mul_input = node.inputs[0] - eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) - if eye_non_eye_inputs is None: + inputs = node.inputs[0] + + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + diag_input = inputs.owner.inputs[0] + det_val = diag_input.prod(axis=-1) + return [det_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: return None - eye_input, non_eye_inputs = eye_non_eye_inputs + + eye_input, non_eye_inputs = inputs_or_none # Dealing with only one other input if len(non_eye_inputs) != 1: return None - useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] + eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] # Checking if original x was scalar/vector/matrix - if useful_non_eye.type.broadcastable[-2:] == (True, True): + if non_eye_input.type.broadcastable[-2:] == (True, True): # For scalar - det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) - elif useful_non_eye.type.broadcastable[-2:] == (False, False): + det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0]) + elif non_eye_input.type.broadcastable[-2:] == (False, False): # For Matrix - det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) + det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1) else: # For vector - det_val = useful_non_eye.prod(axis=(-1, -2)) + det_val = non_eye_input.prod(axis=(-1, -2)) det_val = det_val.astype(node.outputs[0].type.dtype) return [det_val] -arange = ARange("int64") -det_diag_from_diag = PatternNodeRewriter( - ( - det, - ( - advanced_set_subtensor, - (alloc, 0, "sh1", "sh2"), - "x", - (arange, 0, "stop", 1), - (arange, 0, "stop", 1), - ), - ), - (prod, "x"), - name="det_diag_from_diag", - allow_multiple_clients=True, -) -register_canonicalize(det_diag_from_diag) -register_stabilize(det_diag_from_diag) -register_specialize(det_diag_from_diag) - - @register_canonicalize @register_stabilize @register_specialize diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py new file mode 100644 index 0000000000..265f3ff2e8 --- /dev/null +++ b/pytensor/tensor/rewriting/ofg.py @@ -0,0 +1,68 @@ +from pytensor import clone_replace +from pytensor.compile import optdb +from pytensor.compile.builders import OpFromGraph +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out +from pytensor.tensor.basic import AllocDiag +from pytensor.tensor.rewriting.basic import register_specialize + + +@node_rewriter([OpFromGraph]) +def inline_ofg_expansion(fgraph, node): + """ + This optimization expands internal graph of OpFromGraph. + Only performed if node.op.is_inline == True + Doing so can improve optimization at the cost of compilation speed. + """ + op = node.op + if not op.is_inline: + return False + + new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + copy_stack_trace(op.inner_outputs, new_out) + + return new_out + + +# We want to run this before the first merge optimizer +# and before the first scan optimizer. +optdb.register( + "inline_ofg_expansion", + in2out(inline_ofg_expansion), + "fast_compile", + "fast_run", + position=-0.01, +) + + +@register_specialize("inline_ofg") +@node_rewriter([AllocDiag]) +def late_inline_OpFromGraph(fgraph, node): + """ + Inline `OpFromGraph` nodes. + + OpFromGraph nodes are used to compactly represent the output of a function graph. Certain `Ops`, like, einsum, + diag, and kron, are implemented using pytensor `Op`s. As a result, their outputs are not a single `Op`, but a + graph. To allow rewrites to easily spot and manipulate these "composite functions", we use the `OpFromGraph` node. + This node is a thin wrapper around the output graph. It is not, however, meant to be included in the final + program, because it hides the inner graph from certain optimizations. + + This rewrite specifies that all `OpFromGraph` nodes should be replaced by their inner graphs by setting the + `inplace=True` flag. + + Parameters + ---------- + fgraph: FunctionGraph + The function graph being rewritten + node: Apply + Node of the function graph to be optimized + + Returns + ------- + + """ + op = node.op + new_out = clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + copy_stack_trace(op.inner_outputs, new_out) + + return new_out diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 76c8b4b329..5cd2bd54c6 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -13,7 +14,7 @@ from pytensor.graph.op import Op, get_test_value from pytensor.ifelse import ifelse from pytensor.raise_op import assert_op -from pytensor.tensor.type import dscalar, scalar, vector +from pytensor.tensor.type import dscalar, matrices, scalar, vector @pytest.fixture(scope="module", autouse=True) @@ -209,3 +210,19 @@ def test_jax_checkandraise(): def set_test_value(x, v): x.tag.test_value = v return x + + +def test_OpFromGraph(): + x, y, z = matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y], inline=False) + ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False) + + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 + out_fg = FunctionGraph([x, y, z], [out]) + + xv = np.ones((2, 2), dtype=config.floatX) + yv = np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + compare_jax_and_py(out_fg, [xv, yv, zv]) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index d59e3cc88f..0bc064fe65 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -362,6 +362,8 @@ def test_invalid_batched_a(self): ids=["block_diag", "kron"], ) def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): + rng = np.random.default_rng(sum(map(ord, "lift_through_linalg"))) + if pytensor.config.floatX.endswith("32"): pytest.skip("Test is flaky at half precision") @@ -371,6 +373,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): f1 = pytensor.function( [A, B], X, mode=get_default_mode().including("local_lift_through_linalg") ) + f2 = pytensor.function( [A, B], X, mode=get_default_mode().excluding("local_lift_through_linalg") ) @@ -386,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): assert len(f_ops) == 2 assert len(g_ops) == 1 - test_vals = [ - np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2) - ] + test_vals = [rng.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)] test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) @@ -403,13 +404,18 @@ def test_det_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) y = pt.eye(7) * x + # Calculating determinant value using pt.linalg.det z_det = pt.linalg.det(y) # REWRITE TEST f_rewritten = function([x], z_det, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes - assert not any(isinstance(node.op, Det) for node in nodes) + + assert not any( + isinstance(node.op, Det) or isinstance(getattr(node.op, "core_op", None), Det) + for node in nodes + ) # NUMERIC VALUE TEST if len(shape) == 0: @@ -418,6 +424,7 @@ def test_det_diag_from_eye_mul(shape): x_test = np.random.rand(*shape).astype(config.floatX) else: x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) @@ -459,6 +466,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): x_diag = pt.eye(1, 1) * x y = pt.linalg.det(x_diag) f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Det) for node in nodes) @@ -468,6 +476,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): x_test_matrix = np.eye(1, 1) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) + assert_allclose( det_val, rewritten_val, diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py new file mode 100644 index 0000000000..6304939562 --- /dev/null +++ b/tests/tensor/rewriting/test_ofg.py @@ -0,0 +1,22 @@ +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor import config +from pytensor.compile.builders import OpFromGraph + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="Rewrite is not applied in FAST_COMPILE mode", +) +def test_alloc_diag_inlined(): + x = pt.tensor("x", shape=(None,)) + + z = pt.diag(x) + assert isinstance(z.owner.op, OpFromGraph) + + f = pytensor.function([x], z) + nodes = f.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, OpFromGraph) for node in nodes)