Skip to content

Commit a9d8ca0

Browse files
Add a fusion rewrite for CAReduces with Elemwise inputs
1 parent 4bec443 commit a9d8ca0

File tree

3 files changed

+118
-3
lines changed

3 files changed

+118
-3
lines changed

aesara/scalar/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4048,7 +4048,7 @@ def __init__(self, inputs, outputs):
40484048

40494049
@property
40504050
def fn(self):
4051-
return self._fn
4051+
return None
40524052

40534053
@property
40544054
def inner_inputs(self):

aesara/tensor/rewriting/elemwise.py

+58-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
from aesara.graph.basic import Apply, Constant, io_toposort
1212
from aesara.graph.features import ReplaceValidate
1313
from aesara.graph.op import compute_test_value, get_test_value
14-
from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter
14+
from aesara.graph.rewriting.basic import (
15+
GraphRewriter,
16+
copy_stack_trace,
17+
in2out,
18+
node_rewriter,
19+
)
1520
from aesara.graph.rewriting.db import SequenceDB
1621
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
1722
from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
18-
from aesara.tensor.elemwise import DimShuffle, Elemwise
23+
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1924
from aesara.tensor.exceptions import NotScalarConstantError
2025
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
2126
from aesara.tensor.shape import shape_padleft
@@ -944,3 +949,54 @@ def local_useless_composite(fgraph, node):
944949
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
945950
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
946951
return dict(zip([node.outputs[i] for i in idx], e))
952+
953+
954+
@node_rewriter([CAReduce])
955+
def local_careduce_fusion(fgraph, node):
956+
"""Fuse a `CAReduce` applied to an `Elemwise`."""
957+
958+
(car_input,) = node.inputs
959+
elm_node = car_input.owner
960+
961+
if elm_node is None or not isinstance(elm_node.op, Elemwise):
962+
return False
963+
964+
elm_inputs = elm_node.inputs
965+
966+
if len(elm_inputs) > 1:
967+
# TODO: Implement the multiple inputs case
968+
raise False
969+
970+
car_axis = node.op.axis
971+
car_scalar_op = node.op.scalar_op
972+
elm_scalar_op = elm_node.op.scalar_op
973+
974+
scalar_elm_inputs = [
975+
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
976+
]
977+
elm_output = elm_scalar_op(*scalar_elm_inputs)
978+
# This input represents the previous value in the `CAReduce` binary reduction
979+
carried_car_input = elm_output.type()
980+
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]
981+
982+
fused_scalar_op = aes.Composite(
983+
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs
984+
)
985+
986+
# The fused `Op` needs to look and behave like a `BinaryScalarOp`
987+
# TODO: Generate a new `type` and make this relationship official?
988+
fused_scalar_op.identity = car_scalar_op.identity
989+
fused_scalar_op.nin = 2
990+
fused_scalar_op.nout = 1
991+
992+
new_car_op = CAReduce(fused_scalar_op, car_axis)
993+
994+
return [new_car_op(*elm_inputs)]
995+
996+
997+
compile.optdb.register( # type: ignore
998+
"local_careduce_fusion",
999+
in2out(local_careduce_fusion),
1000+
"fusion",
1001+
position=49,
1002+
)

tests/tensor/rewriting/test_elemwise.py

+59
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,65 @@ def test_test_values(self, test_value):
11051105
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
11061106
)
11071107

1108+
@pytest.mark.parametrize("linker", ["cvm", "py"])
1109+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
1110+
def test_CAReduce_single_input(self, linker, axis):
1111+
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""
1112+
1113+
mode = Mode(linker=linker)
1114+
mode._optimizer = mode._optimizer.including(
1115+
"local_careduce_fusion",
1116+
"canonicalize",
1117+
"inplace",
1118+
)
1119+
1120+
x = tensor("floatX", shape=(None, None, None), name="x")
1121+
out = exp(x).sum(axis=axis)
1122+
1123+
out_fn = function([x], out, mode=mode)
1124+
(out_node,) = out_fn.maker.fgraph.toposort()
1125+
1126+
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
1127+
1128+
rng = np.random.default_rng(2320)
1129+
x_val = rng.random((4, 3, 2), dtype=config.floatX)
1130+
1131+
exp_res = np.exp(x_val).sum(axis=axis)
1132+
1133+
out_val = out_fn(x_val)
1134+
assert out_val.shape == exp_res.shape
1135+
assert np.allclose(out_val, exp_res)
1136+
1137+
@pytest.mark.xfail(reason="Not implemented")
1138+
@pytest.mark.parametrize("linker", ["cvm", "py"])
1139+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
1140+
def test_CAReduce_multiple_inputs(self, linker, axis):
1141+
"""Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs."""
1142+
1143+
mode = Mode(linker=linker)
1144+
mode._optimizer = mode._optimizer.including(
1145+
"local_careduce_fusion",
1146+
"canonicalize",
1147+
"inplace",
1148+
)
1149+
1150+
x = tensor("floatX", shape=(None, None, None), name="x")
1151+
y = tensor("floatX", shape=(None, None, None), name="y")
1152+
out = (x + y).sum(axis=axis)
1153+
1154+
out_fn = function([x, y], out, mode=mode)
1155+
(out_node,) = out_fn.maker.fgraph.toposort()
1156+
1157+
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
1158+
1159+
rng = np.random.default_rng(2320)
1160+
x_val = rng.random((4, 3, 2), dtype=config.floatX)
1161+
y_val = rng.random((4, 3, 2), dtype=config.floatX)
1162+
exp_res = (x_val + y_val).sum(axis=axis)
1163+
out_val = out_fn(x_val, y_val)
1164+
assert out_val.shape == exp_res.shape
1165+
assert np.allclose(out_val, exp_res)
1166+
11081167

11091168
class TimesN(aes.basic.UnaryScalarOp):
11101169
"""

0 commit comments

Comments
 (0)