From 122ef538b9e1bcc95eec6fb4f2ea8ffb22c04148 Mon Sep 17 00:00:00 2001 From: Sebastian Mossburger Date: Wed, 6 Aug 2025 09:58:12 +0200 Subject: [PATCH 1/4] Add test case for reverse step rewrite --- onnxscript/rewriter/collapse_slices_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/collapse_slices_test.py index ce803b8a4f..52b59f9037 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/collapse_slices_test.py @@ -100,3 +100,22 @@ def test_slice_equal_dynamic_shape(self): model = ir.serde.deserialize_model(model_proto) count = collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) + + def test_slice_equal_dynamic_shape_but_step_reverse(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[L, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = collapse_slices.rules.apply_to_model(model) + # Should not change the output shape if we did not use the default step of 1 + self.assertEqual(count, 0) From 16ce711111fcd2d4832d5fef3ceccbd79cde4edc Mon Sep 17 00:00:00 2001 From: Sebastian Mossburger Date: Wed, 6 Aug 2025 10:06:50 +0200 Subject: [PATCH 2/4] Add potential fix --- onnxscript/rewriter/collapse_slices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index e38f0f443d..a83ac29f38 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -76,10 +76,14 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps): return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"]) -def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_): +def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_): """Check if the shape of the slice output is the same as the data.""" if data.shape is None or slice_output.shape is None: return False + + if not (steps.const_value.numpy() == 1).all(): + return False + return data.shape == slice_output.shape From 11b3648aeb791e05b226d30d613ff7b887112875 Mon Sep 17 00:00:00 2001 From: Sebastian Mossburger Date: Thu, 7 Aug 2025 09:18:17 +0200 Subject: [PATCH 3/4] Use is_singleton_value for the check from onnxscript.rewriter --- onnxscript/rewriter/collapse_slices.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index a83ac29f38..d2fbf3ac43 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -6,6 +6,7 @@ from onnxscript import ir from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet +from onnxscript.rewriter._ir_utils import is_singleton_value logger = logging.getLogger(__name__) _INT64_MAX = 9223372036854775807 @@ -81,7 +82,7 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ if data.shape is None or slice_output.shape is None: return False - if not (steps.const_value.numpy() == 1).all(): + if not is_singleton_value(steps, 1): return False return data.shape == slice_output.shape From 42bc95a9030ba19510c1b062ebd66d7a42245e95 Mon Sep 17 00:00:00 2001 From: Sebastian Mossburger Date: Wed, 13 Aug 2025 11:46:45 +0200 Subject: [PATCH 4/4] Run lintrunner f --- onnxscript/rewriter/collapse_slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index d2fbf3ac43..291128157d 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -5,8 +5,8 @@ import logging from onnxscript import ir -from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet logger = logging.getLogger(__name__) _INT64_MAX = 9223372036854775807