diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index 685017c0..da104118 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -139,35 +139,11 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): for further processing or optimization. Initializers that are also graph inputs will not be lifted. - - Preconditions: - - All initializers in the model must have unique names across the main graph and subgraphs. """ - def requires(self, model: ir.Model) -> None: - """Ensure all initializer names are unique.""" - registered_initializer_names: set[str] = set() - duplicated_initializers: list[ir.Value] = [] - for graph in model.graphs(): - for initializer in graph.initializers.values(): - if initializer.name is None: - raise ir.passes.PreconditionError( - f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}" - ) - if initializer.name in registered_initializer_names: - duplicated_initializers.append(initializer) - else: - registered_initializer_names.add(initializer.name) - if duplicated_initializers: - raise ir.passes.PreconditionError( - "Found duplicated initializers in the model. " - "Initializer name must be unique across the main graph and subgraphs. " - "Please ensure all initializers have unique names. Duplicated: " - f"{duplicated_initializers!r}" - ) - def call(self, model: ir.Model) -> ir.passes.PassResult: count = 0 + registered_initializer_names: dict[str, int] = {} for graph in model.graphs(): if graph is model.graph: continue @@ -182,6 +158,15 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: continue # Remove the initializer from the subgraph graph.initializers.pop(name) + # To avoid name conflicts, we need to rename the initializer + # to a unique name in the main graph + if name in registered_initializer_names: + name_count = registered_initializer_names[name] + initializer.name = f"{name}_{name_count}" + registered_initializer_names[name] = name_count + 1 + else: + assert initializer.name is not None + registered_initializer_names[initializer.name] = 1 model.graph.register_initializer(initializer) count += 1 logger.debug( diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index afb528a8..f3862917 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -311,13 +311,6 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( graph=main_graph, ir_version=10, ) - if then_initializer_name == else_initializer_name: - with self.assertRaisesRegex( - ir.passes.PreconditionError, - "Initializer name must be unique across the main graph and subgraphs", - ): - constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) - return result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) self.assertTrue(result.modified) @@ -397,13 +390,6 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( graph=main_graph, ir_version=10, ) - if then_initializer_name == else_initializer_name: - with self.assertRaisesRegex( - ir.passes.PreconditionError, - "Initializer name must be unique across the main graph and subgraphs", - ): - constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) - return result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) self.assertTrue(result.modified)