From 0ac9864cac07029770d249857c228f60b3df83fc Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Wed, 11 Dec 2024 19:11:49 -0600 Subject: [PATCH 1/6] change propagation to properly ignore ignored axes --- pytato/transform/metadata.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 130db8b7a..4eeddb030 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -716,11 +716,17 @@ def unify_axes_tags( equations_collector.equations ) - for tag, var in equations_collector.known_tag_to_var.items(): - if isinstance(tag, AxisIgnoredForPropagationTag): - continue + ignored_vars = set() + for (ary, ax), ax_var in equations_collector.axis_to_var.items(): + tags = ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) + if tags: + ignored_vars.add(ax_var) + for tag in tags: + ignored_vars.add(equations_collector.known_tag_to_var[tag]) - reachable_nodes = get_reachable_nodes(propagation_graph, var) + for tag, var in equations_collector.known_tag_to_var.items(): + reachable_nodes = get_reachable_nodes(propagation_graph, var, + ignored_vars) for reachable_var in (reachable_nodes - known_tag_vars): axis_to_solved_tags.setdefault( equations_collector.axis_to_var.inverse[reachable_var], From 2b42195fee1d135b0f5b68a686afb8ac203c6bc9 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 12 Dec 2024 22:08:05 -0600 Subject: [PATCH 2/6] bump pytools version requirement --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 665a83d81..b91a3568e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "bidict", "immutabledict", "loopy>=2020.2", - "pytools>=2024.1.14", + "pytools>=2024.1.21", "pymbolic>=2024.2", "typing_extensions>=4", ] From db0eb201c6825638e0ba316ae1f13dcf8f5fc520 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 12 Dec 2024 22:20:03 -0600 Subject: [PATCH 3/6] improve clarity of unify_axes_tags --- pytato/transform/metadata.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 4eeddb030..8cd520d94 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -716,17 +716,19 @@ def unify_axes_tags( equations_collector.equations ) - ignored_vars = set() - for (ary, ax), ax_var in equations_collector.axis_to_var.items(): - tags = ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) - if tags: - ignored_vars.add(ax_var) - for tag in tags: - ignored_vars.add(equations_collector.known_tag_to_var[tag]) + ignored_vars = set({ + tag_var for tag, tag_var in equations_collector.known_tag_to_var.items() + if isinstance(tag, AxisIgnoredForPropagationTag) + }) + + ignored_vars.update({ + ax_var for (ary, ax), ax_var in equations_collector.axis_to_var.items() + if ary.axes[ax].tags_of_type(AxisIgnoredForPropagationTag) + }) for tag, var in equations_collector.known_tag_to_var.items(): reachable_nodes = get_reachable_nodes(propagation_graph, var, - ignored_vars) + exclude_nodes=ignored_vars) for reachable_var in (reachable_nodes - known_tag_vars): axis_to_solved_tags.setdefault( equations_collector.axis_to_var.inverse[reachable_var], From 52f1b3ceea76f196edb0b26b74e1943be4684d86 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 13 Dec 2024 11:25:15 -0600 Subject: [PATCH 4/6] add test --- test/test_pytato.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 38e2fda5e..a2c8739dc 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1324,6 +1324,41 @@ def test_unify_axes_tags(): # }}} +def test_ignoring_axes_during_propagation(): + from pytools.tag import UniqueTag + from pytato.transform.metadata import AxisIgnoredForPropagationTag + + class ElementAxisTag(UniqueTag): + pass + + class DOFAxisTagX(UniqueTag): + pass + + class DOFAxisTagY(UniqueTag): + pass + + a = pt.make_placeholder("a", (4, 4)) + a = a.with_tagged_axis(0, AxisIgnoredForPropagationTag()) + a = a.with_tagged_axis(1, AxisIgnoredForPropagationTag()) + + u = pt.make_placeholder("u", (128, 4, 4)) + u = u.with_tagged_axis(0, ElementAxisTag()) + u = u.with_tagged_axis(1, DOFAxisTagX()) + u = u.with_tagged_axis(2, DOFAxisTagY()) + + u_x = pt.einsum("il,elj->eij", a, u) + u_y = pt.einsum("jl,eil->eij", a, u) + + expr = u_x + u_y + + unified = pt.unify_axes_tags(expr) + iax_to_tags = {i: unified.axes[i].tags for i in range(len(unified.axes))} + + assert iax_to_tags[0] == frozenset([ElementAxisTag()]) + assert iax_to_tags[1] == frozenset([DOFAxisTagX()]) + assert iax_to_tags[2] == frozenset([DOFAxisTagY()]) + + def test_rewrite_einsums_with_no_broadcasts(): a = pt.make_placeholder("a", (10, 4, 1)) b = pt.make_placeholder("b", (10, 1, 4)) From ddad374560d58050b5c63e4e2e8cca24870e80a0 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 13 Dec 2024 11:28:16 -0600 Subject: [PATCH 5/6] fix ruff complaints --- test/test_pytato.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index a2c8739dc..8783e84c6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1325,8 +1325,8 @@ def test_unify_axes_tags(): def test_ignoring_axes_during_propagation(): - from pytools.tag import UniqueTag from pytato.transform.metadata import AxisIgnoredForPropagationTag + from pytools.tag import UniqueTag class ElementAxisTag(UniqueTag): pass From 871a835c0c00870cf32ede6b1fa950294773dbe6 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 13 Dec 2024 11:35:06 -0600 Subject: [PATCH 6/6] actually fix ruff complaints --- test/test_pytato.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 8783e84c6..45d333c32 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1325,9 +1325,10 @@ def test_unify_axes_tags(): def test_ignoring_axes_during_propagation(): - from pytato.transform.metadata import AxisIgnoredForPropagationTag from pytools.tag import UniqueTag + from pytato.transform.metadata import AxisIgnoredForPropagationTag + class ElementAxisTag(UniqueTag): pass