-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite rank 0 elemwise ops and push scalar constants into elemwise #107
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the Elemwise related changes. Indeed Theano rewriting is built all around tensor nodes.
I am very interested by the inline of scalar Constants in Composite. I don't see why it would fail with the C backend ...
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] | ||
|
||
|
||
compile.optdb["specialize"].register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specialize is optional, if Numba will fail without this rewrite we should add a new non optional rewrite phase at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should probably make it so that elemwise doesn't fail in numba either way.
To be honest, I'm having a hard time seeing how we could expect everything to work nicely if users pick and choose rewrites. I think that is a testing nightmare...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to rething the phases a bit for sure though. Maybe this actually belongs in "uncanonicalize" or so? Or we could have a new phase "scalarize"...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it to a new "scalarize" phase
if not isinstance(op, Elemwise): | ||
return False | ||
|
||
if any(op.inplace_pattern): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine to be inplace because constants are never inplaced. But to not have to deal with it just register this rewrite before the inplace rewrites
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was worried that maybe some downstream op is assuming that one of the inputs has in fact changed? It should be running before the inline passes anyway though....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That shouldn't happen. Inplace rewrites are myopic, they only look at 1 node at a time. I never saw a rewrite checking inplace patterns elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, what about examples like this?
import pytensor.tensor as pt
import pytensor
x = pt.dvector("x")
y = 2 * x + 1
val = np.ones(3)
input = pytensor.In(x, update=y)
func = pytensor.function([input], [])
pytensor.dprint(func)
Elemwise{Composite{(1.0 + (2.0 * i0))}}[(0, 0)] [id A] 0
|x [id B]
Replacing the inplace Elemwise with a non-inplace Elemwise would be incorrect here.
Still not a problem because the rewrite is registered before the inplace pass, but still...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure there would be...
But it looks like we need this final Elemwise in the graph not for its output but only for its side effect of changing the first input. If we were to replace this node with an Elemwise without the inplace flag, but the same output, wouldn't the update break? But maybe there is a feature somewhere that prevents this?
input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs | ||
] | ||
return ( | ||
Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... I am curious why would it fail. I can have a look at the generated C code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, figuring out what exactly is going wrong here would be good I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow I can't reproduce the segfaults anymore...
I'm getting compilation errors however.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested one simple example locally and it seems to work in the C-backend. Can you share the problematic example? The more I look at this PR the more it seems it shouldn't be made Numba specific!
@@ -380,6 +380,99 @@ def is_dimshuffle_useless(new_order, input): | |||
return is_useless | |||
|
|||
|
|||
@node_rewriter([Elemwise]) | |||
def local_elemwise_lift_scalars(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rewrite is not really lifting scalars (they are sandwiched between the same inputs and outputs), maybe call it "elemwise_to_scalar"?
Anyway, wouldn't it be easier to create a different Numba function when dispatching? Don't you still have the same problem with mixed rank0 and non rank0 inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, elemwise_to_scalar is better :-)
I could have a separate numba function to solve the immediate scalar issue in elemwise...
I'll write a bit more about my motivation for this rewrite below...
@ricardoV94 I expanded the description a bit to add a bit of the motivation. I'm curious to hear what you think :-) |
I totally agree with specializing for scalar graphs. I would do it at a later rewrite phase like you suggested. That way we can keep coverage from our rewrites without duplicated work. If JAX and C show speedups we could also include that phase in those backends. We really need to start a benchmark suite to guide performance related changes! |
There is no need for an Elemwise Op if all inputs have rank 0. And we don't need to use scalar constants as inputs of the Elemwise, they can be inputs for the scalar_op.
d651753
to
011d3f6
Compare
011d3f6
to
b0c4462
Compare
if not isinstance(op, Elemwise): | ||
return False | ||
|
||
if any(op.inplace_pattern): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?
@register_scalarize | ||
@node_rewriter([Sum]) | ||
def local_sum_of_makevector(fgraph, node): | ||
(array,) = node.inputs | ||
if not array.owner or not isinstance(array.owner.op, MakeVector): | ||
return False | ||
|
||
values = array.owner.inputs | ||
summed = aes.add(*values) | ||
return [as_tensor_variable(summed)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to touch on #59
Can we abstract the scalarize part from the "lift reduction operations towards the inputs", which is useful regardless of the backend? Even the scalarize seems useful in both backends. What was the problem with the C backend again?
@register_stabilize("cxx_only") | ||
@register_canonicalize("cxx_only") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me push against this approach.
There are three scenarios for this rewrite:
- It's not very useful and should be reconsidered, regardless of backend
- It's useful in the context of a larger chain of rewrites, regardless of backend
- It's only useful in one specific backend.
I don't see the reason for 3
If you start making rewrites exclusive to the C-backend you will forget about 2. But eventually you will want to make numba the default backend and you will want the old tests to pass. You will now have made your task much more challenging because you diverged the C and Numba backends, and the latter's test suite is way more myopic.
It's actually a blessing that Theano/Aesara had very extensive test suites and it was difficult to break things unintentionally. But restricting rewrites to the old well tested backend that we want to eventually replace by the new poorly tested one, is opting out of this safety net. In a sense you will just be kicking the can down the road. The decision about the rewrite will have to be done regardless, but by then the Numba rewrite passes may look so different (because it was developed in a much more forgiving test suite) that you cannot even reason about the two and make an informed choice.
In short I think we should be very very selective about the rewrites that are backend specific. For instance I think we should definitely investigate if the scalarize changes also make sense for the C and JAX backends.
I think #349 might be a good solution? |
A bit more groundwork for the #92, to remove some cases where elemwise ops are not actually needed.
This adds two rewrites:
local_elemwise_lift_scalars
is meant to remove Elemwise ops that do not actually vectorize anything because all inputs are rank 0 tensors, and replaces them by aTensorFromScalar(Composite)
. I put this into the "specialize" phase, because I think it interacts badly with stabilization rewrites (that come before the specialize phase). I'm not really sure I like the way this works yet. It seems that currently most rewrites assume that during canonicalization we always deal with elemwise ops instead of scalar ops. And many rewrites then only apply to those. So for instance if we avoid all elemwise ops and only use scalars, even basic rewrites are not applied at all:is not rewritten to use log1p:
because those rewrites only target tensor log and add ops, not the scalar versions.
So if we put this new rewrite in
canonicalize
(where I think it maybe should belong?) it would break a lot of other rewrites that assume that everything is wrapped in an elemwise.I wonder if we can change the pattern rewriter so that it figures this out automatically and also applies rewrites to matching scalar ops? I think this would give us quite a bit more flexibility, because then we can change things to scalar ops, which should usually compile and execute faster.
push_elemwise_constants
I made it so this rewrite only applies to numba (see below), and is applied after Elemwise fusion. We often have Ops that use scalar constants in an elemwise, but those constants are provided as rank 0 inputs to the Elemwise ops and then broadcasted, when they could just be ScalarConstants in the inner Composite op. So we rewriteto
I ran into segfaults when I tried this with the c backend, so this only applies to numba for now.
I also made some small changes to the rewrites: Previously the elemwise fusion db was created in a different file than all the other basic rewrite dbs (mode.py), so I moved that to the others to make it more consistent. I then also created a new rewrite phase
post_fusion
that is executed right after the elemwise_fusion rewrites, that currently only containspush_elemwise_constants
.Update
local_subtensor_merge
from thecanonicalize
pass. This is supposed to simplify chainedSubtensor
ops, but I'm not sure we really should have this rewrite be so aggressive. There are cases where I'd say it is making things more complicated instead of simpler, which was especially apparent in combination withlocal_elemwise_lift_scalars
. For instance it would end up rewriting this:b
is justx
. Maybe we should change this rewrite so that it only does something if it knows statically ifstart
andend
are non-negative? (And I guess this also suggests we could use a rewrite that teaches graphs that shapes are non-negative (or they should just return an unsigned int?))Update
A bit more motivation for this:
I think the numba backend at least can really profit from turning more things into scalar ops. Both for compile time and run time. This clashes a bit with what theano thinks of as the "canonical form", where pretty much everything is a tensor.
Maybe a nice compromise between this might be to leave the canonicalize and specialize phases exactly as they are, so that the tensor form stays the canonical way of representing everything, and all the rewrites in those phases work with that.
But at some stage later (not sure exactly when...) we could add a stage (possibly numba specific) that tries to turns as much as possible into scalars. So elemwise ops, shape_i, sum and a bunch of others could be rewritten to return scalars.
The reason I think scalars are better if possible in numba is that tensors produce a lot of code, which slows down compilation, adds lots of allocations and makes the code in general much less transparent for llvm, which I think leads to lots of missed optimizations there.
I benchmarked a small radon model for this a bit, and in that model for instance we spend about 5% of the time in allocation code. That might not sound like too much, but I think this is way to much for comfort, given that most of the cost of those allocations in terms of missed optimizations, cache misses, ref counting etc will be hidden.
After the rewrites here I see a lot of code like this in logp graphs:
There is quite a bit of potential now to rewrite this further: Sum of make vector for instance shoudl just be the sum of the elements. But those sums can easily be represented as a scalar
add
node for instance, so that we never have to allocate those tensors.So why do I think that rank 0 tensors are something to avoid in numba? Take this addition for instance:
The first is compiled to this:
The addition is turned into a single
fadd
instruction that is really easy to reason about for llvm:In contrast the second look like this (after optimization):
Including a call to
@NRT_MemInfo_alloc_aligned(i64 8, i32 32)
, which allocates, and sets up refcounting. Butllvm
doesn't know what this function is doing, so it's presence prevents a lot of optimizations. For instance if we do something like this:where we just call
np.asarray
, but never use it in any way, it still can't optimize that away.