diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1bd8673cd3aa..b275e85939c6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1105,14 +1105,33 @@ class Where(OnnxOpConverter): """ @classmethod def _impl_v9(cls, inputs, attr, params): - # x and y can be broadcasted condition_shape = infer_shape(inputs[0]) x_shape = infer_shape(inputs[1]) y_shape = infer_shape(inputs[2]) - if len(condition_shape) > len(x_shape): - inputs[1] = _op.broadcast_to(inputs[1], condition_shape) - if len(condition_shape) > len(y_shape): - inputs[2] = _op.broadcast_to(inputs[2], condition_shape) + + # condition, x, and y can all be broadcasted. + # broadcast each of them to the longest shape. + # if two shapes have the same number of dimensions, + # try to choose the one that doesn't have "1" as + # a dimension. + shapes = [condition_shape, x_shape, y_shape] + shape_lens = [len(shape) for shape in shapes] + max_size = max(shape_lens) + max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size] + broadcast_idx = max_size_idxs[0] + if len(max_size_idxs) > 1: + for idx in max_size_idxs: + if 1 not in shapes[idx]: + broadcast_idx = idx + + broadcast_shape = shapes[broadcast_idx] + + if condition_shape != broadcast_shape: + inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape) + if x_shape != broadcast_shape: + inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape) + if y_shape != broadcast_shape: + inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape) return _op.where(inputs[0], inputs[1], inputs[2]) class Or(Elemwise): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7eb09493df8c..fc05e7a011a9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1684,6 +1684,22 @@ def test_where(): outdata = np.where(condition, x, y) verify_where(condition, x, y, TensorProto.FLOAT, outdata) + x = np.array([2], dtype=np.float32) + y = np.array(1, dtype=np.float32) + outdata = np.where(condition, x, y) + verify_where(condition, x, y, TensorProto.FLOAT, outdata) + + condition = np.array(1, dtype=np.bool) + x = np.array([[1, 2], [3, 4]], dtype=np.float32) + y = np.array([[5, 6], [7, 8]], dtype=np.float32) + outdata = np.where(condition, x, y) + verify_where(condition, x, y, TensorProto.FLOAT, outdata) + + x = np.array([[1, 2], [3, 4]], dtype=np.float32) + y = np.array([[1], [7]], dtype=np.float32) + outdata = np.where(condition, x, y) + verify_where(condition, x, y, TensorProto.FLOAT, outdata) + def verify_or(indata, dtype): x = indata[0].astype(dtype)