diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index f1fda00849..e38f0f443d 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -73,7 +73,14 @@ def _identity_to_itself(op, data, **_): def _potential_redundant_slice(op, data, starts, ends, axes, steps): """To identify a slice op""" - return op.Slice(data, starts, ends, axes, steps) + return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"]) + + +def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_): + """Check if the shape of the slice output is the same as the data.""" + if data.shape is None or slice_output.shape is None: + return False + return data.shape == slice_output.shape # Register the rewrite rules @@ -83,5 +90,13 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps): _check_if_redundant_slice, ) -# NOTE: The order of the rules is important. Larger pattern should be checked first. -rules = RewriteRuleSet([remove_redundant_slice]) +remove_redundant_slice2 = RewriteRule( + _potential_redundant_slice, + _identity_to_itself, + _same_shape, +) + +# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one, +# provided shape-inference is run before the rewriter and computes the shape of the slice output. + +rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2]) diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/collapse_slices_test.py index 7e7a4c15c4..ce803b8a4f 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/collapse_slices_test.py @@ -65,11 +65,11 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self): (np.random.rand(512, 16, 112).astype(np.float32),), ) - def test_slice_pattern_is_not_matched_when_input_is_dynamic(self): + def test_slice_unequal_dynamic_shape(self): model_proto = onnx.parser.parse_model( f""" - agraph (float[L, M, N] data) => (float[L, M, N] output) + agraph (float[L, M, N] data) => (float[P, M, N] output) {{ starts = Constant() ends = Constant() @@ -82,3 +82,21 @@ def test_slice_pattern_is_not_matched_when_input_is_dynamic(self): model = ir.serde.deserialize_model(model_proto) count = collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 0) + + def test_slice_equal_dynamic_shape(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[L, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1)