Skip to content

Commit

Permalink
[FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in Slice, Squeez…
Browse files Browse the repository at this point in the history
…e and Flatten (#17490)
  • Loading branch information
PatrikPerssonInceptron authored Oct 31, 2024
1 parent e3e27f5 commit de93c37
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 32 deletions.
59 changes: 48 additions & 11 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,14 +1199,29 @@ class Squeeze(OnnxOpConverter):

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
axis = get_constant(inputs[1], params)
if isinstance(axis, relax.Constant):
axis = [int(x) for x in axis.data.numpy()]
axis = tuple([int(x) for x in axis.data.numpy()])

# If data is constant, perform computation directly.
if isinstance(inputs[0], relax.Constant):
out_data = _np.squeeze(inputs[0].data.numpy(), axis)
return relax.const(out_data, inputs[0].struct_info.dtype)
return relax.op.squeeze(inputs[0], axis)
if isinstance(data, relax.Constant):
if isinstance(axis, (tuple, type(None))):
out_data = _np.squeeze(data.data.numpy(), axis)
else:
raise NotImplementedError("Squeeze with symbolic axes not supported")

return relax.const(out_data, data.struct_info.dtype)

if isinstance(data, relax.ShapeExpr):
if axis == (0,):
return relax.PrimValue(data[0])
else:
raise NotImplementedError(
"Squeeze with symbolic axes and non-zero axes is not supported."
)

return relax.op.squeeze(data, axis)


class Constant(OnnxOpConverter):
Expand Down Expand Up @@ -1559,12 +1574,12 @@ def _impl_v13(cls, bb, inputs, attr, params):
splits_rank = splits.checked_type.ndim
if splits is not None and splits_rank > 0:
if isinstance(splits, relax.Constant):
splits = splits.data.asnumpy()
splits = splits.data.numpy()
indices = []
index = 0
for i in splits[:-1]:
index += i
indices.append(index)
indices.append(index.item())
else:
raise ValueError("Dynamic Split not yet supported")
# When splits isnt specified divide evenly over axis.
Expand Down Expand Up @@ -1611,11 +1626,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
steps = [1] * len(axes)
# If input is a shape tensor, we can directly extract it.
if isinstance(data, relax.ShapeExpr):
shape_data = [dim.value for dim in data]
shape_data = list(data)
# Starts, ends, and steps must be 1-d for shape operation.
assert all(len(i) == 1 for i in [starts, ends, steps])
sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
return relax.const(sliced_values, "int64")

if all([isinstance(val, (tir.IntImm, int)) for val in sliced_values]):
return relax.const([x.value for x in sliced_values], "int64")
else:
return relax.ShapeExpr(sliced_values)

# If all `starts`, `ends`, and `steps` are constant, use strict mode
# Otherwise, we assume the slice is inbound.
assume_inbound = not all(
Expand Down Expand Up @@ -2237,8 +2257,24 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 1)
data_shape = [i.value for i in inputs[0].struct_info.shape]
new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1)
data_shape = list(inputs[0].struct_info.shape)

if axis == 0:
new_shape = (1, -1)
else:
shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in data_shape[0:axis]]

if all(shape_flags):
data_shape = [x.value for x in data_shape[0:axis]]
new_shape = (_np.prod(data_shape).astype("int64"), -1)
else:
batch_size = 1

for el in data_shape[0:axis]:
batch_size = batch_size * el

new_shape = (batch_size, -1)

return relax.op.reshape(inputs[0], new_shape)


Expand Down Expand Up @@ -3220,6 +3256,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
"Equal",
"Where",
"Cast",
"Squeeze",
]
return_tuple_ops = [
"SequenceConstruct",
Expand Down
Loading

0 comments on commit de93c37

Please sign in to comment.