Skip to content

Commit

Permalink
Improve variable names and comments/docstring in local_elemwise_fusio…
Browse files Browse the repository at this point in the history
…n_op
  • Loading branch information
brandonwillard committed Oct 17, 2020
1 parent 9df70fd commit 74ee82b
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions theano/tensor/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7609,43 +7609,52 @@ def check_input(inputs):
"""


# ###############
# # Loop fusion #
# ###############
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
"""
We parametrize it to make it work for Elemwise and GpuElemwise op.
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
"""Create a recursive function that fuses `Elemwise` `Op`s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` and `GpuElemwise` `Op`s.
Parameters
----------
OP
GpuElemwise or Elemwise class (the one that we want to fuse)
max_input_fct
A function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise).
GPU kernel currently has a limit of 256 bytes for
the size of all parameters passed to it. As currently
we pass many information only by parameter, we must
limit how many ops we fuse together to avoid busting
that 256 limit.
op_class : type
`GpuElemwise` or `Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take (useful for `GpuElemwise`). The GPU kernel currently has a
limit of 256 bytes for the size of all parameters passed to it. As
currently we pass a lot of information only by parameter, we must limit how
many `Op`s we fuse together to avoid busting that 256 limit.
On the CPU we limit to 32 input variables
since that is the maximum numpy support.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature `(node, *args)` that constructs an
`op_class` instance (e.g. `op_class(*args)`).
"""
if maker is None:

def maker(node, scalar_op):
return OP(scalar_op)
return op_class(scalar_op)

def local_fuse(node):
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
"""Fuse `Elemwise` `Op`s in a node.
As part of specialization, we fuse two consecutive elemwise `Op`s of the
same shape.
For mixed dtype, we let the Composite op do the cast. It lets the C
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by theano itself.
The number of dimensions is validated at call time by Theano itself.
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
Expand All @@ -7672,12 +7681,13 @@ def local_fuse(node):
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)

if type(node.op) is not OP:
if type(node.op) is not op_class:
return False

if len(node.outputs) > 1:
# We don't support the fusion for node with multiple outputs.
# We don't support fusion for nodes with multiple outputs.
return

inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
Expand Down Expand Up @@ -7710,7 +7720,7 @@ def local_fuse(node):
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, OP)
and isinstance(i.owner.op, op_class)
and len(set([n for n, idx in i.clients])) == 1
and
# Do not merge elemwise that don't have the same
Expand All @@ -7736,9 +7746,11 @@ def local_fuse(node):
tmp.tag.test_value = tv
except AttributeError:
pass

tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])

s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True)

# if the scalar_op don't have a c implementation,
Expand Down Expand Up @@ -7786,8 +7798,8 @@ def local_fuse(node):
s_inputs.extend(tmp_scalar)
s_g.extend(s_op)
else:
# We must support the case where the same variable appear many
# time in the inputs
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
Expand Down Expand Up @@ -7834,15 +7846,16 @@ def local_fuse(node):
)

# create the composite op.
C = scalar.Composite(s_inputs, s_new_out)
composite_op = scalar.Composite(s_inputs, s_new_out)

# create the new node.
# Do not call make_node to have test_value
n = maker(node, C)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype
new_node = maker(node, composite_op)(*inputs).owner

assert len(new_node.outputs) == 1
assert node.outputs[0].dtype == new_node.outputs[0].dtype

if len(n.inputs) > max_nb_input:
if len(new_node.inputs) > max_nb_input:
_logger.warning(
"loop fusion failed because Op would exceed" " kernel argument limit."
)
Expand All @@ -7851,16 +7864,15 @@ def local_fuse(node):
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(n)
ret = local_fuse(new_node)
if ret is not False and ret is not None:
# print n,ret
assert len(ret) == len(n.outputs)
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
n = ret[0].owner
new_node = ret[0].owner
else:
break

return n.outputs
return new_node.outputs

return local_fuse

Expand Down

0 comments on commit 74ee82b

Please sign in to comment.