Skip to content

Commit

Permalink
Properly handle ignored axes during tag propagation (#569)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-alveyblanc authored Dec 13, 2024
1 parent 48e8d61 commit 67a087d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"bidict",
"immutabledict",
"loopy>=2020.2",
"pytools>=2024.1.14",
"pytools>=2024.1.21",
"pymbolic>=2024.2",
"typing_extensions>=4",
]
Expand Down
16 changes: 12 additions & 4 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,19 @@ def unify_axes_tags(
equations_collector.equations
)

for tag, var in equations_collector.known_tag_to_var.items():
if isinstance(tag, AxisIgnoredForPropagationTag):
continue
ignored_vars = set({
tag_var for tag, tag_var in equations_collector.known_tag_to_var.items()
if isinstance(tag, AxisIgnoredForPropagationTag)
})

ignored_vars.update({
ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items()
if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag)
})

reachable_nodes = get_reachable_nodes(propagation_graph, var)
for tag, var in equations_collector.known_tag_to_var.items():
reachable_nodes = get_reachable_nodes(propagation_graph, var,
exclude_nodes=ignored_vars)
for reachable_var in (reachable_nodes - known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
Expand Down
36 changes: 36 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,42 @@ def test_unify_axes_tags():
# }}}


def test_ignoring_axes_during_propagation():
from pytools.tag import UniqueTag

from pytato.transform.metadata import AxisIgnoredForPropagationTag

class ElementAxisTag(UniqueTag):
pass

class DOFAxisTagX(UniqueTag):
pass

class DOFAxisTagY(UniqueTag):
pass

a = pt.make_placeholder("a", (4, 4))
a = a.with_tagged_axis(0, AxisIgnoredForPropagationTag())
a = a.with_tagged_axis(1, AxisIgnoredForPropagationTag())

u = pt.make_placeholder("u", (128, 4, 4))
u = u.with_tagged_axis(0, ElementAxisTag())
u = u.with_tagged_axis(1, DOFAxisTagX())
u = u.with_tagged_axis(2, DOFAxisTagY())

u_x = pt.einsum("il,elj->eij", a, u)
u_y = pt.einsum("jl,eil->eij", a, u)

expr = u_x + u_y

unified = pt.unify_axes_tags(expr)
iax_to_tags = {i: unified.axes[i].tags for i in range(len(unified.axes))}

assert iax_to_tags[0] == frozenset([ElementAxisTag()])
assert iax_to_tags[1] == frozenset([DOFAxisTagX()])
assert iax_to_tags[2] == frozenset([DOFAxisTagY()])


def test_rewrite_einsums_with_no_broadcasts():
a = pt.make_placeholder("a", (10, 4, 1))
b = pt.make_placeholder("b", (10, 1, 4))
Expand Down

0 comments on commit 67a087d

Please sign in to comment.