diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 1f46fef1ed..19dfb120ea 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -4048,7 +4048,7 @@ def __init__(self, inputs, outputs): @property def fn(self): - return self._fn + return None @property def inner_inputs(self): diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index e80e871370..0e81704185 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -7,15 +7,21 @@ import aesara import aesara.scalar.basic as aes from aesara import compile +from aesara.compile.mode import get_target_language from aesara.configdefaults import config from aesara.graph.basic import Apply, Constant, io_toposort from aesara.graph.features import ReplaceValidate from aesara.graph.op import compute_test_value, get_test_value -from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter +from aesara.graph.rewriting.basic import ( + GraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) from aesara.graph.rewriting.db import SequenceDB from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value -from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize from aesara.tensor.shape import shape_padleft @@ -944,3 +950,82 @@ def local_useless_composite(fgraph, node): c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) return dict(zip([node.outputs[i] for i in idx], e)) + + +@node_rewriter([CAReduce]) +def local_careduce_fusion(fgraph, node): + """Fuse a `CAReduce` applied to an `Elemwise`.""" + + (car_input,) = node.inputs + elm_node = car_input.owner + + if elm_node is None or not isinstance(elm_node.op, Elemwise): + return False + + elm_inputs = elm_node.inputs + elm_outputs = elm_node.outputs + + if len(elm_inputs) > 1 or len(elm_outputs) > 1: + # TODO: Implement the multiple inputs case + return False + + if len(fgraph.clients[elm_outputs[0]]) > 1: + return False + + # Don't form the fusion when the target language is Python + elm_scalar_op = elm_node.op.scalar_op + car_scalar_op = node.op.scalar_op + + if get_target_language() == ("py",): + return False + + try: + elm_scalar_op.c_code( + elm_node, + "test_presence_of_c_code", + ["x" for x in elm_inputs], + ["z" for z in elm_outputs], + {"fail": "%(fail)s"}, + ) + + car_scalar_op.c_code( + node, + "test_presence_of_c_code", + ["x" for x in node.inputs], + ["z" for z in node.outputs], + {"fail": "%(fail)s"}, + ) + except (NotImplementedError, MethodNotDefined): + return False + + car_axis = node.op.axis + + scalar_elm_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs + ] + elm_output = elm_scalar_op(*scalar_elm_inputs) + # This input represents the previous value in the `CAReduce` binary reduction + carried_car_input = elm_output.type() + scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)] + + fused_scalar_op = aes.Composite( + inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs + ) + + # The fused `Op` needs to look and behave like a `BinaryScalarOp` + # TODO: Generate a new `type` and make this relationship official? + fused_scalar_op.identity = car_scalar_op.identity + fused_scalar_op.nin = 2 + fused_scalar_op.nout = 1 + + new_car_op = CAReduce(fused_scalar_op, car_axis) + + return [new_car_op(*elm_inputs)] + + +compile.optdb.register( # type: ignore + "local_careduce_fusion", + in2out(local_careduce_fusion), + "fusion", + position=49, +) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index ec1525ff43..d4d5c58cdf 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1113,6 +1113,86 @@ def test_test_values(self, test_value): f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] ) + @pytest.mark.parametrize("linker", ["cvm", "py"]) + @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) + def test_CAReduce_single_input(self, linker, axis): + """Make sure that `CAReduce` and `Elemwise` fusions work with a single input.""" + + mode = Mode(linker=linker) + mode._optimizer = mode._optimizer.including( + "local_careduce_fusion", + "canonicalize", + "inplace", + ) + + x = tensor("floatX", shape=(None, None, None), name="x") + out = exp(x).sum(axis=axis) + + out_fn = function([x], out, mode=mode) + + if linker != "py": + (out_node,) = out_fn.maker.fgraph.toposort() + assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) + + rng = np.random.default_rng(2320) + x_val = rng.random((4, 3, 2), dtype=config.floatX) + + exp_res = np.exp(x_val).sum(axis=axis) + + out_val = out_fn(x_val) + assert out_val.shape == exp_res.shape + assert np.allclose(out_val, exp_res) + else: + out_nodes = out_fn.maker.fgraph.toposort() + assert not any( + isinstance(out_node.op.scalar_op, aes.basic.Composite) + for out_node in out_nodes + if hasattr(out_node.op, "scalar_op") + ) + + # `Elemwise`s with more than one client shouldn't be rewritten + x = tensor("floatX", shape=(None, None, None), name="x") + exp_x = exp(x) + out = exp_x.sum(axis=axis) + exp(x) + + out_fn = function([x], out, mode=mode) + out_nodes = out_fn.maker.fgraph.toposort() + assert not any( + isinstance(out_node.op.scalar_op, aes.basic.Composite) + for out_node in out_nodes + if hasattr(out_node.op, "scalar_op") + ) + + @pytest.mark.xfail(reason="Not implemented") + @pytest.mark.parametrize("linker", ["cvm", "py"]) + @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) + def test_CAReduce_multiple_inputs(self, linker, axis): + """Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs.""" + + mode = Mode(linker=linker) + mode._optimizer = mode._optimizer.including( + "local_careduce_fusion", + "canonicalize", + "inplace", + ) + + x = tensor("floatX", shape=(None, None, None), name="x") + y = tensor("floatX", shape=(None, None, None), name="y") + out = (x + y).sum(axis=axis) + + out_fn = function([x, y], out, mode=mode) + (out_node,) = out_fn.maker.fgraph.toposort() + + assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) + + rng = np.random.default_rng(2320) + x_val = rng.random((4, 3, 2), dtype=config.floatX) + y_val = rng.random((4, 3, 2), dtype=config.floatX) + exp_res = (x_val + y_val).sum(axis=axis) + out_val = out_fn(x_val, y_val) + assert out_val.shape == exp_res.shape + assert np.allclose(out_val, exp_res) + class TimesN(aes.basic.UnaryScalarOp): """