diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5470c911d30b..7a4a65df6ec5 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1910,15 +1910,47 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(shape, relax.ShapeExpr): data_shape = list(data.struct_info.shape) target_shape = list(shape.values) + original_data_shape = [ + dim.value if hasattr(dim, "value") else str(dim) for dim in data_shape + ] + original_target_shape = [ + dim.value if hasattr(dim, "value") else str(dim) for dim in target_shape + ] data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape assert len(data_shape) == len(target_shape) - # Fix small target shapes or target shapes assigned to -1 + # Apply ONNX v13 Expand broadcasting rules for i, s in enumerate(target_shape): - if isinstance(s, tvm.tir.IntImm) and ( - (isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i]) - or s.value == -1 - ): - target_shape[i] = data_shape[i] + if isinstance(s, tvm.tir.IntImm): + if s.value == -1: + # -1 means preserve the input dimension + target_shape[i] = data_shape[i] + elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1: + # Input dimension is 1, can broadcast to any target dimension >= 1 + if s.value < 1: + raise ValueError( + f"ONNX Expand: Invalid target dimension {s.value} " + f"at possition {i}. Target dimensions must be >= 1." + ) + elif ( + isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value + ): + # Dimensions match, no change needed + pass + elif s.value == 1: + # Target dimension is 1 but input dimension is not 1 + # This would "squeeze" the dimension - preserve input for safety + target_shape[i] = data_shape[i] + else: + if isinstance(data_shape[i], tvm.tir.IntImm): + raise ValueError( + f"ONNX Expand: Cannot broadcast input shape {original_data_shape} " + f"to target shape {original_target_shape}. " + f"At dimension {i}: input size {data_shape[i].value} is " + f"incompatible with target size {s.value}. " + f"ONNX broadcasting requires corresponding dimensions to have " + f"the same value or one of them to be 1." + ) + # For dynamic shapes, let broadcast_to handle it if target_shape == data_shape: return data return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape)) @@ -1929,6 +1961,8 @@ def _impl_v13(cls, bb, inputs, attr, params): # ONNX Expand operator requires preserving target rank and broadcasting # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] + original_data_shape = data_shape.copy() + original_new_shape = new_shape.copy() # Right-align the shapes if len(new_shape) > len(data_shape): @@ -1938,8 +1972,32 @@ def _impl_v13(cls, bb, inputs, attr, params): # Fix small target shapes - if target dim is smaller than input dim # use the input dim (ONNX-specific behavior). for i in range(len(new_shape)): - if new_shape[i] < data_shape[i]: + if new_shape[i] == -1: + # -1 means preserve the input dimension + new_shape[i] = data_shape[i] + elif data_shape[i] == 1: + # Input dimension is 1, can broadcast to any target dimension >= 1 + if new_shape[i] < 1: + raise ValueError( + f"ONNX Expand: Invalid target dimension {new_shape[i]} " + f"at possition {i}. Target dimensions must be >= 1." + ) + elif new_shape[i] == data_shape[i]: + # Dimensions match, no change needed + pass + elif new_shape[i] == 1: + # Target dimension is 1 but input dimension is not 1 + # This would "squeeze" the dimension - preserve input for safety new_shape[i] = data_shape[i] + else: + raise ValueError( + f"ONNX Expand: Cannot broadcast input shape {original_data_shape} " + f"to target shape {original_new_shape}. " + f"At dimension {i}: input size {data_shape[i]} is incompatible " + f"with target size {new_shape[i]}. " + f"ONNX broadcasting requires corresponding dimensions to have the same " + f"value or one of them to be 1." + ) return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes. @@ -1956,7 +2014,18 @@ def _impl_v13(cls, bb, inputs, attr, params): for i in range(shape_ndim): shape_vars.append(tvm.tir.Var("x_%d" % i, "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) - return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars))) + + # Applying broadcasting rules for dynamic shapes + data_shape = list(data.struct_info.shape) + data_ndim = len(data_shape) + target_ndim = shape_ndim + padded_data = data + + if target_ndim > data_ndim: + padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape + padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape))) + + return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars))) class Attention(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 625cdebf7f61..d2f5a65593e4 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1909,6 +1909,106 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data): _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data) +def test_expand_incompatible_broadcasting(): + """ + This test case reproduces the error where input tensor shape at dim 1 is 25 + and target shape at dim 3 is 56, which violates ONNX broadcasting rules + """ + + def _test_expand_error_case(name, data_shape, target_shape_vals): + data = np.random.uniform(size=data_shape).astype(np.float32) + + shape_array = np.array(target_shape_vals, dtype=np.int64) + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten(), + ), + ) + + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph( + [shape_node, expand_node], + "expand_error_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)], + ) + + model = helper.make_model(graph, producer_name=name) + + with pytest.raises(ValueError) as exc_info: + from_onnx(model, keep_params_in_input=True) + + error_msg = str(exc_info.value) + assert ( + "broadcast" in error_msg.lower() or "incompatible" in error_msg.lower() + ), f"Expected broadcasting error, but got: {error_msg}" + + # Test case 1: Reproduce the exact error from the issue-17769 + # Input shape: (25,), target shape: (1, 1, 1, 56) + # This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1 + _test_expand_error_case( + "expand_incompatible_25_to_56", + data_shape=(25,), + target_shape_vals=(1, 1, 1, 56), + ) + + # Test case 2: Another incompatible case + # Input shape: (1, 25), target shape: (1, 1, 1, 56) + # After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56) + # This should fail because 25 != 56 and neither is 1 + _test_expand_error_case( + "expand_incompatible_aligned_25_to_56", + data_shape=(1, 25), + target_shape_vals=(1, 1, 1, 56), + ) + + # Test case 3: Valid case for comparison - should not raise error + def _test_expand_valid_case(): + """Test a valid expand case to ensure our fix doesn't break valid operations""" + data_shape = (1, 25) + target_shape_vals = [2, 25] # Valid: input (1, 25) can broadcast to (2, 25) + + data = np.random.uniform(size=data_shape).astype(np.float32) + shape_array = np.array(target_shape_vals, dtype=np.int64) + + shape_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=onnx.helper.make_tensor( + name="const_tensor", + data_type=onnx.TensorProto.INT64, + dims=shape_array.shape, + vals=shape_array.flatten(), + ), + ) + + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph( + [shape_node, expand_node], + "expand_valid_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)], + ) + + model = helper.make_model(graph, producer_name="expand_valid_test_case") + + try: + tvm_model = from_onnx(model, keep_params_in_input=True) + except Exception as e: + pytest.fail(f"Valid expand case should not fail, but got error: {e}") + + _test_expand_valid_case() + + # TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. @pytest.mark.skip("Produces ill-formed IR") def test_constantofshape():