From 86a1305c89362b2c3d2fa12b8c5a4cd2465c1c36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 12 Apr 2025 17:30:36 -0700 Subject: [PATCH 1/7] [passes] Create topological sort pass --- .../ir/passes/common/topological_sort.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 onnxscript/ir/passes/common/topological_sort.py diff --git a/onnxscript/ir/passes/common/topological_sort.py b/onnxscript/ir/passes/common/topological_sort.py new file mode 100644 index 0000000000..f4b06ba801 --- /dev/null +++ b/onnxscript/ir/passes/common/topological_sort.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Pass for topologically sorting the graphs.""" + +from __future__ import annotations + +__all__ = [ + "TopologicalSortPass", +] + + +from onnxscript import ir + + +class TopologicalSortPass(ir.passes.InPlacePass): + """Topologically sort graphs and functions in a model.""" + + def call(self, model: ir.Model) -> ir.passes.PassResult: + nodes = list(model.graph) + model.graph.sort() + new_nodes = list(model.graph) + for function in model.functions.values(): + nodes.extend(function) + function.sort() + new_nodes.extend(function) + + # Compare node orders to determine if any changes were made + modified = False + for node, new_node in zip(nodes, new_nodes): + if node is not new_node: + modified = True + break + return ir.passes.PassResult(model=model, modified=modified) From a25a8c77501a6d95376724bdd965170b2bcc59e7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 07:53:32 -0700 Subject: [PATCH 2/7] Add unit tests for TopologicalSortPass * **Test for modified=True**: Add a test to check if `modified` is True when the input model needs sorting. * **Test for modified=False**: Add a test to check if `modified` is False when the input model is already sorted. --- .../ir/passes/common/topological_sort_test.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 onnxscript/ir/passes/common/topological_sort_test.py diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py new file mode 100644 index 0000000000..530a061481 --- /dev/null +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Unit tests for the TopologicalSortPass.""" + +import unittest + +from onnxscript import ir +from onnxscript.ir.passes.common import topological_sort + + +class TopologicalSortPassTest(unittest.TestCase): + def setUp(self): + self.node_a = ir.Node("", "A", inputs=[], num_outputs=1, name="node_a") + self.node_b = ir.Node("", "B", inputs=[self.node_a.outputs[0]], num_outputs=1, name="node_b") + self.node_c = ir.Node("", "C", inputs=[self.node_b.outputs[0]], num_outputs=1, name="node_c") + + def test_topological_sort_modified_true(self): + graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes + name="test_graph", + ) + model = ir.Model(graph, ir_version=10) + pass_result = topological_sort.TopologicalSortPass(model) + self.assertTrue(pass_result.modified) + + def test_topological_sort_modified_false(self): + """Test that modified is False when the input model is already sorted.""" + sorted_graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_a, self.node_b, self.node_c], # Sorted nodes + name="test_graph", + ) + sorted_model = ir.Model(sorted_graph, ir_version=10) + pass_result = topological_sort.TopologicalSortPass().call(sorted_model) + self.assertFalse(pass_result.modified) + + +if __name__ == "__main__": + unittest.main() From df9b5bf2f548b493ef26d4a9092b88043b0ae866 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 07:58:14 -0700 Subject: [PATCH 3/7] lint --- onnxscript/ir/passes/common/topological_sort_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index 530a061481..2f30b1fb5d 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -10,9 +10,9 @@ class TopologicalSortPassTest(unittest.TestCase): def setUp(self): - self.node_a = ir.Node("", "A", inputs=[], num_outputs=1, name="node_a") - self.node_b = ir.Node("", "B", inputs=[self.node_a.outputs[0]], num_outputs=1, name="node_b") - self.node_c = ir.Node("", "C", inputs=[self.node_b.outputs[0]], num_outputs=1, name="node_c") + self.node_a = ir.node("A", inputs=[], name="node_a") + self.node_b = ir.node("B", inputs=[self.node_a.outputs[0]], name="node_b") + self.node_c = ir.node("C", inputs=[self.node_b.outputs[0]], name="node_c") def test_topological_sort_modified_true(self): graph = ir.Graph( @@ -22,7 +22,7 @@ def test_topological_sort_modified_true(self): name="test_graph", ) model = ir.Model(graph, ir_version=10) - pass_result = topological_sort.TopologicalSortPass(model) + pass_result = topological_sort.TopologicalSortPass()(model) self.assertTrue(pass_result.modified) def test_topological_sort_modified_false(self): @@ -34,7 +34,7 @@ def test_topological_sort_modified_false(self): name="test_graph", ) sorted_model = ir.Model(sorted_graph, ir_version=10) - pass_result = topological_sort.TopologicalSortPass().call(sorted_model) + pass_result = topological_sort.TopologicalSortPass()(sorted_model) self.assertFalse(pass_result.modified) From c104e76b437bd161520e556025e84174ce5661a5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 08:01:05 -0700 Subject: [PATCH 4/7] test order --- onnxscript/ir/passes/common/topological_sort_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index 2f30b1fb5d..753aa0dfeb 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -24,6 +24,10 @@ def test_topological_sort_modified_true(self): model = ir.Model(graph, ir_version=10) pass_result = topological_sort.TopologicalSortPass()(model) self.assertTrue(pass_result.modified) + self.assertEqual( + tuple(pass_result.model.graph), + (self.node_a, self.node_b, self.node_c), + ) def test_topological_sort_modified_false(self): """Test that modified is False when the input model is already sorted.""" @@ -36,6 +40,10 @@ def test_topological_sort_modified_false(self): sorted_model = ir.Model(sorted_graph, ir_version=10) pass_result = topological_sort.TopologicalSortPass()(sorted_model) self.assertFalse(pass_result.modified) + self.assertEqual( + tuple(pass_result.model.graph), + (self.node_a, self.node_b, self.node_c), + ) if __name__ == "__main__": From bb9a501cfbac7b41df1215b98b61ff62b117b136 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 08:01:48 -0700 Subject: [PATCH 5/7] naming --- onnxscript/ir/passes/common/topological_sort_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index 753aa0dfeb..07d078fe23 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -22,10 +22,10 @@ def test_topological_sort_modified_true(self): name="test_graph", ) model = ir.Model(graph, ir_version=10) - pass_result = topological_sort.TopologicalSortPass()(model) - self.assertTrue(pass_result.modified) + result = topological_sort.TopologicalSortPass()(model) + self.assertTrue(result.modified) self.assertEqual( - tuple(pass_result.model.graph), + tuple(result.model.graph), (self.node_a, self.node_b, self.node_c), ) @@ -38,10 +38,10 @@ def test_topological_sort_modified_false(self): name="test_graph", ) sorted_model = ir.Model(sorted_graph, ir_version=10) - pass_result = topological_sort.TopologicalSortPass()(sorted_model) - self.assertFalse(pass_result.modified) + result = topological_sort.TopologicalSortPass()(sorted_model) + self.assertFalse(result.modified) self.assertEqual( - tuple(pass_result.model.graph), + tuple(result.model.graph), (self.node_a, self.node_b, self.node_c), ) From 28044cb24adc04034ed377dacb10d7e8de44b548 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 08:02:52 -0700 Subject: [PATCH 6/7] test --- onnxscript/ir/passes/common/topological_sort_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index 07d078fe23..ca9d1377f0 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -11,8 +11,8 @@ class TopologicalSortPassTest(unittest.TestCase): def setUp(self): self.node_a = ir.node("A", inputs=[], name="node_a") - self.node_b = ir.node("B", inputs=[self.node_a.outputs[0]], name="node_b") - self.node_c = ir.node("C", inputs=[self.node_b.outputs[0]], name="node_c") + self.node_b = ir.node("B", inputs=self.node_a.outputs, name="node_b") + self.node_c = ir.node("C", inputs=self.node_b.outputs, name="node_c") def test_topological_sort_modified_true(self): graph = ir.Graph( From ea39dfc3c0cd6ec306954cd5d71273c1bad22a54 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 13 Apr 2025 08:59:20 -0700 Subject: [PATCH 7/7] Rename variables for clarity in topological sort --- onnxscript/ir/passes/common/topological_sort.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/common/topological_sort.py b/onnxscript/ir/passes/common/topological_sort.py index f4b06ba801..9be183cf01 100644 --- a/onnxscript/ir/passes/common/topological_sort.py +++ b/onnxscript/ir/passes/common/topological_sort.py @@ -16,17 +16,17 @@ class TopologicalSortPass(ir.passes.InPlacePass): """Topologically sort graphs and functions in a model.""" def call(self, model: ir.Model) -> ir.passes.PassResult: - nodes = list(model.graph) + original_nodes = list(model.graph) model.graph.sort() - new_nodes = list(model.graph) + sorted_nodes = list(model.graph) for function in model.functions.values(): - nodes.extend(function) + original_nodes.extend(function) function.sort() - new_nodes.extend(function) + sorted_nodes.extend(function) # Compare node orders to determine if any changes were made modified = False - for node, new_node in zip(nodes, new_nodes): + for node, new_node in zip(original_nodes, sorted_nodes): if node is not new_node: modified = True break