Skip to content

Commit 89ef16c

Browse files
authored
[Pass] Support lifting subgraph initializers to main graph (#2266)
Fix #2157
1 parent 34e7ba8 commit 89ef16c

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
__all__ = [
88
"LiftConstantsToInitializersPass",
9+
"LiftSubgraphInitializersToMainGraphPass",
910
]
1011

1112
import logging
@@ -126,3 +127,39 @@ def _constant_node_attribute_to_tensor(
126127
)
127128
return None
128129
return tensor
130+
131+
132+
class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
133+
"""Lift subgraph initializers to main graph.
134+
135+
This pass lifts the initializers of a subgraph to the main graph.
136+
It is used to ensure that the initializers are available in the main graph
137+
for further processing or optimization.
138+
"""
139+
140+
def call(self, model: ir.Model) -> ir.passes.PassResult:
141+
count = 0
142+
registered_initializer_names: dict[str, int] = {}
143+
for graph in model.graphs():
144+
if graph is model.graph:
145+
continue
146+
for name, initializer in graph.initializers.items():
147+
# To avoid name conflicts, we need to rename the initializer
148+
# to a unique name in the main graph
149+
if name in registered_initializer_names:
150+
name_count = registered_initializer_names[name]
151+
initializer.name = f"{name}_{name_count}"
152+
registered_initializer_names[name] = name_count + 1
153+
else:
154+
assert initializer.name is not None
155+
registered_initializer_names[initializer.name] = 1
156+
model.graph.register_initializer(initializer)
157+
count += 1
158+
logger.debug(
159+
"Lifted initializer '%s' from subgraph '%s' to main graph",
160+
initializer.name,
161+
graph.name,
162+
)
163+
# Remove the initializer from the subgraph
164+
graph.initializers.clear()
165+
return ir.passes.PassResult(model, modified=bool(count))

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,5 +251,87 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
251251
self.assertEqual(len(result.model.graph.initializers), 0)
252252

253253

254+
class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase):
255+
@parameterized.parameterized.expand(
256+
[
257+
("then_initializer", "else_initializer"),
258+
("initializer", "initializer"),
259+
]
260+
)
261+
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
262+
self, then_initializer_name, else_initializer_name
263+
):
264+
input_value = ir.Value(
265+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
266+
)
267+
268+
then_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
269+
then_initializer_value = ir.Value(
270+
name=then_initializer_name,
271+
shape=then_initializer_tensor.shape,
272+
type=ir.TensorType(ir.DataType.FLOAT),
273+
const_value=then_initializer_tensor,
274+
)
275+
276+
# then branch adds the constant to the input
277+
# else branch multiplies the input by the constant
278+
add_node = ir.node("Add", inputs=[input_value, then_initializer_value])
279+
then_graph = ir.Graph(
280+
inputs=[input_value, then_initializer_value],
281+
outputs=[add_node.outputs[0]],
282+
nodes=[add_node],
283+
opset_imports={"": 20},
284+
initializers=[then_initializer_value],
285+
)
286+
else_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
287+
else_initializer_value = ir.Value(
288+
name=else_initializer_name,
289+
shape=else_initializer_tensor.shape,
290+
type=ir.TensorType(ir.DataType.FLOAT),
291+
const_value=else_initializer_tensor,
292+
)
293+
mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value])
294+
else_graph = ir.Graph(
295+
inputs=[input_value],
296+
outputs=[mul_node.outputs[0]],
297+
nodes=[mul_node],
298+
opset_imports={"": 20},
299+
initializers=[else_initializer_value],
300+
)
301+
# create a conditional node that uses the then and else graphs
302+
cond_node = ir.node(
303+
"If",
304+
inputs=[input_value],
305+
attributes={"then_branch": then_graph, "else_branch": else_graph},
306+
num_outputs=1,
307+
)
308+
# construnct the model
309+
main_graph = ir.Graph(
310+
inputs=[input_value],
311+
outputs=cond_node.outputs,
312+
nodes=[cond_node],
313+
opset_imports={"": 20},
314+
)
315+
main_graph.sort()
316+
model = ir.Model(
317+
graph=main_graph,
318+
ir_version=10,
319+
)
320+
result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
321+
self.assertTrue(result.modified)
322+
323+
self.assertEqual(len(else_graph.initializers), 0)
324+
self.assertEqual(len(then_graph.initializers), 0)
325+
self.assertEqual(len(main_graph.initializers), 2)
326+
for value, tensor in zip(
327+
main_graph.initializers.values(),
328+
[then_initializer_tensor, else_initializer_tensor],
329+
):
330+
self.assertIs(
331+
value.const_value,
332+
tensor,
333+
)
334+
335+
254336
if __name__ == "__main__":
255337
unittest.main()

onnxscript/optimizer/_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def optimize_ir(
5454
),
5555
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
5656
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
57+
onnxscript.ir.passes.common.constant_manipulation.LiftSubgraphInitializersToMainGraphPass(),
5758
]
5859
if inline:
5960
# Inline all functions first before optimizing

0 commit comments

Comments
 (0)