-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a pass to flag arrays only differing in tags #420
Comments
Where might be a good place to insert this pass? (Not very familiar with the overall structure of pytato yet.) |
I think it would come down to adding a function in |
Here's one way to do it: (py311_env) $ cat remove_tags_and_merge.py
import pytato as pt
import numpy as np
from pytools.tag import Tag
def remove_tag_t(expr, tag_t):
def _rec_remove_tag_t(expr):
if isinstance(expr, pt.Array):
if tags_to_remove := expr.tags_of_type(tag_t):
return expr.without_tags(tags_to_remove,
verify_existence=False)
else:
return expr
else:
return expr
expr = pt.transform.map_and_copy(expr, _rec_remove_tag_t)
return pt.transform.BranchMorpher()(expr)
x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)
tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())
out = 2*tmp + 3*tmp1
print(pt.analysis.get_num_nodes(out))
print(pt.analysis.get_num_nodes(remove_tag_t(out, tag_t=Tag)))
(py311_env) $ python remove_tags_and_merge.py
8
7 |
This is true, but if it's just one node differing in the tag, then something else is wrong here as the subexpressions for the diverging nodes would still be the same and the relative difference in runtime/compile time should have been insignificant. |
Are you sure? Wouldn't depending nodes necessarily also compare non-equal? |
Thanks for providing that! It's quick, but it's got a few downsides: It has quite a few traversals, and it doesn't explicitly identify the offending nodes. |
Aah fair. I was only thinking of the predecessors and not the successors. Thanks for the correction!
Yep, it's a starting point. However, extending it to the functionalities that you point out shouldn't be more than another 50 lines, I think :). |
FWIW, this is more in line with what you suggested: import pytato as pt
import numpy as np
from typing import Dict
class MyWalkMapper(pt.transform.CachedWalkMapper):
def __init__(self):
super().__init__()
self.stripped_ary_to_ary: Dict[pt.Array, pt.Array] = {}
def get_cache_key(self, expr):
return id(expr)
def post_visit(self, expr: pt.transform.ArrayOrNames):
if isinstance(expr, pt.Array):
from pytato.array import (_get_default_tags,
_get_default_axes)
tagless_expr = expr.copy(
tags=_get_default_tags(),
axes=_get_default_axes(expr.ndim))
try:
if colliding_expr := self.stripped_ary_to_ary[tagless_expr] != expr:
raise ValueError(f"Arrays '{colliding_expr}' and '{expr}'"
" are semantically the same array except"
" the attached metadata => will lead to "
" inefficient generated code.")
except KeyError:
self.stripped_ary_to_ary[tagless_expr] = expr
x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)
tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())
out = 2*tmp + 3*tmp1
MyWalkMapper()(out) |
@majosm reported a situation where a large compile time difference was observed based on an array having a tag vs. not. This is plausible, as even different just tags can lead to arrays not being viewed as equal and therefore failing to be merged in common subexpression elimination. This means that this value (and all its dependents, if both versions are used) are computed multiple times. If multiple uses of the pattern occur, then this could lead to exponential growth of DAG size.
All of this is likely almost always unintended, and so we should at least warn about it (if not error). What I have in mind is a pass that strips all tags and flags the situation in which that process produces multiple versions of the same array that compare equal after stripping.
The text was updated successfully, but these errors were encountered: