From 2ac87749e6625c744f3a295477cd210e3a6dc968 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 14:56:58 +0200 Subject: [PATCH] Simplify rewrites by assuming Elemwise / Alloc shapes are correct --- pytensor/tensor/rewriting/basic.py | 165 ++++++++------------------- tests/tensor/rewriting/test_basic.py | 71 ++++++------ 2 files changed, 86 insertions(+), 150 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 23cb429e37..b190a4ea80 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -23,7 +23,7 @@ """ import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import Union import numpy as np @@ -65,21 +65,17 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import Sum, add from pytensor.tensor.math import all as at_all from pytensor.tensor.math import eq -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.sort import TopKOp from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.utils import NoDuplicateOptWarningFilter -if TYPE_CHECKING: - from pytensor.tensor.rewriting.shape import ShapeFeature - - _logger = logging.getLogger("pytensor.tensor.rewriting.basic") _logger.addFilter(NoDuplicateOptWarningFilter()) @@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node): def local_elemwise_alloc(fgraph, node): r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. - `Alloc`\s are effectively a type of `Elemwise` operation - (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so - this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to - `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it - broadcasts). - - In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant - `Alloc`\s. - The rewrite essentially performs the following replacement: - ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``, - when ``y.shape`` for some input ``y`` (or the combined shapes of the - non-`Alloc`\s) is sufficient to maintain the same/correct output shape. + ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)`` - In it's current form, it also explicitly accounts for `DimShuffle`\s of + In its current form, it also explicitly accounts for `DimShuffle`\s of `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which introduces them as a canonicalization of `Alloc`'s with leading broadcastable dimensions. """ - # Rewrite is only applicable when there are at least two inputs if len(node.inputs) == 1: - return False - - if len(node.outputs) > 1: - return False + return None def dimshuffled_alloc(i): return ( @@ -305,76 +286,40 @@ def dimshuffled_alloc(i): if len(alloc_idxs) == 0: return False - # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a - # baseline for the dimensions. - ref_var_idx = None - for idx, i in enumerate(node.inputs): - if i.type.broadcastable == node.outputs[0].type.broadcastable: - # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an - # `Alloc`, so that all `Alloc`s can be rewritten. - if idx not in alloc_idxs: - ref_var_idx = idx - break - - # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one - if ref_var_idx is None: - for idx, i in enumerate(node.inputs): - # XXX: This broadcastable comparison doesn't work - if ( - i.type.broadcastable == node.outputs[0].type.broadcastable - ) and idx in alloc_idxs: - ref_var_idx = idx - break - - if not hasattr(fgraph, "shape_feature"): - return False - - input_shapes = [ - tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim)) - for i in node.inputs - ] - bcasted_shape = broadcast_shape( - *input_shapes, - arrays_are_shapes=True, - ) - new_inputs = list(node.inputs) for idx in alloc_idxs: i = node.inputs[idx] - # Remove `Alloc` + # Remove simple `Alloc` if isinstance(i.owner.op, Alloc): - new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape) + new_inp = i.owner.inputs[0] - # TODO FIXME: This shouldn't be handled here. - # `DimShuffle`s should be lifted through `Alloc`s - # by other, more general rewrites. - # Remove `Alloc` in `DimShuffle` + # Remove `Dimshuffle(Alloc)` elif isinstance(i.owner.op, DimShuffle): old_alloc = i.owner.inputs[0] - new_alloc = old_alloc.owner.inputs[0] + old_alloc_inp = old_alloc.owner.inputs[0] + missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim + if missing_ndims > 0: + # The `Alloc` added new dimensions to the left. + # We replace those cases with a `DimShuffle` here. + # Nested dimshuffles will be merged later by other rewrites. + old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims) # We need to keep the old `DimShuffle`. It could swap axes or # add dimensions anywhere. - if new_alloc.ndim != old_alloc.ndim: - # The `Alloc` can add dimensions to the value. - # We replace those cases with a `DimShuffle` here. - nb_dim_to_add = old_alloc.ndim - new_alloc.ndim - new_alloc = new_alloc.dimshuffle( - ["x"] * nb_dim_to_add + list(range(new_alloc.ndim)) - ) - new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape) + new_inp = i.owner.op(old_alloc_inp) - copy_stack_trace(i, new_alloc) - new_inputs[idx] = new_alloc + copy_stack_trace(i, new_inp) + new_inputs[idx] = new_inp - # If this assert is triggered, it means we are recreating an equivalent graph - # which would result in cyclical merge rewrites. - if all(new is old for new, old in zip(new_inputs, node.inputs)): - return + new_outs = node.op(*new_inputs, return_list=True) - ret = node.op(*new_inputs, return_list=True) - copy_stack_trace(node.outputs, ret) - return ret + if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable: + new_outs = [ + alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs + ] + + copy_stack_trace(node.outputs, new_outs) + return new_outs @register_canonicalize("shape_unsafe") @@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node): # The newly created node c doesn't has 'clients', # so this iteration is took place with node.outputs[0] + # TODO: This should just be a WalkingGraphRewrite! replacements = {node.outputs[0]: c} for client, cl_idx in fgraph.clients[node.outputs[0]]: if ( @@ -438,9 +384,8 @@ def local_fill_to_alloc(fgraph, node): with their dependencies on those tensors' shapes, and sometimes those shapes can be computed without needing to compute the tensors themselves. - XXX: This rewrite can produce inconsistent results, so do *not* consider - making it a canonicalization until those inconsistencies are - resolved/justified. + Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent, + which could mask shape errors. """ shape_ref, values_ref = node.inputs out_type = node.outputs[0].type @@ -448,13 +393,6 @@ def local_fill_to_alloc(fgraph, node): if values_ref.type.broadcastable == out_type.broadcastable: # The assumption here is that `values_ref` already has the same shape # as `shape_ref`, so a `fill`/`Alloc` is unnecessary. - - # XXX FIXME TODO: The only way this can be determined is if one - # absolutely knows that the shapes of `shape_ref` and `values_ref` are - # equal. - # This is an old rewrite, and it's only a - # "specialization/stabilization", so we're going to leave it be for - # now. return [values_ref] if shape_ref.type.broadcastable == out_type.broadcastable: @@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node): copy_stack_trace(node.outputs[0], o) return [o] + # The case that is not covered is when `shape_ref` is broadcasted by `values_ref` + # TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape)) + return @@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node): return [element_sum] -@register_useless("local_remove_switch_const_cond") -@register_canonicalize("fast_compile", "local_remove_switch_const_cond") -@register_specialize -@node_rewriter([Elemwise]) +@register_useless("shape_unsafe") +@register_canonicalize("fast_compile", "shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([switch]) def local_useless_switch(fgraph, node): """ This rewrite makes the following changes in a graph: - at.switch(cond, left, right) -> - if cond is constant and cond == 0: right - if cond is constant and cond != 0: left - if left is right -> left + switch(cond, left, right) -> + if cond is constant and cond == 0: right + if cond is constant and cond != 0: left + if left is right -> left and - at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) + switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) """ - if not isinstance(node.op.scalar_op, aes.Switch): - return False - - shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None) - - if shape_feature is None: - return False left = node.inputs[1] right = node.inputs[2] cond_var = node.inputs[0] cond = extract_constant(cond_var, only_process_constants=True) + out_bcast = node.outputs[0].type.broadcastable if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( cond, (np.number, np.bool_) @@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node): else: out = correct_out - input_shapes = [ - tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim)) - for inp in node.inputs - ] - - out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True) - - out = alloc(out, *out_shape) + if out.type.broadcastable != out_bcast: + out = broadcast_arrays(out, *node.inputs)[0] # Copy over stacktrace from selected output to new output copy_stack_trace(node.outputs + correct_out, out) @@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node): if left == right: # Note: No need to copy over stacktrace, because the input node # already has its own stacktrace - if cond.type.is_super(left.type): + if left.type.broadcastable == out_bcast: return [left] - ret = fill(cond, left) + ret = broadcast_arrays(left, cond)[0] # Copy over stacktrace from switch output and correct branch copy_stack_trace(node.outputs + left, ret) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index eb2aa8eeb1..5d364da6fd 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1013,7 +1013,7 @@ def test_broadcasting_1(self): z = at.switch(1, x, y) f = function([x, y], z, mode=self.mode) - start_var = f.maker.fgraph.outputs[0].owner.inputs[0] + start_var = f.maker.fgraph.outputs[0] assert isinstance(start_var.owner.op, Elemwise) assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast) assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) @@ -1698,45 +1698,50 @@ def verify_op_count(f, count, cls): ) @pytest.mark.parametrize( - "expr, x_shape, y_shape", + "expr, x_shape, y_shape, needs_alloc", [ - (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)), - (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)), - (lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)), + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2), True), + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1), False), + (lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3), False), ( lambda x, y: at.mul( at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x") ), (), (), + True, ), - (lambda x, y: at.mul(y, at.alloc(1, x)), (), ()), - (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), - (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), + (lambda x, y: at.mul(y, at.alloc(1, x)), (), (), True), + (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1), False), + (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2), False), ( lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1), + False, ), ( lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2), + False, ), ( lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15), + False, ), - (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), + (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2), False), ( lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), (15, 2), (2, 15), + False, ), ], ) - def test_basic(self, expr, x_shape, y_shape): + def test_basic(self, expr, x_shape, y_shape, needs_alloc): x = at.tensor( dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x" ) @@ -1752,10 +1757,16 @@ def test_basic(self, expr, x_shape, y_shape): on_unused_input="ignore", ) - assert not any( - isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort() - ) + nodes = z_opt.maker.fgraph.toposort() + if needs_alloc: + # When the final result needs an Alloc, this should be the last node + # x = scalar; y = vector; mul(x, ones_like(y)) -> alloc(x, y.shape) + assert isinstance(nodes[-1].op, Alloc) + nodes = nodes[:-1] + + assert not any(isinstance(node.op, Alloc) for node in nodes) + # Check results are the same without the optimization z_no_opt = pytensor.function( [x, y], z, @@ -1799,7 +1810,7 @@ def test_remove_alloc_wo_dimshuffle(self): [self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) func = function( [self.vec, self.mat], @@ -1807,7 +1818,7 @@ def test_remove_alloc_wo_dimshuffle(self): mode=self.fast_run_mode, ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 1, Assert) + self.verify_op_count(func, 1, SpecifyShape) # No optimization on alloc without assert func = function( @@ -1839,7 +1850,10 @@ def test_remove_alloc_wo_dimshuffle(self): self.alloc_w_dep_broad2 + self.mat, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) + # This graph requires one outer Alloc and an Assert + # To make sure `mat` is square since we end up doing + # broadcast_to(x, mat[..., None].shape) + mat[None, ...] + self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 1, Assert) def test_remove_alloc_w_dimshuffle(self): @@ -1851,16 +1865,13 @@ def test_remove_alloc_w_dimshuffle(self): self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 0, Assert) - # TODO FIXME: The `BroadcastTo` shapes should use the constants - # provided by the first/`Alloc` term, and not the unknown values from - # the `tens` term. func = function( [self.vec, self.tens], self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, mode=self.fast_run_mode, ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) func = function( [self.vec, self.tens], @@ -1888,16 +1899,13 @@ def test_multi_input_single_alloc(self): self.verify_op_count(func, 2, Alloc) self.verify_op_count(func, 0, Assert) - # Optimization on dimshuffle with assert - # TODO: When we support static shape constraints like `shape[i] != 1`, - # reproduce this with such a constraint on `mat` and make sure the - # `BroadcastTo` is removed. func = function( [self.vec, self.mat], self.tv_wo_dep + self.tm_wo_dep, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) + # It still needs an outer alloc to broadcast final shape + self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 0, Assert) # No optimization on dimshuffle without assert @@ -1909,25 +1917,24 @@ def test_multi_input_single_alloc(self): self.verify_op_count(func, 2, Alloc) self.verify_op_count(func, 0, Assert) - # Optimization on dimshuffle without assert func = function( [self.vec, self.mat, self.s], self.tv_w_dep + self.tm_w_dep, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) - # The second assert is from the shape check... - self.verify_op_count(func, 2, Assert) + # It still needs an outer alloc to broadcast final shape + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) def test_misc(self): - x = row(dtype=self.dtype) - y = tensor(dtype=self.dtype, shape=(None, None, 1)) + x = row("x", dtype=self.dtype) + y = tensor("y", dtype=self.dtype, shape=(None, None, 1)) out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y func = function([y, x], out, mode=self.fast_run_mode) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) y_val = np.random.random((5, 5, 1)).astype(self.dtype) x_val = np.random.random((1, 5)).astype(self.dtype)