Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 77 additions & 8 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,15 +1910,47 @@ def _impl_v13(cls, bb, inputs, attr, params):
if isinstance(shape, relax.ShapeExpr):
data_shape = list(data.struct_info.shape)
target_shape = list(shape.values)
original_data_shape = [
dim.value if hasattr(dim, "value") else str(dim) for dim in data_shape
]
original_target_shape = [
dim.value if hasattr(dim, "value") else str(dim) for dim in target_shape
]
data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape
assert len(data_shape) == len(target_shape)
# Fix small target shapes or target shapes assigned to -1
# Apply ONNX v13 Expand broadcasting rules
for i, s in enumerate(target_shape):
if isinstance(s, tvm.tir.IntImm) and (
(isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i])
or s.value == -1
):
target_shape[i] = data_shape[i]
if isinstance(s, tvm.tir.IntImm):
if s.value == -1:
# -1 means preserve the input dimension
target_shape[i] = data_shape[i]
elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1:
# Input dimension is 1, can broadcast to any target dimension >= 1
if s.value < 1:
raise ValueError(
f"ONNX Expand: Invalid target dimension {s.value} "
f"at possition {i}. Target dimensions must be >= 1."
)
elif (
isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value
):
# Dimensions match, no change needed
pass
elif s.value == 1:
# Target dimension is 1 but input dimension is not 1
# This would "squeeze" the dimension - preserve input for safety
target_shape[i] = data_shape[i]
else:
if isinstance(data_shape[i], tvm.tir.IntImm):
raise ValueError(
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} "
f"to target shape {original_target_shape}. "
f"At dimension {i}: input size {data_shape[i].value} is "
f"incompatible with target size {s.value}. "
f"ONNX broadcasting requires corresponding dimensions to have "
f"the same value or one of them to be 1."
)
# For dynamic shapes, let broadcast_to handle it
if target_shape == data_shape:
return data
return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
Expand All @@ -1929,6 +1961,8 @@ def _impl_v13(cls, bb, inputs, attr, params):
# ONNX Expand operator requires preserving target rank and broadcasting
# according to standard rules. Dimensions are right-aligned.
data_shape = [dim.value for dim in data.struct_info.shape]
original_data_shape = data_shape.copy()
original_new_shape = new_shape.copy()

# Right-align the shapes
if len(new_shape) > len(data_shape):
Expand All @@ -1938,8 +1972,32 @@ def _impl_v13(cls, bb, inputs, attr, params):
# Fix small target shapes - if target dim is smaller than input dim
# use the input dim (ONNX-specific behavior).
for i in range(len(new_shape)):
if new_shape[i] < data_shape[i]:
if new_shape[i] == -1:
# -1 means preserve the input dimension
new_shape[i] = data_shape[i]
elif data_shape[i] == 1:
# Input dimension is 1, can broadcast to any target dimension >= 1
if new_shape[i] < 1:
raise ValueError(
f"ONNX Expand: Invalid target dimension {new_shape[i]} "
f"at possition {i}. Target dimensions must be >= 1."
)
elif new_shape[i] == data_shape[i]:
# Dimensions match, no change needed
pass
elif new_shape[i] == 1:
# Target dimension is 1 but input dimension is not 1
# This would "squeeze" the dimension - preserve input for safety
new_shape[i] = data_shape[i]
else:
raise ValueError(
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} "
f"to target shape {original_new_shape}. "
f"At dimension {i}: input size {data_shape[i]} is incompatible "
f"with target size {new_shape[i]}. "
f"ONNX broadcasting requires corresponding dimensions to have the same "
f"value or one of them to be 1."
)
return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))

# Otherwise handle dynamic shapes.
Expand All @@ -1956,7 +2014,18 @@ def _impl_v13(cls, bb, inputs, attr, params):
for i in range(shape_ndim):
shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars)))

# Applying broadcasting rules for dynamic shapes
data_shape = list(data.struct_info.shape)
data_ndim = len(data_shape)
target_ndim = shape_ndim
padded_data = data

if target_ndim > data_ndim:
padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape
padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape)))

return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars)))


class Attention(OnnxOpConverter):
Expand Down
100 changes: 100 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,6 +1909,106 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data):
_test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data)


def test_expand_incompatible_broadcasting():
"""
This test case reproduces the error where input tensor shape at dim 1 is 25
and target shape at dim 3 is 56, which violates ONNX broadcasting rules
"""

def _test_expand_error_case(name, data_shape, target_shape_vals):
data = np.random.uniform(size=data_shape).astype(np.float32)

shape_array = np.array(target_shape_vals, dtype=np.int64)
shape_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["shape"],
value=onnx.helper.make_tensor(
name="const_tensor",
data_type=onnx.TensorProto.INT64,
dims=shape_array.shape,
vals=shape_array.flatten(),
),
)

expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])

graph = helper.make_graph(
[shape_node, expand_node],
"expand_error_test",
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
)

model = helper.make_model(graph, producer_name=name)

with pytest.raises(ValueError) as exc_info:
from_onnx(model, keep_params_in_input=True)

error_msg = str(exc_info.value)
assert (
"broadcast" in error_msg.lower() or "incompatible" in error_msg.lower()
), f"Expected broadcasting error, but got: {error_msg}"

# Test case 1: Reproduce the exact error from the issue-17769
# Input shape: (25,), target shape: (1, 1, 1, 56)
# This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1
_test_expand_error_case(
"expand_incompatible_25_to_56",
data_shape=(25,),
target_shape_vals=(1, 1, 1, 56),
)

# Test case 2: Another incompatible case
# Input shape: (1, 25), target shape: (1, 1, 1, 56)
# After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56)
# This should fail because 25 != 56 and neither is 1
_test_expand_error_case(
"expand_incompatible_aligned_25_to_56",
data_shape=(1, 25),
target_shape_vals=(1, 1, 1, 56),
)

# Test case 3: Valid case for comparison - should not raise error
def _test_expand_valid_case():
"""Test a valid expand case to ensure our fix doesn't break valid operations"""
data_shape = (1, 25)
target_shape_vals = [2, 25] # Valid: input (1, 25) can broadcast to (2, 25)

data = np.random.uniform(size=data_shape).astype(np.float32)
shape_array = np.array(target_shape_vals, dtype=np.int64)

shape_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["shape"],
value=onnx.helper.make_tensor(
name="const_tensor",
data_type=onnx.TensorProto.INT64,
dims=shape_array.shape,
vals=shape_array.flatten(),
),
)

expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])

graph = helper.make_graph(
[shape_node, expand_node],
"expand_valid_test",
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
)

model = helper.make_model(graph, producer_name="expand_valid_test_case")

try:
tvm_model = from_onnx(model, keep_params_in_input=True)
except Exception as e:
pytest.fail(f"Valid expand case should not fail, but got error: {e}")

_test_expand_valid_case()


# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed.
@pytest.mark.skip("Produces ill-formed IR")
def test_constantofshape():
Expand Down
Loading