-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite rank 0 elemwise ops and push scalar constants into elemwise #107
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,8 @@ | |
in2out, | ||
node_rewriter, | ||
) | ||
from pytensor.graph.rewriting.db import SequenceDB | ||
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError | ||
from pytensor.tensor import as_tensor_variable | ||
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value | ||
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise | ||
from pytensor.tensor.exceptions import NotScalarConstantError | ||
|
@@ -380,6 +380,91 @@ def is_dimshuffle_useless(new_order, input): | |
return is_useless | ||
|
||
|
||
@node_rewriter([Elemwise]) | ||
def elemwise_to_scalar(fgraph, node): | ||
op = node.op | ||
if not all(input.ndim == 0 for input in node.inputs): | ||
return False | ||
|
||
scalars = [aes.as_scalar(input) for input in node.inputs] | ||
|
||
# TODO Something like | ||
# copy_stack_trace(node.outputs[0], new_res) | ||
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] | ||
|
||
|
||
compile.optdb["scalarize"].register( | ||
"local_elemwise_to_scalar", | ||
elemwise_to_scalar, | ||
"fast_run", | ||
"fast_compile", | ||
"numba_only", | ||
) | ||
|
||
|
||
@node_rewriter([Elemwise]) | ||
def push_elemwise_constants(fgraph, node): | ||
"""Push constant scalars from inputs to elemwise to inputs of the | ||
contained scalar op. | ||
""" | ||
op = node.op | ||
if any(op.inplace_pattern): | ||
return False | ||
|
||
if not isinstance(node.op.scalar_op, aes.Composite): | ||
return False | ||
|
||
def is_constant_scalar(x): | ||
return isinstance(x, TensorConstant) and all(x.broadcastable) | ||
|
||
push_idxs = [] | ||
push_values = [] | ||
keep_values = [] | ||
for i, input in enumerate(node.inputs): | ||
if is_constant_scalar(input): | ||
push_idxs.append(i) | ||
val = input.value | ||
push_values.append(aes.constant(val.item(), dtype=val.dtype)) | ||
elif ( | ||
input.owner | ||
and isinstance(input.owner.op, DimShuffle) | ||
and is_constant_scalar(input.owner.inputs[0]) | ||
): | ||
push_idxs.append(i) | ||
val = input.owner.inputs[0].value | ||
push_values.append(aes.constant(val.item(), dtype=val.dtype)) | ||
else: | ||
keep_values.append(input) | ||
|
||
if not push_values: | ||
return False | ||
|
||
inner_graph = node.op.scalar_op.fgraph | ||
to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs] | ||
|
||
# Clone the inner graph, it might be used somewhere else | ||
inner_graph, mapping = inner_graph.clone_get_equiv() | ||
inner_graph.replace_all( | ||
(mapping[old], new) for old, new in zip(to_replace, push_values) | ||
) | ||
|
||
new_inputs = [ | ||
input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs | ||
] | ||
return ( | ||
Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... I am curious why would it fail. I can have a look at the generated C code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, figuring out what exactly is going wrong here would be good I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somehow I can't reproduce the segfaults anymore... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested one simple example locally and it seems to work in the C-backend. Can you share the problematic example? The more I look at this PR the more it seems it shouldn't be made Numba specific! |
||
.make_node(*keep_values) | ||
.outputs | ||
) | ||
|
||
|
||
compile.optdb["post_fusion"].register( | ||
"push_elemwise_constants", | ||
push_elemwise_constants, | ||
"numba_only", | ||
) | ||
|
||
|
||
@register_canonicalize | ||
@register_specialize | ||
@node_rewriter([DimShuffle]) | ||
|
@@ -898,34 +983,13 @@ def print_profile(cls, stream, prof, level=0): | |
print(blanc, " time_toposort", prof[7], file=stream) | ||
|
||
|
||
if config.tensor__local_elemwise_fusion: | ||
# Must be after gpu(48.5) and before AddDestroyHandler(49.5) | ||
fuse_seqopt = SequenceDB() | ||
fuse_seqopt.register( | ||
"composite_elemwise_fusion", | ||
FusionOptimizer(local_elemwise_fusion), | ||
"fast_run", | ||
"fusion", | ||
position=1, | ||
) | ||
compile.optdb.register( # type: ignore | ||
"elemwise_fusion", | ||
fuse_seqopt, | ||
"fast_run", | ||
"fusion", | ||
"local_elemwise_fusion", | ||
"FusionOptimizer", | ||
position=49, | ||
) | ||
else: | ||
compile.optdb.register( # type: ignore | ||
"elemwise_fusion", | ||
FusionOptimizer(local_elemwise_fusion), | ||
"fusion", | ||
"local_elemwise_fusion", | ||
"FusionOptimizer", | ||
position=49, | ||
) | ||
compile.optdb["elemwise_fusion"].register( # type: ignore | ||
"composite_elemwise_fusion", | ||
FusionOptimizer(local_elemwise_fusion), | ||
"fast_run", | ||
"fusion", | ||
position=1, | ||
) | ||
|
||
|
||
@register_canonicalize | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
import pytensor.scalar.basic as aes | ||
import pytensor.scalar.math as aes_math | ||
from pytensor import compile | ||
from pytensor.graph.basic import Constant, Variable | ||
from pytensor.graph.rewriting.basic import ( | ||
NodeRewriter, | ||
|
@@ -85,13 +86,14 @@ | |
encompasses_broadcastable, | ||
local_fill_sink, | ||
register_canonicalize, | ||
register_scalarize, | ||
register_specialize, | ||
register_specialize_device, | ||
register_stabilize, | ||
register_uncanonicalize, | ||
register_useless, | ||
) | ||
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt | ||
from pytensor.tensor.rewriting.elemwise import FusionOptimizer | ||
from pytensor.tensor.shape import Shape, Shape_i | ||
from pytensor.tensor.subtensor import Subtensor | ||
from pytensor.tensor.type import ( | ||
|
@@ -1567,6 +1569,18 @@ def local_op_of_op(fgraph, node): | |
return [combined(node_inps.owner.inputs[0])] | ||
|
||
|
||
@register_scalarize | ||
@node_rewriter([Sum]) | ||
def local_sum_of_makevector(fgraph, node): | ||
(array,) = node.inputs | ||
if not array.owner or not isinstance(array.owner.op, MakeVector): | ||
return False | ||
|
||
values = array.owner.inputs | ||
summed = aes.add(*values) | ||
return [as_tensor_variable(summed)] | ||
Comment on lines
+1572
to
+1581
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to touch on #59 Can we abstract the scalarize part from the "lift reduction operations towards the inputs", which is useful regardless of the backend? Even the scalarize seems useful in both backends. What was the problem with the C backend again? |
||
|
||
|
||
ALL_REDUCE = ( | ||
[ | ||
CAReduce, | ||
|
@@ -2922,7 +2936,7 @@ def local_add_mul_fusion(fgraph, node): | |
return [output] | ||
|
||
|
||
fuse_seqopt.register( | ||
compile.optdb["elemwise_fusion"].register( | ||
"local_add_mul_fusion", | ||
FusionOptimizer(local_add_mul_fusion), | ||
"fast_run", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -469,8 +469,8 @@ def local_subtensor_lift(fgraph, node): | |
return [rbcast_subt_x] | ||
|
||
|
||
@register_canonicalize | ||
@register_specialize | ||
@register_stabilize("cxx_only") | ||
@register_canonicalize("cxx_only") | ||
Comment on lines
+472
to
+473
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me push against this approach. There are three scenarios for this rewrite:
I don't see the reason for 3 If you start making rewrites exclusive to the C-backend you will forget about 2. But eventually you will want to make numba the default backend and you will want the old tests to pass. You will now have made your task much more challenging because you diverged the C and Numba backends, and the latter's test suite is way more myopic. It's actually a blessing that Theano/Aesara had very extensive test suites and it was difficult to break things unintentionally. But restricting rewrites to the old well tested backend that we want to eventually replace by the new poorly tested one, is opting out of this safety net. In a sense you will just be kicking the can down the road. The decision about the rewrite will have to be done regardless, but by then the Numba rewrite passes may look so different (because it was developed in a much more forgiving test suite) that you cannot even reason about the two and make an informed choice. In short I think we should be very very selective about the rewrites that are backend specific. For instance I think we should definitely investigate if the scalarize changes also make sense for the C and JAX backends. |
||
@node_rewriter([Subtensor]) | ||
def local_subtensor_merge(fgraph, node): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine to be inplace because constants are never inplaced. But to not have to deal with it just register this rewrite before the inplace rewrites
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was worried that maybe some downstream op is assuming that one of the inputs has in fact changed? It should be running before the inline passes anyway though....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That shouldn't happen. Inplace rewrites are myopic, they only look at 1 node at a time. I never saw a rewrite checking inplace patterns elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, what about examples like this?
Replacing the inplace Elemwise with a non-inplace Elemwise would be incorrect here.
Still not a problem because the rewrite is registered before the inplace pass, but still...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure there would be...
But it looks like we need this final Elemwise in the graph not for its output but only for its side effect of changing the first input. If we were to replace this node with an Elemwise without the inplace flag, but the same output, wouldn't the update break? But maybe there is a feature somewhere that prevents this?