Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

__all__ = [
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
]

import logging
Expand Down Expand Up @@ -126,3 +127,39 @@ def _constant_node_attribute_to_tensor(
)
return None
return tensor


class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
"""Lift subgraph initializers to main graph.

This pass lifts the initializers of a subgraph to the main graph.
It is used to ensure that the initializers are available in the main graph
for further processing or optimization.
"""

def call(self, model: ir.Model) -> ir.passes.PassResult:
count = 0
registered_initializer_names: dict[str, int] = {}
for graph in model.graphs():
if graph is model.graph:
continue
for name, initializer in graph.initializers.items():
# To avoid name conflicts, we need to rename the initializer
# to a unique name in the main graph
if name in registered_initializer_names:
name_count = registered_initializer_names[name]
initializer.name = f"{name}_{name_count}"
registered_initializer_names[name] = name_count + 1
else:
assert initializer.name is not None
registered_initializer_names[initializer.name] = 1
model.graph.register_initializer(initializer)
count += 1
logger.debug(
"Lifted initializer '%s' from subgraph '%s' to main graph",
initializer.name,
graph.name,
)
# Remove the initializer from the subgraph
graph.initializers.clear()
return ir.passes.PassResult(model, modified=bool(count))
82 changes: 82 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,87 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
self.assertEqual(len(result.model.graph.initializers), 0)


class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
("then_initializer", "else_initializer"),
("initializer", "initializer"),
]
)
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
self, then_initializer_name, else_initializer_name
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

then_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
then_initializer_value = ir.Value(
name=then_initializer_name,
shape=then_initializer_tensor.shape,
type=ir.TensorType(ir.DataType.FLOAT),
const_value=then_initializer_tensor,
)

# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.node("Add", inputs=[input_value, then_initializer_value])
then_graph = ir.Graph(
inputs=[input_value, then_initializer_value],
outputs=[add_node.outputs[0]],
nodes=[add_node],
opset_imports={"": 20},
initializers=[then_initializer_value],
)
else_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
else_initializer_value = ir.Value(
name=else_initializer_name,
shape=else_initializer_tensor.shape,
type=ir.TensorType(ir.DataType.FLOAT),
const_value=else_initializer_tensor,
)
mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[mul_node],
opset_imports={"": 20},
initializers=[else_initializer_value],
)
# create a conditional node that uses the then and else graphs
cond_node = ir.node(
"If",
inputs=[input_value],
attributes={"then_branch": then_graph, "else_branch": else_graph},
num_outputs=1,
)
# construnct the model
main_graph = ir.Graph(
inputs=[input_value],
outputs=cond_node.outputs,
nodes=[cond_node],
opset_imports={"": 20},
)
main_graph.sort()
model = ir.Model(
graph=main_graph,
ir_version=10,
)
result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model)
self.assertTrue(result.modified)

self.assertEqual(len(else_graph.initializers), 0)
self.assertEqual(len(then_graph.initializers), 0)
self.assertEqual(len(main_graph.initializers), 2)
for value, tensor in zip(
main_graph.initializers.values(),
[then_initializer_tensor, else_initializer_tensor],
):
self.assertIs(
value.const_value,
tensor,
)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def optimize_ir(
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftSubgraphInitializersToMainGraphPass(),
]
if inline:
# Inline all functions first before optimizing
Expand Down
Loading