diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index dd4b8a425425..24217184b57c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1917,18 +1917,20 @@ def _impl_v13(cls, bb, inputs, attr, params): # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): new_shape = shape.data.numpy().tolist() - # For some reason, onnx allows target shapes to be smaller than input shapes. - # We need to go correct it. + # ONNX Expand operator requires preserving target rank and broadcasting + # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] - # Dimensions are right alignment. - data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape - # Fix small target shapes. - for i, s in enumerate(new_shape): - if i < len(data_shape) and s < data_shape[i]: + + # Right-align the shapes + if len(new_shape) > len(data_shape): + data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape + else: + new_shape = [1] * (len(data_shape) - len(new_shape)) + new_shape + # Fix small target shapes - if target dim is smaller than input dim + # use the input dim (ONNX-specific behavior). + for i in range(len(new_shape)): + if new_shape[i] < data_shape[i]: new_shape[i] = data_shape[i] - # If the new shape matches the input shape, no transformation is needed. - if new_shape == data_shape: - return data return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes. diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 10c185ae09d6..ebc1454c2302 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1692,6 +1692,12 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data): data = np.random.uniform(size=in_shape).astype(np.float32) ref_data = np.tile(data, (1, 1, 4)) _test_expand("expand_with_diff_dim", data, shape, ref_data) + + in_shape = (3, 1) + shape = (1, 1, 3, 1) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, (1, 1, 1, 1)) + _test_expand("expand_with_the_same_suffix_dims", data, shape, ref_data) else: in_shape = (1, 32, 32) shape = ("batch", 32, 32)