Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

from onnxscript import ir
from onnxscript.rewriter._ir_utils import is_singleton_value
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,10 +77,14 @@ def _potential_redundant_slice(op, data, starts, ends, axes, steps):
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])


def _same_shape(op, data: ir.Value, slice_output: ir.Value, **_):
def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_):
"""Check if the shape of the slice output is the same as the data."""
if data.shape is None or slice_output.shape is None:
return False

if not is_singleton_value(steps, 1):
return False

return data.shape == slice_output.shape


Expand Down
19 changes: 19 additions & 0 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,22 @@ def test_slice_equal_dynamic_shape(self):
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)

def test_slice_equal_dynamic_shape_but_step_reverse(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[L, M, N] data) => (float[L, M, N] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{9}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{-1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
# Should not change the output shape if we did not use the default step of 1
self.assertEqual(count, 0)
Loading