diff --git a/theano/tensor/opt.py b/theano/tensor/opt.py index 565b12e923..97ddf04f69 100644 --- a/theano/tensor/opt.py +++ b/theano/tensor/opt.py @@ -7609,43 +7609,52 @@ def check_input(inputs): """ -# ############### -# # Loop fusion # -# ############### -def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): - """ - We parametrize it to make it work for Elemwise and GpuElemwise op. +def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): + """Create a recursive function that fuses `Elemwise` `Op`s. + + The basic idea is that we loop through an `Elemwise` node's inputs, find + other `Elemwise` nodes, determine the scalars input types for all of the + `Elemwise` `Op`s, construct a new scalar `Op` using the scalar input types + and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a + new "fused" `Elemwise`. + + It's parameterized in order to work for `Elemwise` and `GpuElemwise` `Op`s. Parameters ---------- - OP - GpuElemwise or Elemwise class (the one that we want to fuse) - max_input_fct - A function that returns the maximum number of inputs - that this elemwise can take (useful for GpuElemwise). - GPU kernel currently has a limit of 256 bytes for - the size of all parameters passed to it. As currently - we pass many information only by parameter, we must - limit how many ops we fuse together to avoid busting - that 256 limit. + op_class : type + `GpuElemwise` or `Elemwise` class (the one that we want to fuse) + max_input_fct : callable + A function that returns the maximum number of inputs that this `Elemwise` + can take (useful for `GpuElemwise`). The GPU kernel currently has a + limit of 256 bytes for the size of all parameters passed to it. As + currently we pass a lot of information only by parameter, we must limit how + many `Op`s we fuse together to avoid busting that 256 limit. - On the CPU we limit to 32 input variables - since that is the maximum numpy support. + On the CPU we limit to 32 input variables since that is the maximum + NumPy support. + + maker: callable + A function with the signature `(node, *args)` that constructs an + `op_class` instance (e.g. `op_class(*args)`). """ if maker is None: def maker(node, scalar_op): - return OP(scalar_op) + return op_class(scalar_op) def local_fuse(node): - """ - As part of specialization, we fuse two consecutive elemwise Ops of the + """Fuse `Elemwise` `Op`s in a node. + + + As part of specialization, we fuse two consecutive elemwise `Op`s of the same shape. - For mixed dtype, we let the Composite op do the cast. It lets the C + For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C compiler do the cast. - The number of dimensions is validated at call time by theano itself. + + The number of dimensions is validated at call time by Theano itself. """ # META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!! @@ -7672,12 +7681,13 @@ def local_fuse(node): # worthwhile if the summation axis doesn't line up with a # contiguous dimension) - if type(node.op) is not OP: + if type(node.op) is not op_class: return False if len(node.outputs) > 1: - # We don't support the fusion for node with multiple outputs. + # We don't support fusion for nodes with multiple outputs. return + inputs = [] # inputs of the new Elemwise op. s_inputs = [] # inputs of the new scalar op used by the Composite. # Inputs of the new scalar op that represents the current node. @@ -7710,7 +7720,7 @@ def local_fuse(node): # we still want to fusion. So we take the set. if ( i.owner - and isinstance(i.owner.op, OP) + and isinstance(i.owner.op, op_class) and len(set([n for n, idx in i.clients])) == 1 and # Do not merge elemwise that don't have the same @@ -7736,9 +7746,11 @@ def local_fuse(node): tmp.tag.test_value = tv except AttributeError: pass + tmp_s_input.append(tmp) tmp_input.append(ii) tmp_scalar.append(tmp_s_input[-1]) + s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True) # if the scalar_op don't have a c implementation, @@ -7786,8 +7798,8 @@ def local_fuse(node): s_inputs.extend(tmp_scalar) s_g.extend(s_op) else: - # We must support the case where the same variable appear many - # time in the inputs + # We must support the case where the same variable appears many + # times within the inputs if inputs.count(i) == node.inputs.count(i): s = s_inputs[inputs.index(i)] else: @@ -7834,15 +7846,16 @@ def local_fuse(node): ) # create the composite op. - C = scalar.Composite(s_inputs, s_new_out) + composite_op = scalar.Composite(s_inputs, s_new_out) # create the new node. # Do not call make_node to have test_value - n = maker(node, C)(*inputs).owner - assert len(n.outputs) == 1 - assert node.outputs[0].dtype == n.outputs[0].dtype + new_node = maker(node, composite_op)(*inputs).owner + + assert len(new_node.outputs) == 1 + assert node.outputs[0].dtype == new_node.outputs[0].dtype - if len(n.inputs) > max_nb_input: + if len(new_node.inputs) > max_nb_input: _logger.warning( "loop fusion failed because Op would exceed" " kernel argument limit." ) @@ -7851,16 +7864,15 @@ def local_fuse(node): # we fuse as many that we can at the same time to make debug mode faster # debug mode will be faster as it won't test all intermediate step. while True: - ret = local_fuse(n) + ret = local_fuse(new_node) if ret is not False and ret is not None: - # print n,ret - assert len(ret) == len(n.outputs) + assert len(ret) == len(new_node.outputs) assert len(ret) == 1 - n = ret[0].owner + new_node = ret[0].owner else: break - return n.outputs + return new_node.outputs return local_fuse