diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py index ca9d1377f0..8680761f1e 100644 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ b/onnxscript/ir/passes/common/topological_sort_test.py @@ -45,6 +45,41 @@ def test_topological_sort_modified_false(self): (self.node_a, self.node_b, self.node_c), ) + def test_topological_sort_on_functions(self): + """Test that TopologicalSortPass works on functions in a model.""" + # Create a function with unsorted nodes + func_graph = ir.Graph( + inputs=self.node_a.inputs, + outputs=self.node_c.outputs, + nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes + ) + function = ir.Function( + domain="test_domain", + name="test_function", + graph=func_graph, + attributes=[], + ) + + # Create a model with the function + graph = ir.Graph( + inputs=[], + outputs=[], + nodes=[], + name="test_graph", + ) + model = ir.Model(graph, ir_version=10, functions=[function]) + + # Apply the TopologicalSortPass + result = topological_sort.TopologicalSortPass()(model) + + # Verify that the nodes in the function are sorted + sorted_func_nodes = (self.node_a, self.node_b, self.node_c) + self.assertTrue(result.modified) + self.assertEqual( + tuple(result.model.functions[function.identifier()]), + sorted_func_nodes, + ) + if __name__ == "__main__": unittest.main()