Skip to content

Commit d3830c4

Browse files
Add a fusion rewrite for CAReduces with Elemwise inputs
1 parent 63788d6 commit d3830c4

File tree

3 files changed

+168
-3
lines changed

3 files changed

+168
-3
lines changed

aesara/scalar/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4048,7 +4048,7 @@ def __init__(self, inputs, outputs):
40484048

40494049
@property
40504050
def fn(self):
4051-
return self._fn
4051+
return None
40524052

40534053
@property
40544054
def inner_inputs(self):

aesara/tensor/rewriting/elemwise.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
import aesara
88
import aesara.scalar.basic as aes
99
from aesara import compile
10+
from aesara.compile.mode import get_target_language
1011
from aesara.configdefaults import config
1112
from aesara.graph.basic import Apply, Constant, io_toposort
1213
from aesara.graph.features import ReplaceValidate
1314
from aesara.graph.op import compute_test_value, get_test_value
14-
from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter
15+
from aesara.graph.rewriting.basic import (
16+
GraphRewriter,
17+
copy_stack_trace,
18+
in2out,
19+
node_rewriter,
20+
)
1521
from aesara.graph.rewriting.db import SequenceDB
1622
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
1723
from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
18-
from aesara.tensor.elemwise import DimShuffle, Elemwise
24+
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1925
from aesara.tensor.exceptions import NotScalarConstantError
2026
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
2127
from aesara.tensor.shape import shape_padleft
@@ -944,3 +950,82 @@ def local_useless_composite(fgraph, node):
944950
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
945951
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
946952
return dict(zip([node.outputs[i] for i in idx], e))
953+
954+
955+
@node_rewriter([CAReduce])
956+
def local_careduce_fusion(fgraph, node):
957+
"""Fuse a `CAReduce` applied to an `Elemwise`."""
958+
959+
(car_input,) = node.inputs
960+
elm_node = car_input.owner
961+
962+
if elm_node is None or not isinstance(elm_node.op, Elemwise):
963+
return False
964+
965+
elm_inputs = elm_node.inputs
966+
elm_outputs = elm_node.outputs
967+
968+
if len(elm_inputs) > 1 or len(elm_outputs) > 1:
969+
# TODO: Implement the multiple inputs case
970+
return False
971+
972+
if len(fgraph.clients[elm_outputs[0]]) > 1:
973+
return False
974+
975+
# Don't form the fusion when the target language is Python
976+
elm_scalar_op = elm_node.op.scalar_op
977+
car_scalar_op = node.op.scalar_op
978+
979+
if get_target_language() == ("py",):
980+
return False
981+
982+
try:
983+
elm_scalar_op.c_code(
984+
elm_node,
985+
"test_presence_of_c_code",
986+
["x" for x in elm_inputs],
987+
["z" for z in elm_outputs],
988+
{"fail": "%(fail)s"},
989+
)
990+
991+
car_scalar_op.c_code(
992+
node,
993+
"test_presence_of_c_code",
994+
["x" for x in node.inputs],
995+
["z" for z in node.outputs],
996+
{"fail": "%(fail)s"},
997+
)
998+
except (NotImplementedError, MethodNotDefined):
999+
return False
1000+
1001+
car_axis = node.op.axis
1002+
1003+
scalar_elm_inputs = [
1004+
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
1005+
]
1006+
elm_output = elm_scalar_op(*scalar_elm_inputs)
1007+
# This input represents the previous value in the `CAReduce` binary reduction
1008+
carried_car_input = elm_output.type()
1009+
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]
1010+
1011+
fused_scalar_op = aes.Composite(
1012+
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs
1013+
)
1014+
1015+
# The fused `Op` needs to look and behave like a `BinaryScalarOp`
1016+
# TODO: Generate a new `type` and make this relationship official?
1017+
fused_scalar_op.identity = car_scalar_op.identity
1018+
fused_scalar_op.nin = 2
1019+
fused_scalar_op.nout = 1
1020+
1021+
new_car_op = CAReduce(fused_scalar_op, car_axis)
1022+
1023+
return [new_car_op(*elm_inputs)]
1024+
1025+
1026+
compile.optdb.register( # type: ignore
1027+
"local_careduce_fusion",
1028+
in2out(local_careduce_fusion),
1029+
"fusion",
1030+
position=49,
1031+
)

tests/tensor/rewriting/test_elemwise.py

+80
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,86 @@ def test_test_values(self, test_value):
11131113
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
11141114
)
11151115

1116+
@pytest.mark.parametrize("linker", ["cvm", "py"])
1117+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
1118+
def test_CAReduce_single_input(self, linker, axis):
1119+
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""
1120+
1121+
mode = Mode(linker=linker)
1122+
mode._optimizer = mode._optimizer.including(
1123+
"local_careduce_fusion",
1124+
"canonicalize",
1125+
"inplace",
1126+
)
1127+
1128+
x = tensor("floatX", shape=(None, None, None), name="x")
1129+
out = exp(x).sum(axis=axis)
1130+
1131+
out_fn = function([x], out, mode=mode)
1132+
1133+
if linker != "py":
1134+
(out_node,) = out_fn.maker.fgraph.toposort()
1135+
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
1136+
1137+
rng = np.random.default_rng(2320)
1138+
x_val = rng.random((4, 3, 2), dtype=config.floatX)
1139+
1140+
exp_res = np.exp(x_val).sum(axis=axis)
1141+
1142+
out_val = out_fn(x_val)
1143+
assert out_val.shape == exp_res.shape
1144+
assert np.allclose(out_val, exp_res)
1145+
else:
1146+
out_nodes = out_fn.maker.fgraph.toposort()
1147+
assert not any(
1148+
isinstance(out_node.op.scalar_op, aes.basic.Composite)
1149+
for out_node in out_nodes
1150+
if hasattr(out_node.op, "scalar_op")
1151+
)
1152+
1153+
# `Elemwise`s with more than one client shouldn't be rewritten
1154+
x = tensor("floatX", shape=(None, None, None), name="x")
1155+
exp_x = exp(x)
1156+
out = exp_x.sum(axis=axis) + exp(x)
1157+
1158+
out_fn = function([x], out, mode=mode)
1159+
out_nodes = out_fn.maker.fgraph.toposort()
1160+
assert not any(
1161+
isinstance(out_node.op.scalar_op, aes.basic.Composite)
1162+
for out_node in out_nodes
1163+
if hasattr(out_node.op, "scalar_op")
1164+
)
1165+
1166+
@pytest.mark.xfail(reason="Not implemented")
1167+
@pytest.mark.parametrize("linker", ["cvm", "py"])
1168+
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
1169+
def test_CAReduce_multiple_inputs(self, linker, axis):
1170+
"""Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs."""
1171+
1172+
mode = Mode(linker=linker)
1173+
mode._optimizer = mode._optimizer.including(
1174+
"local_careduce_fusion",
1175+
"canonicalize",
1176+
"inplace",
1177+
)
1178+
1179+
x = tensor("floatX", shape=(None, None, None), name="x")
1180+
y = tensor("floatX", shape=(None, None, None), name="y")
1181+
out = (x + y).sum(axis=axis)
1182+
1183+
out_fn = function([x, y], out, mode=mode)
1184+
(out_node,) = out_fn.maker.fgraph.toposort()
1185+
1186+
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
1187+
1188+
rng = np.random.default_rng(2320)
1189+
x_val = rng.random((4, 3, 2), dtype=config.floatX)
1190+
y_val = rng.random((4, 3, 2), dtype=config.floatX)
1191+
exp_res = (x_val + y_val).sum(axis=axis)
1192+
out_val = out_fn(x_val, y_val)
1193+
assert out_val.shape == exp_res.shape
1194+
assert np.allclose(out_val, exp_res)
1195+
11161196

11171197
class TimesN(aes.basic.UnaryScalarOp):
11181198
"""

0 commit comments

Comments
 (0)