Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX][ONNX][RELAX] Add support for dynamic ShapeExpr in Slice, Squeeze and Flatten #17490

Merged
merged 9 commits into from
Oct 31, 2024
59 changes: 48 additions & 11 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,14 +1199,29 @@ class Squeeze(OnnxOpConverter):

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
axis = get_constant(inputs[1], params)
if isinstance(axis, relax.Constant):
axis = [int(x) for x in axis.data.numpy()]
axis = tuple([int(x) for x in axis.data.numpy()])

# If data is constant, perform computation directly.
if isinstance(inputs[0], relax.Constant):
out_data = _np.squeeze(inputs[0].data.numpy(), axis)
return relax.const(out_data, inputs[0].struct_info.dtype)
return relax.op.squeeze(inputs[0], axis)
if isinstance(data, relax.Constant):
if isinstance(axis, (tuple, type(None))):
out_data = _np.squeeze(data.data.numpy(), axis)
else:
raise NotImplementedError("Squeeze with symbolic axes not supported")

return relax.const(out_data, data.struct_info.dtype)

if isinstance(data, relax.ShapeExpr):
if axis == (0,):
return relax.PrimValue(data[0])
else:
raise NotImplementedError(
"Squeeze with symbolic axes and non-zero axes is not supported."
)

return relax.op.squeeze(data, axis)


class Constant(OnnxOpConverter):
Expand Down Expand Up @@ -1559,12 +1574,12 @@ def _impl_v13(cls, bb, inputs, attr, params):
splits_rank = splits.checked_type.ndim
if splits is not None and splits_rank > 0:
if isinstance(splits, relax.Constant):
splits = splits.data.asnumpy()
splits = splits.data.numpy()
indices = []
index = 0
for i in splits[:-1]:
index += i
indices.append(index)
indices.append(index.item())
else:
raise ValueError("Dynamic Split not yet supported")
# When splits isnt specified divide evenly over axis.
Expand Down Expand Up @@ -1611,11 +1626,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
steps = [1] * len(axes)
# If input is a shape tensor, we can directly extract it.
if isinstance(data, relax.ShapeExpr):
shape_data = [dim.value for dim in data]
shape_data = list(data)
# Starts, ends, and steps must be 1-d for shape operation.
assert all(len(i) == 1 for i in [starts, ends, steps])
sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
return relax.const(sliced_values, "int64")

if all([isinstance(val, (tir.IntImm, int)) for val in sliced_values]):
return relax.const([x.value for x in sliced_values], "int64")
else:
return relax.ShapeExpr(sliced_values)

# If all `starts`, `ends`, and `steps` are constant, use strict mode
# Otherwise, we assume the slice is inbound.
assume_inbound = not all(
Expand Down Expand Up @@ -2237,8 +2257,24 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 1)
data_shape = [i.value for i in inputs[0].struct_info.shape]
new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1)
data_shape = list(inputs[0].struct_info.shape)

if axis == 0:
new_shape = (1, -1)
else:
shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in data_shape[0:axis]]

if all(shape_flags):
data_shape = [x.value for x in data_shape[0:axis]]
new_shape = (_np.prod(data_shape).astype("int64"), -1)
else:
batch_size = 1

for el in data_shape[0:axis]:
batch_size = batch_size * el

new_shape = (batch_size, -1)

return relax.op.reshape(inputs[0], new_shape)


Expand Down Expand Up @@ -3220,6 +3256,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
"Equal",
"Where",
"Cast",
"Squeeze",
]
return_tuple_ops = [
"SequenceConstruct",
Expand Down
Loading
Loading