|
11 | 11 | from aesara.graph.basic import Apply, Constant, io_toposort
|
12 | 12 | from aesara.graph.features import ReplaceValidate
|
13 | 13 | from aesara.graph.op import compute_test_value, get_test_value
|
14 |
| -from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter |
| 14 | +from aesara.graph.rewriting.basic import ( |
| 15 | + GraphRewriter, |
| 16 | + copy_stack_trace, |
| 17 | + in2out, |
| 18 | + node_rewriter, |
| 19 | +) |
15 | 20 | from aesara.graph.rewriting.db import SequenceDB
|
16 | 21 | from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
|
17 | 22 | from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
|
18 |
| -from aesara.tensor.elemwise import DimShuffle, Elemwise |
| 23 | +from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
19 | 24 | from aesara.tensor.exceptions import NotScalarConstantError
|
20 | 25 | from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
|
21 | 26 | from aesara.tensor.shape import shape_padleft
|
@@ -944,3 +949,54 @@ def local_useless_composite(fgraph, node):
|
944 | 949 | c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
|
945 | 950 | e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
|
946 | 951 | return dict(zip([node.outputs[i] for i in idx], e))
|
| 952 | + |
| 953 | + |
| 954 | +@node_rewriter([CAReduce]) |
| 955 | +def local_careduce_fusion(fgraph, node): |
| 956 | + """Fuse a `CAReduce` applied to an `Elemwise`.""" |
| 957 | + |
| 958 | + (car_input,) = node.inputs |
| 959 | + elm_node = car_input.owner |
| 960 | + |
| 961 | + if elm_node is None or not isinstance(elm_node.op, Elemwise): |
| 962 | + return False |
| 963 | + |
| 964 | + elm_inputs = elm_node.inputs |
| 965 | + |
| 966 | + if len(elm_inputs) > 1: |
| 967 | + # TODO: Implement the multiple inputs case |
| 968 | + raise False |
| 969 | + |
| 970 | + # TODO: What about the `dtype`s and other properties in `CAReduceDtype` types? |
| 971 | + car_axis = node.op.axis |
| 972 | + car_scalar_op = node.op.scalar_op |
| 973 | + elm_scalar_op = elm_node.op.scalar_op |
| 974 | + |
| 975 | + scalar_elm_inputs = [ |
| 976 | + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs |
| 977 | + ] |
| 978 | + elm_output = elm_scalar_op(*scalar_elm_inputs) |
| 979 | + # This input represents the previous value in the `CAReduce` binary reduction |
| 980 | + carried_car_input = elm_output.type() |
| 981 | + scalar_fused_outputs = [car_scalar_op(elm_output, carried_car_input)] |
| 982 | + |
| 983 | + fused_scalar_op = aes.Composite( |
| 984 | + inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs |
| 985 | + ) |
| 986 | + # The fused `Op` needs to look and behave like a `BinaryScalarOp` |
| 987 | + # TODO: Generate a new `type` and make this relationship official? |
| 988 | + fused_scalar_op.identity = car_scalar_op.identity |
| 989 | + fused_scalar_op.nin = 2 |
| 990 | + fused_scalar_op.nout = 1 |
| 991 | + |
| 992 | + new_car_op = CAReduce(fused_scalar_op, car_axis) |
| 993 | + |
| 994 | + return [new_car_op(*elm_inputs)] |
| 995 | + |
| 996 | + |
| 997 | +compile.optdb.register( # type: ignore |
| 998 | + "local_careduce_fusion", |
| 999 | + in2out(local_careduce_fusion), |
| 1000 | + # "fusion", |
| 1001 | + position=49, |
| 1002 | +) |
0 commit comments