Skip to content

Commit

Permalink
Simplify rewrites by assuming Elemwise / Alloc shapes are correct
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 7, 2023
1 parent 548c14a commit 2ac8774
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 150 deletions.
165 changes: 47 additions & 118 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import logging
from typing import TYPE_CHECKING, Optional, Union
from typing import Union

import numpy as np

Expand Down Expand Up @@ -65,21 +65,17 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.var import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter


if TYPE_CHECKING:
from pytensor.tensor.rewriting.shape import ShapeFeature


_logger = logging.getLogger("pytensor.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter())

Expand Down Expand Up @@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node):
def local_elemwise_alloc(fgraph, node):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
`Alloc`\s are effectively a type of `Elemwise` operation
(e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
broadcasts).
In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
`Alloc`\s.
The rewrite essentially performs the following replacement:
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
when ``y.shape`` for some input ``y`` (or the combined shapes of the
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
In it's current form, it also explicitly accounts for `DimShuffle`\s of
In its current form, it also explicitly accounts for `DimShuffle`\s of
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
"""
# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return False

if len(node.outputs) > 1:
return False
return None

def dimshuffled_alloc(i):
return (
Expand All @@ -305,76 +286,40 @@ def dimshuffled_alloc(i):
if len(alloc_idxs) == 0:
return False

# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions.
ref_var_idx = None
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be rewritten.
if idx not in alloc_idxs:
ref_var_idx = idx
break

# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if ref_var_idx is None:
for idx, i in enumerate(node.inputs):
# XXX: This broadcastable comparison doesn't work
if (
i.type.broadcastable == node.outputs[0].type.broadcastable
) and idx in alloc_idxs:
ref_var_idx = idx
break

if not hasattr(fgraph, "shape_feature"):
return False

input_shapes = [
tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
for i in node.inputs
]
bcasted_shape = broadcast_shape(
*input_shapes,
arrays_are_shapes=True,
)

new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]

# Remove `Alloc`
# Remove simple `Alloc`
if isinstance(i.owner.op, Alloc):
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
new_inp = i.owner.inputs[0]

# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle`
# Remove `Dimshuffle(Alloc)`
elif isinstance(i.owner.op, DimShuffle):
old_alloc = i.owner.inputs[0]
new_alloc = old_alloc.owner.inputs[0]
old_alloc_inp = old_alloc.owner.inputs[0]
missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim
if missing_ndims > 0:
# The `Alloc` added new dimensions to the left.
# We replace those cases with a `DimShuffle` here.
# Nested dimshuffles will be merged later by other rewrites.
old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims)
# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
if new_alloc.ndim != old_alloc.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
new_alloc = new_alloc.dimshuffle(
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
)
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
new_inp = i.owner.op(old_alloc_inp)

copy_stack_trace(i, new_alloc)
new_inputs[idx] = new_alloc
copy_stack_trace(i, new_inp)
new_inputs[idx] = new_inp

# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in cyclical merge rewrites.
if all(new is old for new, old in zip(new_inputs, node.inputs)):
return
new_outs = node.op(*new_inputs, return_list=True)

ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret)
return ret
if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable:
new_outs = [
alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs
]

copy_stack_trace(node.outputs, new_outs)
return new_outs


@register_canonicalize("shape_unsafe")
Expand Down Expand Up @@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node):

# The newly created node c doesn't has 'clients',
# so this iteration is took place with node.outputs[0]
# TODO: This should just be a WalkingGraphRewrite!
replacements = {node.outputs[0]: c}
for client, cl_idx in fgraph.clients[node.outputs[0]]:
if (
Expand Down Expand Up @@ -438,23 +384,15 @@ def local_fill_to_alloc(fgraph, node):
with their dependencies on those tensors' shapes, and sometimes those
shapes can be computed without needing to compute the tensors themselves.
XXX: This rewrite can produce inconsistent results, so do *not* consider
making it a canonicalization until those inconsistencies are
resolved/justified.
Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
which could mask shape errors.
"""
shape_ref, values_ref = node.inputs
out_type = node.outputs[0].type

if values_ref.type.broadcastable == out_type.broadcastable:
# The assumption here is that `values_ref` already has the same shape
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.

# XXX FIXME TODO: The only way this can be determined is if one
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
# equal.
# This is an old rewrite, and it's only a
# "specialization/stabilization", so we're going to leave it be for
# now.
return [values_ref]

if shape_ref.type.broadcastable == out_type.broadcastable:
Expand All @@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node):
copy_stack_trace(node.outputs[0], o)
return [o]

# The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
# TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))

return


Expand Down Expand Up @@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node):
return [element_sum]


@register_useless("local_remove_switch_const_cond")
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
@register_specialize
@node_rewriter([Elemwise])
@register_useless("shape_unsafe")
@register_canonicalize("fast_compile", "shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([switch])
def local_useless_switch(fgraph, node):
"""
This rewrite makes the following changes in a graph:
at.switch(cond, left, right) ->
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if left is right -> left
switch(cond, left, right) ->
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if left is right -> left
and
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
if not isinstance(node.op.scalar_op, aes.Switch):
return False

shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None)

if shape_feature is None:
return False

left = node.inputs[1]
right = node.inputs[2]
cond_var = node.inputs[0]
cond = extract_constant(cond_var, only_process_constants=True)
out_bcast = node.outputs[0].type.broadcastable

if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, (np.number, np.bool_)
Expand All @@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node):
else:
out = correct_out

input_shapes = [
tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
for inp in node.inputs
]

out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)

out = alloc(out, *out_shape)
if out.type.broadcastable != out_bcast:
out = broadcast_arrays(out, *node.inputs)[0]

# Copy over stacktrace from selected output to new output
copy_stack_trace(node.outputs + correct_out, out)
Expand All @@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node):
if left == right:
# Note: No need to copy over stacktrace, because the input node
# already has its own stacktrace
if cond.type.is_super(left.type):
if left.type.broadcastable == out_bcast:
return [left]

ret = fill(cond, left)
ret = broadcast_arrays(left, cond)[0]

# Copy over stacktrace from switch output and correct branch
copy_stack_trace(node.outputs + left, ret)
Expand Down
Loading

0 comments on commit 2ac8774

Please sign in to comment.