Skip to content

Commit f42c2bb

Browse files
authored
Improve redundant slice removal (#2441)
Improve the optimization for removal of redundant slices. (It doesn't currently handle dynamic shapes.) The optimization is fairly simple, and eliminates the slice when the input and output shapes are same. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent e63a16b commit f42c2bb

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

onnxscript/rewriter/collapse_slices.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def _identity_to_itself(op, data, **_):
7373

7474
def _potential_redundant_slice(op, data, starts, ends, axes, steps):
7575
"""To identify a slice op"""
76-
return op.Slice(data, starts, ends, axes, steps)
76+
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])
77+
78+
79+
def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_):
80+
"""Check if the shape of the slice output is the same as the data."""
81+
if data.shape is None or slice_output.shape is None:
82+
return False
83+
return data.shape == slice_output.shape
7784

7885

7986
# Register the rewrite rules
@@ -83,5 +90,13 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps):
8390
_check_if_redundant_slice,
8491
)
8592

86-
# NOTE: The order of the rules is important. Larger pattern should be checked first.
87-
rules = RewriteRuleSet([remove_redundant_slice])
93+
remove_redundant_slice2 = RewriteRule(
94+
_potential_redundant_slice,
95+
_identity_to_itself,
96+
_same_shape,
97+
)
98+
99+
# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one,
100+
# provided shape-inference is run before the rewriter and computes the shape of the slice output.
101+
102+
rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2])

onnxscript/rewriter/collapse_slices_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self):
6565
(np.random.rand(512, 16, 112).astype(np.float32),),
6666
)
6767

68-
def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
68+
def test_slice_unequal_dynamic_shape(self):
6969
model_proto = onnx.parser.parse_model(
7070
f"""
7171
<ir_version: 7, opset_import: [ "" : 17]>
72-
agraph (float[L, M, N] data) => (float[L, M, N] output)
72+
agraph (float[L, M, N] data) => (float[P, M, N] output)
7373
{{
7474
starts = Constant<value: tensor = int64[1] {{0}}>()
7575
ends = Constant<value: tensor = int64[1] {{{9}}}>()
@@ -82,3 +82,21 @@ def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
8282
model = ir.serde.deserialize_model(model_proto)
8383
count = collapse_slices.rules.apply_to_model(model)
8484
self.assertEqual(count, 0)
85+
86+
def test_slice_equal_dynamic_shape(self):
87+
model_proto = onnx.parser.parse_model(
88+
f"""
89+
<ir_version: 7, opset_import: [ "" : 17]>
90+
agraph (float[L, M, N] data) => (float[L, M, N] output)
91+
{{
92+
starts = Constant<value: tensor = int64[1] {{0}}>()
93+
ends = Constant<value: tensor = int64[1] {{{9}}}>()
94+
axes = Constant<value: tensor = int64[1] {{0}}>()
95+
steps = Constant<value: tensor = int64[1] {{1}}>()
96+
output = Slice (data, starts, ends, axes, steps)
97+
}}
98+
"""
99+
)
100+
model = ir.serde.deserialize_model(model_proto)
101+
count = collapse_slices.rules.apply_to_model(model)
102+
self.assertEqual(count, 1)

0 commit comments

Comments
 (0)