Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite rank 0 elemwise ops and push scalar constants into elemwise #107

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr else None),
**popen_kwargs
**popen_kwargs,
)
break
except OSError:
Expand Down
58 changes: 53 additions & 5 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,34 @@ def apply(self, fgraph):
# misc special cases for speed that break canonicalization
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)

# Turn tensor operations to scalar operations where possible.
# This is currently marked as numba-only, but this could be changed
# in the future.
optdb.register("scalarize", EquilibriumDB(), "numba_only", position=3.1)

# misc special cases for speed that are dependent on the device.
optdb.register(
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
) # must be after gpu stuff at 48.5

# Must be before add_destroy_handler
optdb.register(
"elemwise_fusion",
SequenceDB(),
"fast_run",
"fusion",
"local_elemwise_fusion",
position=48.7,
)

optdb.register(
"post_fusion",
EquilibriumDB(),
"fast_run",
"fast_compile",
position=48.8,
)

# especially constant merge
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)

Expand Down Expand Up @@ -441,19 +464,44 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(
include=["fast_compile"],
exclude=["numba_only"],
),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
FAST_RUN = Mode(
"cvm",
RewriteDatabaseQuery(
include=["fast_run"],
exclude=["numba_only"],
),
)
else:
FAST_RUN = Mode("vm", "fast_run")
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(
include=["fast_run"],
exclude=["numba_only"],
),
)

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=["cxx_only", "BlasOpt", "numba_only"],
),
)

NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(
include=["fast_run", "numba_only"],
exclude=["cxx_only", "BlasOpt"],
),
)


Expand Down
51 changes: 37 additions & 14 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return node_rewriter


def register_scalarize(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):

def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)

return register
else:
name = kwargs.pop("name", None) or node_rewriter.__name__
compile.optdb["scalarize"].register(
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
)
return node_rewriter


def register_uncanonicalize(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
):
Expand Down Expand Up @@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):

@register_canonicalize
@register_specialize
@register_scalarize
@node_rewriter([TensorFromScalar])
def local_tensor_scalar_tensor(fgraph, node):
"""tensor_from_scalar(scalar_from_tensor(x)) -> x"""
if isinstance(node.op, TensorFromScalar):
s = node.inputs[0]
if s.owner and isinstance(s.owner.op, ScalarFromTensor):
t = s.owner.inputs[0]
s = node.inputs[0]
if s.owner and isinstance(s.owner.op, ScalarFromTensor):
t = s.owner.inputs[0]

# We don't need to copy over any stack traces here
return [t]
# We don't need to copy over any stack traces here
return [t]


@register_canonicalize
@register_specialize
@register_scalarize
@node_rewriter([ScalarFromTensor])
def local_scalar_tensor_scalar(fgraph, node):
"""scalar_from_tensor(tensor_from_scalar(x)) -> x"""
if isinstance(node.op, ScalarFromTensor):
t = node.inputs[0]
if t.owner and isinstance(t.owner.op, TensorFromScalar):
s = t.owner.inputs[0]

# We don't need to copy over any stack traces here
return [s]
"""scalar_from_tensor(tensor_from_scalar(x)) -> x

and scalar_from_tensor(TensorConstant(x)) -> x
"""
t = node.inputs[0]
if t.owner and isinstance(t.owner.op, TensorFromScalar):
s = t.owner.inputs[0]

# We don't need to copy over any stack traces here
return [s]
if isinstance(t, TensorConstant):
assert t.ndim == 0
return [aes.constant(t.value.item(), t.name, t.dtype)]


@register_specialize("local_alloc_elemwise")
Expand Down
122 changes: 93 additions & 29 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
in2out,
node_rewriter,
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
Expand Down Expand Up @@ -380,6 +380,91 @@ def is_dimshuffle_useless(new_order, input):
return is_useless


@node_rewriter([Elemwise])
def elemwise_to_scalar(fgraph, node):
op = node.op
if not all(input.ndim == 0 for input in node.inputs):
return False

scalars = [aes.as_scalar(input) for input in node.inputs]

# TODO Something like
# copy_stack_trace(node.outputs[0], new_res)
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs]


compile.optdb["scalarize"].register(
"local_elemwise_to_scalar",
elemwise_to_scalar,
"fast_run",
"fast_compile",
"numba_only",
)


@node_rewriter([Elemwise])
def push_elemwise_constants(fgraph, node):
"""Push constant scalars from inputs to elemwise to inputs of the
contained scalar op.
"""
op = node.op
if any(op.inplace_pattern):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to be inplace because constants are never inplaced. But to not have to deal with it just register this rewrite before the inplace rewrites

Copy link
Member Author

@aseyboldt aseyboldt Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried that maybe some downstream op is assuming that one of the inputs has in fact changed? It should be running before the inline passes anyway though....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't happen. Inplace rewrites are myopic, they only look at 1 node at a time. I never saw a rewrite checking inplace patterns elsewhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what about examples like this?

import pytensor.tensor as pt
import pytensor

x = pt.dvector("x")
y = 2 * x + 1

val = np.ones(3)
input = pytensor.In(x, update=y)
func = pytensor.function([input], [])
pytensor.dprint(func)
Elemwise{Composite{(1.0 + (2.0 * i0))}}[(0, 0)] [id A] 0
 |x [id B]

Replacing the inplace Elemwise with a non-inplace Elemwise would be incorrect here.
Still not a problem because the rewrite is registered before the inplace pass, but still...

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure there would be...
But it looks like we need this final Elemwise in the graph not for its output but only for its side effect of changing the first input. If we were to replace this node with an Elemwise without the inplace flag, but the same output, wouldn't the update break? But maybe there is a feature somewhere that prevents this?

return False

if not isinstance(node.op.scalar_op, aes.Composite):
return False

def is_constant_scalar(x):
return isinstance(x, TensorConstant) and all(x.broadcastable)

push_idxs = []
push_values = []
keep_values = []
for i, input in enumerate(node.inputs):
if is_constant_scalar(input):
push_idxs.append(i)
val = input.value
push_values.append(aes.constant(val.item(), dtype=val.dtype))
elif (
input.owner
and isinstance(input.owner.op, DimShuffle)
and is_constant_scalar(input.owner.inputs[0])
):
push_idxs.append(i)
val = input.owner.inputs[0].value
push_values.append(aes.constant(val.item(), dtype=val.dtype))
else:
keep_values.append(input)

if not push_values:
return False

inner_graph = node.op.scalar_op.fgraph
to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs]

# Clone the inner graph, it might be used somewhere else
inner_graph, mapping = inner_graph.clone_get_equiv()
inner_graph.replace_all(
(mapping[old], new) for old, new in zip(to_replace, push_values)
)

new_inputs = [
input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs
]
return (
Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I am curious why would it fail. I can have a look at the generated C code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, figuring out what exactly is going wrong here would be good I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I can't reproduce the segfaults anymore...
I'm getting compilation errors however.

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested one simple example locally and it seems to work in the C-backend. Can you share the problematic example? The more I look at this PR the more it seems it shouldn't be made Numba specific!

.make_node(*keep_values)
.outputs
)


compile.optdb["post_fusion"].register(
"push_elemwise_constants",
push_elemwise_constants,
"numba_only",
)


@register_canonicalize
@register_specialize
@node_rewriter([DimShuffle])
Expand Down Expand Up @@ -898,34 +983,13 @@ def print_profile(cls, stream, prof, level=0):
print(blanc, " time_toposort", prof[7], file=stream)


if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register( # type: ignore
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
else:
compile.optdb.register( # type: ignore
"elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)
compile.optdb["elemwise_fusion"].register( # type: ignore
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fast_run",
"fusion",
position=1,
)


@register_canonicalize
Expand Down
18 changes: 16 additions & 2 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytensor.scalar.basic as aes
import pytensor.scalar.math as aes_math
from pytensor import compile
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
NodeRewriter,
Expand Down Expand Up @@ -85,13 +86,14 @@
encompasses_broadcastable,
local_fill_sink,
register_canonicalize,
register_scalarize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from pytensor.tensor.rewriting.elemwise import FusionOptimizer
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
Expand Down Expand Up @@ -1567,6 +1569,18 @@ def local_op_of_op(fgraph, node):
return [combined(node_inps.owner.inputs[0])]


@register_scalarize
@node_rewriter([Sum])
def local_sum_of_makevector(fgraph, node):
(array,) = node.inputs
if not array.owner or not isinstance(array.owner.op, MakeVector):
return False

values = array.owner.inputs
summed = aes.add(*values)
return [as_tensor_variable(summed)]
Comment on lines +1572 to +1581
Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to touch on #59

Can we abstract the scalarize part from the "lift reduction operations towards the inputs", which is useful regardless of the backend? Even the scalarize seems useful in both backends. What was the problem with the C backend again?



ALL_REDUCE = (
[
CAReduce,
Expand Down Expand Up @@ -2922,7 +2936,7 @@ def local_add_mul_fusion(fgraph, node):
return [output]


fuse_seqopt.register(
compile.optdb["elemwise_fusion"].register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
"fast_run",
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ def local_subtensor_lift(fgraph, node):
return [rbcast_subt_x]


@register_canonicalize
@register_specialize
@register_stabilize("cxx_only")
@register_canonicalize("cxx_only")
Comment on lines +472 to +473
Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me push against this approach.

There are three scenarios for this rewrite:

  1. It's not very useful and should be reconsidered, regardless of backend
  2. It's useful in the context of a larger chain of rewrites, regardless of backend
  3. It's only useful in one specific backend.

I don't see the reason for 3

If you start making rewrites exclusive to the C-backend you will forget about 2. But eventually you will want to make numba the default backend and you will want the old tests to pass. You will now have made your task much more challenging because you diverged the C and Numba backends, and the latter's test suite is way more myopic.

It's actually a blessing that Theano/Aesara had very extensive test suites and it was difficult to break things unintentionally. But restricting rewrites to the old well tested backend that we want to eventually replace by the new poorly tested one, is opting out of this safety net. In a sense you will just be kicking the can down the road. The decision about the rewrite will have to be done regardless, but by then the Numba rewrite passes may look so different (because it was developed in a much more forgiving test suite) that you cannot even reason about the two and make an informed choice.


In short I think we should be very very selective about the rewrites that are backend specific. For instance I think we should definitely investigate if the scalarize changes also make sense for the C and JAX backends.

@node_rewriter([Subtensor])
def local_subtensor_merge(fgraph, node):
"""
Expand Down
7 changes: 1 addition & 6 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pytensor.scan.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor import log, vector
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
Expand Down Expand Up @@ -437,8 +436,4 @@ def test_inner_graph_optimized():
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
assert len(inner_scan_nodes) == 1
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
inner_scan_node.op.scalar_op, Log1p
)
assert any(isinstance(node.op, Log1p) for node in inner_scan_nodes)