From 9c23236a7a1630e5048a54f4c67624a2c07932a8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 12:42:56 -0700 Subject: [PATCH 1/5] [pass] Update LiftSubgraphInitializersToMainGraphPass to avoid variable shadowing --- .../ir/passes/common/constant_manipulation.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index b76c3c0802..8a8cba7a65 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -140,8 +140,28 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): """ def call(self, model: ir.Model) -> ir.passes.PassResult: + # 1. Ensure all initializer names are unique + registered_initializer_names: set[str] = set() + duplicated_initializers: list[ir.Value] = [] + for graph in model.graphs(): + for initializer in graph.initializers.values(): + if initializer.name is None: + raise ValueError( + f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}" + ) + if initializer.name in registered_initializer_names: + duplicated_initializers.append(initializer) + else: + registered_initializer_names.add(initializer.name) + if duplicated_initializers: + raise ValueError( + "Found duplicated initializers in the model. " + "Please ensure all initializers have unique names. Duplicated: " + f"{duplicated_initializers!r}" + ) + + # 2. Lift the initializers count = 0 - registered_initializer_names: dict[str, int] = {} for graph in model.graphs(): if graph is model.graph: continue @@ -156,16 +176,6 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: continue # Remove the initializer from the subgraph graph.initializers.pop(name) - # To avoid name conflicts, we need to rename the initializer - # to a unique name in the main graph - if name in registered_initializer_names: - name_count = registered_initializer_names[name] - initializer.name = f"{name}_{name_count}" - registered_initializer_names[name] = name_count + 1 - else: - assert initializer.name is not None - registered_initializer_names[initializer.name] = 1 - model.graph.register_initializer(initializer) count += 1 logger.debug( "Lifted initializer '%s' from subgraph '%s' to main graph", From 8df300858b98333fec96cf7ffbcd7e505ff5611d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 12:45:54 -0700 Subject: [PATCH 2/5] test --- onnxscript/ir/passes/common/constant_manipulation.py | 1 + onnxscript/ir/passes/common/constant_manipulation_test.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 8a8cba7a65..4033e48ad4 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -156,6 +156,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if duplicated_initializers: raise ValueError( "Found duplicated initializers in the model. " + "Initializer name must be unique across the main graph and subgraphs. " "Please ensure all initializers have unique names. Duplicated: " f"{duplicated_initializers!r}" ) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index d02933136b..f1410727b6 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -390,6 +390,13 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( graph=main_graph, ir_version=10, ) + if then_initializer_name == else_initializer_name: + with self.assertRaisesRegex( + ValueError, + "Initializer name must be unique across the main graph and subgraphs", + ): + constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) + return result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) self.assertTrue(result.modified) From c53655a16f31f8f82c85118c784215ddc7d350d0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 12:47:46 -0700 Subject: [PATCH 3/5] update --- onnxscript/ir/passes/common/constant_manipulation_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index f1410727b6..879711b652 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -311,6 +311,13 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( graph=main_graph, ir_version=10, ) + if then_initializer_name == else_initializer_name: + with self.assertRaisesRegex( + ValueError, + "Initializer name must be unique across the main graph and subgraphs", + ): + constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) + return result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) self.assertTrue(result.modified) From 3a92c83d2eeb31639889d0145bd3bc81d1b0184a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 13:18:01 -0700 Subject: [PATCH 4/5] test --- .../ir/passes/common/constant_manipulation.py | 16 +++++++++++----- .../passes/common/constant_manipulation_test.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 4033e48ad4..06533bb099 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -137,16 +137,21 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): This pass lifts the initializers of a subgraph to the main graph. It is used to ensure that the initializers are available in the main graph for further processing or optimization. + + Initializers that are also graph inputs will not be lifted. + + Preconditions: + - All initializers in the model must have unique names across the main graph and subgraphs. """ - def call(self, model: ir.Model) -> ir.passes.PassResult: - # 1. Ensure all initializer names are unique + def requires(self, model: ir.Model) -> None: + # Ensure all initializer names are unique registered_initializer_names: set[str] = set() duplicated_initializers: list[ir.Value] = [] for graph in model.graphs(): for initializer in graph.initializers.values(): if initializer.name is None: - raise ValueError( + raise ir.passes.PreconditionError( f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}" ) if initializer.name in registered_initializer_names: @@ -154,14 +159,14 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: else: registered_initializer_names.add(initializer.name) if duplicated_initializers: - raise ValueError( + raise ir.passes.PreconditionError( "Found duplicated initializers in the model. " "Initializer name must be unique across the main graph and subgraphs. " "Please ensure all initializers have unique names. Duplicated: " f"{duplicated_initializers!r}" ) - # 2. Lift the initializers + def call(self, model: ir.Model) -> ir.passes.PassResult: count = 0 for graph in model.graphs(): if graph is model.graph: @@ -177,6 +182,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: continue # Remove the initializer from the subgraph graph.initializers.pop(name) + model.graph.register_initializer(initializer) count += 1 logger.debug( "Lifted initializer '%s' from subgraph '%s' to main graph", diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 879711b652..5f8e93661a 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -248,12 +248,12 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self): class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase): @parameterized.parameterized.expand( [ - ("then_initializer", "else_initializer"), - ("initializer", "initializer"), + ("unique_init_names", "then_initializer", "else_initializer"), + ("duplicated_init_names", "initializer", "initializer"), ] ) def test_pass_with_lifting_constants_to_initializers_within_subgraph( - self, then_initializer_name, else_initializer_name + self, _: str, then_initializer_name: str, else_initializer_name: str ): input_value = ir.Value( name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) @@ -313,7 +313,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( ) if then_initializer_name == else_initializer_name: with self.assertRaisesRegex( - ValueError, + ir.passes.PreconditionError, "Initializer name must be unique across the main graph and subgraphs", ): constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) @@ -332,12 +332,12 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( @parameterized.parameterized.expand( [ - ("then_initializer", "else_initializer"), - ("initializer", "initializer"), + ("unique_init_names", "then_initializer", "else_initializer"), + ("duplicated_init_names", "initializer", "initializer"), ] ) def test_pass_does_not_lift_initialized_inputs_in_subgraph( - self, then_initializer_name, else_initializer_name + self, _: str, then_initializer_name: str, else_initializer_name: str ): input_value = ir.Value( name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) @@ -399,7 +399,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( ) if then_initializer_name == else_initializer_name: with self.assertRaisesRegex( - ValueError, + ir.passes.PreconditionError, "Initializer name must be unique across the main graph and subgraphs", ): constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) From 91953aae33b0856b6594d03726cfdbfbdd9a5ddf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 27 May 2025 13:22:56 -0700 Subject: [PATCH 5/5] Update onnxscript/ir/passes/common/constant_manipulation.py --- onnxscript/ir/passes/common/constant_manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 06533bb099..bbe614c1b9 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -145,7 +145,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): """ def requires(self, model: ir.Model) -> None: - # Ensure all initializer names are unique + """Ensure all initializer names are unique.""" registered_initializer_names: set[str] = set() duplicated_initializers: list[ir.Value] = [] for graph in model.graphs():