diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 897c6a022594..c423598a2ee7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1560,34 +1560,26 @@ class Where(OnnxOpConverter): @classmethod def _impl_v9(cls, inputs, attr, params): - condition_shape = infer_shape(inputs[0]) - x_shape = infer_shape(inputs[1]) - y_shape = infer_shape(inputs[2]) - - # condition, x, and y can all be broadcasted. - # broadcast each of them to the longest shape. - # if two shapes have the same number of dimensions, - # try to choose the one that doesn't have "1" as - # a dimension. - shapes = [condition_shape, x_shape, y_shape] - shape_lens = [len(shape) for shape in shapes] - max_size = max(shape_lens) - max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size] - broadcast_idx = max_size_idxs[0] - if len(max_size_idxs) > 1: - for idx in max_size_idxs: - if 1 not in shapes[idx]: - broadcast_idx = idx - - broadcast_shape = shapes[broadcast_idx] - - if condition_shape != broadcast_shape: - inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape) - if x_shape != broadcast_shape: - inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape) - if y_shape != broadcast_shape: - inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape) - return _op.where(inputs[0], inputs[1], inputs[2]) + condition_rank = len(infer_shape(inputs[0])) + x_rank = len(infer_shape(inputs[1])) + y_rank = len(infer_shape(inputs[2])) + ranks = [condition_rank, x_rank, y_rank] + + # If one rank is longer than others, then we can broadcast + # to that shape. + max_rank = max(ranks) + max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank] + broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]]) + # If two or more inputs have the same rank, compute the broadcast + # shape by taking the maximum value of each dimensions. + if len(max_rank_idxs) > 1: + for idx in max_rank_idxs: + broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx])) + + condition = _op.broadcast_to(inputs[0], broadcast_shape) + x = _op.broadcast_to(inputs[1], broadcast_shape) + y = _op.broadcast_to(inputs[2], broadcast_shape) + return _op.where(condition, x, y) class Or(Elemwise): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 56d1dd5a5265..515fc32ef88d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2107,10 +2107,18 @@ def test_erf(): verify_erf(x, z) -def verify_where(condition, x, y, dtype, outdata): - node = helper.make_node("Where", inputs=["condition", "x", "y"], outputs=["out"]) +def verify_where(condition, x, y, dtype, outdata, dynamic=False): + node_list = [] + where_inputs = ["condition", "x", "y"] + if dynamic: + shape_node = helper.make_node("Shape", ["x"], ["shape"]) + reshape_node = helper.make_node("Reshape", ["x", "shape"], ["X"]) + where_inputs[1] = "X" + node_list += [shape_node, reshape_node] + node = helper.make_node("Where", inputs=where_inputs, outputs=["out"]) + node_list.append(node) graph = helper.make_graph( - [node], + node_list, "where_test", inputs=[ helper.make_tensor_value_info("condition", TensorProto.BOOL, list(condition.shape)), @@ -2120,7 +2128,7 @@ def verify_where(condition, x, y, dtype, outdata): outputs=[helper.make_tensor_value_info("out", dtype, list(outdata.shape))], ) model = helper.make_model(graph, producer_name="where_test") - verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape]) + verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape], use_vm=True) @tvm.testing.uses_gpu @@ -2156,6 +2164,7 @@ def test_where(): y = np.array([[1], [7]], dtype=np.float32) outdata = np.where(condition, x, y) verify_where(condition, x, y, TensorProto.FLOAT, outdata) + verify_where(condition, x, y, TensorProto.FLOAT, outdata, dynamic=True) def verify_or(indata, dtype):