Skip to content

Merge consecutive reduces #888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 82 additions & 102 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,8 @@
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import (
All,
Any,
Dot,
FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod,
ProdWithoutZeros,
Sum,
_conj,
add,
Expand Down Expand Up @@ -96,6 +91,7 @@
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
Expand All @@ -105,7 +101,11 @@
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)


def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
Expand Down Expand Up @@ -1580,130 +1580,110 @@


@register_canonicalize
@node_rewriter([Sum, Prod])
def local_op_of_op(fgraph, node):
@node_rewriter([CAReduce])
def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
"""
Prod(Prod()) -> single Prod()
or
Sum(Sum()) -> single Sum()
or any CAReduce(Careduce(x)) of the same type

"""
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
# This is done to make sure the rewrite doesn't affect other
# computations.
if len(fgraph.clients[node_inps]) == 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]

# figure out which axes were in the original sum
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)

assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
[inner_reduce] = node.inputs
if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
return None

combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
# Don't apply rewrite if inner_reduce is used elsewhere
if len(fgraph.clients[inner_reduce]) > 1:
return None

# Check if CAReduces have the same scalar op
outer_op: CAReduce = node.op
inner_op = inner_reduce.owner.op

ALL_REDUCE = [
CAReduce,
All,
Any,
Sum,
Prod,
ProdWithoutZeros,
*CAReduce.__subclasses__(),
*FixedOpCAReduce.__subclasses__(),
*NonZeroDimsCAReduce.__subclasses__(),
]
if outer_op.scalar_op != inner_op.scalar_op:
return None

outer_axis = outer_op.axis
inner_axis = inner_op.axis
[x] = inner_reduce.owner.inputs
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if outer_axis is None or inner_axis is None:
return [outer_op.clone(axis=None)(x)]

# Merge axis
newaxis = list(inner_axis)
for i in outer_axis:
new_i = i
for ii in inner_axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)

assert len(newaxis) == len(inner_axis) + len(outer_axis)
return [outer_op.clone(axis=sorted(newaxis))(x)]


@register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_reduce_join(fgraph, node):
"""
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)

Notes
-----
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
all cases.

Currently we must reduce on axis 0. It is probably extensible to the case
where we join and reduce on the same set of axis.
When a, b have a dim length of 1 along the join axis

"""
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
return None

if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
# Support only 2 inputs for now
if len(join_node.inputs) != 3:
return
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
return
elif len(join_node.inputs) <= 2:
# This is a useless join that should get removed by another rewrite?
return
[joined_out] = node.inputs
joined_node = joined_out.owner
join_axis_tensor, *joined_inputs = joined_node.inputs

new_inp = []
for inp in join_node.inputs[1:]:
inp = inp.owner
if not inp:
return
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
"x",
*range(inp.inputs[0].ndim),
):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
n_joined_inputs = len(joined_inputs)
if n_joined_inputs < 2:
# Let some other rewrite get rid of this useless Join
return None

Check warning on line 1647 in pytensor/tensor/rewriting/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/math.py#L1647

Added line #L1647 was not covered by tests
if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
# We don't rewrite if a single Elemwise cannot take all inputs at once
return None

if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return
if not isinstance(join_axis_tensor, Constant):
return None

Check warning on line 1653 in pytensor/tensor/rewriting/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/math.py#L1653

Added line #L1653 was not covered by tests
join_axis = join_axis_tensor.data

reduce_axis = node.op.axis
if reduce_axis is None:
reduce_axis = tuple(range(node.inputs[0].ndim))
# Check whether reduction happens on joined axis
reduce_op = node.op
reduce_axis = reduce_op.axis
if reduce_axis is None:
if joined_out.type.ndim > 1:
return None
elif reduce_axis != (join_axis,):
return None

if len(reduce_axis) != 1 or 0 not in reduce_axis:
return
# Check all inputs are broadcastable along the join axis and squeeze those dims away
new_inputs = []
for inp in joined_inputs:
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_inputs.append(new_input)

# We add the new check late to don't add extra warning.
try:
join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True
)
ret = Elemwise(node.op.scalar_op)(*new_inputs)

if join_axis != reduce_axis[0]:
return
except NotScalarConstantError:
return
if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return None

Check warning on line 1678 in pytensor/tensor/rewriting/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/math.py#L1678

Added line #L1678 was not covered by tests

return [ret]
return [ret]


@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a"""
(summed,) = node.inputs
Expand All @@ -1715,7 +1695,7 @@
@register_canonicalize
@register_uncanonicalize
@register_specialize
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions."""
(reduced,) = node.inputs
Expand Down
Loading
Loading