Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse Elemwise graphs that have multiple outputs and clients #121

Merged
merged 13 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def query(
position_cutoff = tags[0].position_cutoff

# The RewriteDatabaseQuery instance might contain extra rewrites which need
# to be added the the sequence of rewrites (don't alter the
# to be added to the sequence of rewrites (don't alter the
# original dictionary)
if len(tags[0].extra_rewrites) > 0:
position_dict = position_dict.copy()
Expand Down
24 changes: 17 additions & 7 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ def debugprint(
):

if hasattr(var.owner, "op"):
if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars:
if (
isinstance(var.owner.op, HasInnerGraph)
or hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
) and var not in inner_graph_vars:
inner_graph_vars.append(var)
if print_op_info:
op_information.update(op_debug_information(var.owner.op, var.owner))
Expand Down Expand Up @@ -355,8 +359,12 @@ def debugprint(
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = ig_var.owner.op.inner_outputs
if hasattr(ig_var.owner.op, "scalar_op"):
inner_inputs = ig_var.owner.op.scalar_op.inner_inputs
inner_outputs = ig_var.owner.op.scalar_op.inner_outputs
else:
inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = ig_var.owner.op.inner_outputs

outer_inputs = ig_var.owner.inputs

Expand Down Expand Up @@ -422,8 +430,9 @@ def debugprint(

if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_vars
):
or hasattr(getattr(out.owner, "op", None), "scalar_op")
and isinstance(out.owner.op.scalar_op, HasInnerGraph)
) and out not in inner_graph_vars:
inner_graph_vars.append(out)

_debugprint(
Expand Down Expand Up @@ -664,8 +673,9 @@ def get_id_str(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops
):
or hasattr(in_var.owner.op, "scalar_op")
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
) and in_var not in inner_graph_ops:
inner_graph_ops.append(in_var)

_debugprint(
Expand Down
42 changes: 17 additions & 25 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph):

init_param: Tuple[str, ...] = ("inputs", "outputs")

def __init__(self, inputs, outputs):
def __init__(self, inputs, outputs, name="Composite"):
self.name = name
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
Expand Down Expand Up @@ -4106,30 +4107,6 @@ def _perform(*inputs, outputs=[[None]]):
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn

@property
def name(self):
if hasattr(self, "_name"):
return self._name

# TODO FIXME: Just implement pretty printing for the `Op`; don't do
# this redundant, outside work in the `Op` itself.
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self._name = rval
return self._name

@name.setter
def name(self, name):
self._name = name

@property
def fgraph(self):
if hasattr(self, "_fgraph"):
Expand All @@ -4146,6 +4123,21 @@ def fgraph(self):
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)

# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)

self._fgraph = fgraph
return self._fgraph

Expand Down
4 changes: 2 additions & 2 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,9 +1638,9 @@ def compute_grad_2f1(a, b, c, z, wrt):

return compute_grad_2f1(a, b, c, z, wrt=wrt)

def __call__(self, a, b, c, z, wrt):
def __call__(self, a, b, c, z, wrt, **kwargs):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt)
return super().__call__(a, b, c, z, wrt, **kwargs)

def c_code(self, *args, **kwargs):
raise NotImplementedError()
Expand Down
22 changes: 6 additions & 16 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,10 @@ def transform(r):

def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs.
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if len(node.inputs) > 32 and self.ufunc and impl == "py":
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c"

if getattr(self, "nfunc_spec", None) and impl != "c":
Expand All @@ -677,7 +677,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
self.nfunc = module

if (
len(node.inputs) < 32
(len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None
and impl == "py"
Expand Down Expand Up @@ -727,28 +727,18 @@ def prepare_node(self, node, storage_map, compute_map, impl):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)

def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32:
if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of inputs to a ufunc is 32 or more.
# ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion
# should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)

for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")

# Determine the shape of outputs
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)

ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
Expand Down
Loading