From 903a202be9b58ea3fde23152962de8ccd10b95ef Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 1 May 2025 22:56:41 +0000 Subject: [PATCH 1/3] add the pass --- .../ir/passes/common/constant_manipulation.py | 40 +++++++++ .../common/constant_manipulation_test.py | 82 +++++++++++++++++++ onnxscript/optimizer/_optimizer.py | 1 + 3 files changed, 123 insertions(+) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 888053a8f5..4a2c752f0f 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -6,6 +6,7 @@ __all__ = [ "LiftConstantsToInitializersPass", + "LiftSubgraphInitializersToMainGraphPass", ] import logging @@ -126,3 +127,42 @@ def _constant_node_attribute_to_tensor( ) return None return tensor + + +class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): + """Lift subgraph initializers to main graph. + + This pass lifts the initializers of a subgraph to the main graph. + It is used to ensure that the initializers are available in the main graph + for further processing or optimization. + """ + + def call(self, model: ir.Model) -> ir.passes.PassResult: + count = 0 + registered_initializer_names: dict[str, int] = {} + for node in ir.traversal.RecursiveGraphIterator(model.graph): + assert node.graph is not None + if node.graph == model.graph: + continue + if len(node.graph.initializers) == 0: + continue + for initializer in node.graph.initializers.values(): + # To avoid name conflicts, we need to rename the initializer + # to a unique name in the main graph + if initializer.name in registered_initializer_names: + name_count = registered_initializer_names[initializer.name] + initializer.name = f"{initializer.name}_{name_count}" + registered_initializer_names[initializer.name] = name_count + 1 + else: + assert initializer.name is not None + registered_initializer_names[initializer.name] = 1 + model.graph.register_initializer(initializer) + count += 1 + logger.debug( + "Lifted initializer '%s' from subgraph '%s' to main graph", + initializer.name, + node.graph.name, + ) + # Remove the initializer from the subgraph + node.graph._initializers.clear() # pylint: disable=protected-access + return ir.passes.PassResult(model, modified=bool(count)) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index bb84582e31..84fed948ac 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -251,5 +251,87 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self): self.assertEqual(len(result.model.graph.initializers), 0) +class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("then_initializer", "else_initializer"), + ("initializer", "initializer"), + ] + ) + def test_pass_with_lifting_constants_to_initializers_within_subgraph( + self, then_initializer_name, else_initializer_name + ): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + then_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + then_initializer_value = ir.Value( + name=then_initializer_name, + shape=then_initializer_tensor.shape, + type=ir.TensorType(ir.DataType.FLOAT), + const_value=then_initializer_tensor, + ) + + # then branch adds the constant to the input + # else branch multiplies the input by the constant + add_node = ir.node("Add", inputs=[input_value, then_initializer_value]) + then_graph = ir.Graph( + inputs=[input_value, then_initializer_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + opset_imports={"": 20}, + initializers=[then_initializer_value], + ) + else_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + else_initializer_value = ir.Value( + name=else_initializer_name, + shape=else_initializer_tensor.shape, + type=ir.TensorType(ir.DataType.FLOAT), + const_value=else_initializer_tensor, + ) + mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) + else_graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[mul_node], + opset_imports={"": 20}, + initializers=[else_initializer_value], + ) + # create a conditional node that uses the then and else graphs + cond_node = ir.node( + "If", + inputs=[input_value], + attributes={"then_branch": then_graph, "else_branch": else_graph}, + num_outputs=1, + ) + # construnct the model + main_graph = ir.Graph( + inputs=[input_value], + outputs=cond_node.outputs, + nodes=[cond_node], + opset_imports={"": 20}, + ) + main_graph.sort() + model = ir.Model( + graph=main_graph, + ir_version=10, + ) + result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) + self.assertTrue(result.modified) + + self.assertEqual(len(else_graph.initializers), 0) + self.assertEqual(len(then_graph.initializers), 0) + self.assertEqual(len(main_graph.initializers), 2) + for value, tensor in zip( + main_graph.initializers.values(), + [then_initializer_tensor, else_initializer_tensor], + ): + self.assertIs( + value.const_value, + tensor, + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 3aaba1b057..f8994bd741 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -54,6 +54,7 @@ def optimize_ir( ), onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(), + onnxscript.ir.passes.common.constant_manipulation.LiftSubgraphInitializersToMainGraphPass(), ] if inline: # Inline all functions first before optimizing From 6781e9950e0c0ef420b8eac8695f8b3210b022c6 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 2 May 2025 17:53:44 +0000 Subject: [PATCH 2/3] update for loop --- .../ir/passes/common/constant_manipulation.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 4a2c752f0f..f431c57d93 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -140,19 +140,16 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> ir.passes.PassResult: count = 0 registered_initializer_names: dict[str, int] = {} - for node in ir.traversal.RecursiveGraphIterator(model.graph): - assert node.graph is not None - if node.graph == model.graph: - continue - if len(node.graph.initializers) == 0: + for graph in model.graphs(): + if graph is model.graph: continue - for initializer in node.graph.initializers.values(): + for name, initializer in graph.initializers.items(): # To avoid name conflicts, we need to rename the initializer # to a unique name in the main graph - if initializer.name in registered_initializer_names: - name_count = registered_initializer_names[initializer.name] - initializer.name = f"{initializer.name}_{name_count}" - registered_initializer_names[initializer.name] = name_count + 1 + if name in registered_initializer_names: + name_count = registered_initializer_names[name] + initializer.name = f"{name}_{name_count}" + registered_initializer_names[name] = name_count + 1 else: assert initializer.name is not None registered_initializer_names[initializer.name] = 1 @@ -161,8 +158,8 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: logger.debug( "Lifted initializer '%s' from subgraph '%s' to main graph", initializer.name, - node.graph.name, + graph.name, ) # Remove the initializer from the subgraph - node.graph._initializers.clear() # pylint: disable=protected-access + graph._initializers.clear() # pylint: disable=protected-access return ir.passes.PassResult(model, modified=bool(count)) From 37a66945852925b5c5dec695a1b0634cb1d247b4 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 2 May 2025 19:17:14 +0000 Subject: [PATCH 3/3] address review --- onnxscript/ir/passes/common/constant_manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index f431c57d93..124e787b5c 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -161,5 +161,5 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: graph.name, ) # Remove the initializer from the subgraph - graph._initializers.clear() # pylint: disable=protected-access + graph.initializers.clear() return ir.passes.PassResult(model, modified=bool(count))