Skip to content

Commit 6d33d22

Browse files
[pass] Remove unused initialized inputs in DCE (#2212)
Fix #2211 This pull request enhances the functionality of the `RemoveUnusedNodesPass` class and its associated methods by introducing an option to remove unused initialized inputs. It also updates the corresponding tests to validate this new behavior. The changes improve the flexibility of the unused node removal process and ensure the model input signature remains consistent unless explicitly modified. ### Enhancements to `RemoveUnusedNodesPass`: * Added a new `remove_initialized_inputs` attribute to the `RemoveUnusedNodesPass` class, allowing the removal of unused initialized inputs when enabled. This change modifies the model input signature if unused inputs are removed (`onnxscript/ir/passes/common/unused_removal.py`). --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent bc7671c commit 6d33d22

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

onnxscript/ir/passes/common/unused_removal.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,27 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph
9393

9494

9595
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
96+
"""Pass for removing unused nodes and initializers.
97+
98+
Attributes:
99+
remove_initialized_inputs: When an unused initializer is simultaneously a graph input,
100+
remove that input as well. Note that this will change the model input signature.
101+
"""
102+
103+
def __init__(self, remove_initialized_inputs: bool = False):
104+
super().__init__()
105+
self.remove_initialized_inputs = remove_initialized_inputs
106+
96107
def call(self, model: ir.Model) -> ir.passes.PassResult:
97108
count = _remove_unused_nodes_in_graph_like(model.graph)
98109
graph_outputs = frozenset(model.graph.outputs)
99110
initializers = model.graph.initializers
111+
if self.remove_initialized_inputs:
112+
graph_inputs = model.graph.inputs
113+
for i, inp in reversed(list(enumerate(graph_inputs))):
114+
if inp.name in initializers and not (inp in graph_outputs or inp.uses()):
115+
del graph_inputs[i]
116+
count += 1
100117
for init in list(initializers.values()):
101118
if not (init in graph_outputs or init.uses()):
102119
assert init.name is not None

onnxscript/ir/passes/common/unused_removal_test.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
class RemoveUnusedTest(unittest.TestCase):
1414
using_ir: bool
1515

16-
def remove_unused_nodes(self, model: onnx.ModelProto):
16+
def remove_unused_nodes(
17+
self, model: onnx.ModelProto, remove_initialized_inputs: bool = False
18+
):
1719
if self.using_ir:
1820
model_ir = ir.serde.deserialize_model(model)
19-
onnxscript.optimizer.remove_unused_nodes(model_ir)
21+
onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs)
2022
model = ir.serde.serialize_model(model_ir)
2123
return model
22-
onnxscript.optimizer.remove_unused_nodes(model)
24+
onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs)
2325
return model
2426

2527
def test_remove_unused_nodes(self):
@@ -54,6 +56,59 @@ def test_remove_unused_initializers(self):
5456
self.assertEqual(model.graph.node[0].op_type, "Mul")
5557
self.assertEqual(len(model.graph.initializer), 0)
5658

59+
def test_unused_initialized_inputs_are_removed_when_requested(self):
60+
# https://github.com/microsoft/onnxscript/issues/2211
61+
model = onnx.parser.parse_model(
62+
"""
63+
<ir_version: 10, opset_import: [ "" : 17]>
64+
agraph (float[N] x, float[N] two) => (float[N] z)
65+
<float two = {2.0,2.0}> {
66+
four = Add(two, two)
67+
z = Mul(x, x)
68+
}
69+
"""
70+
)
71+
model = self.remove_unused_nodes(model, remove_initialized_inputs=True)
72+
self.assertEqual(len(model.graph.node), 1)
73+
self.assertEqual(model.graph.node[0].op_type, "Mul")
74+
self.assertEqual(len(model.graph.input), 1)
75+
76+
def test_unused_initialized_inputs_are_kept_by_default(self):
77+
model = onnx.parser.parse_model(
78+
"""
79+
<ir_version: 10, opset_import: [ "" : 17]>
80+
agraph (float[N] x, float[N] two) => (float[N] z)
81+
<float two = {2.0,2.0}> {
82+
four = Add(two, two)
83+
z = Mul(x, x)
84+
}
85+
"""
86+
)
87+
model = self.remove_unused_nodes(model)
88+
self.assertEqual(len(model.graph.node), 1)
89+
self.assertEqual(model.graph.node[0].op_type, "Mul")
90+
self.assertEqual(len(model.graph.input), 2)
91+
92+
@parameterized.parameterized.expand([True, False])
93+
def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
94+
# preserve inputs as part of interface
95+
model = onnx.parser.parse_model(
96+
"""
97+
<ir_version: 10, opset_import: [ "" : 17]>
98+
agraph (float[N] x, float[N] two) => (float[N] z)
99+
{
100+
four = Add(two, two)
101+
z = Mul(x, x)
102+
}
103+
"""
104+
)
105+
model = self.remove_unused_nodes(
106+
model, remove_initialized_inputs=remove_initialized_inputs
107+
)
108+
self.assertEqual(len(model.graph.node), 1)
109+
self.assertEqual(model.graph.node[0].op_type, "Mul")
110+
self.assertEqual(len(model.graph.input), 2)
111+
57112
def test_partially_used_nodes(self):
58113
model = onnx.parser.parse_model(
59114
"""

onnxscript/optimizer/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,19 @@ def fold_constants(
112112
return result
113113

114114

115-
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
115+
def remove_unused_nodes(
116+
model: ir.Model | onnx.ModelProto, remove_initialized_inputs: bool = False
117+
) -> None:
116118
"""Removes unused nodes from a model inplace."""
117119
if isinstance(model, ir.Model):
118-
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
120+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
121+
remove_initialized_inputs=remove_initialized_inputs
122+
)(model)
119123
else:
120124
model_ir = ir.serde.deserialize_model(model)
121-
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(
122-
model_ir
123-
).model
125+
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
126+
remove_initialized_inputs=remove_initialized_inputs
127+
)(model_ir).model
124128
new_proto = ir.serde.serialize_model(model_ir)
125129
model.Clear()
126130
model.CopyFrom(new_proto)

0 commit comments

Comments
 (0)