Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Add support for broadcasting to Where and Mat…
Browse files Browse the repository at this point in the history
…Mul (#4267)
  • Loading branch information
soiferj authored and tqchen committed Nov 7, 2019
1 parent 14a5a35 commit 5bcd331
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def _impl_v1(cls, inputs, attr, params):
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
# Broadcast b to match batch size of a
new_b_shape = list(infer_shape(b))
new_a_shape = infer_shape(a)
if new_a_shape[0] > new_b_shape[0]:
new_b_shape[0] = new_a_shape[0]
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
Expand Down Expand Up @@ -987,6 +993,14 @@ class Where(OnnxOpConverter):
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
# x and y can be broadcasted
condition_shape = infer_shape(inputs[0])
x_shape = infer_shape(inputs[1])
y_shape = infer_shape(inputs[2])
if len(condition_shape) > len(x_shape):
inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
if len(condition_shape) > len(y_shape):
inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
return _op.where(inputs[0], inputs[1], inputs[2])

class Or(Elemwise):
Expand All @@ -996,6 +1010,7 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down
15 changes: 10 additions & 5 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,11 +498,7 @@ def test_matmul():
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


def test_batch_matmul():
a_shape = (2, 3, 4, 3)
b_shape = (2, 3, 3, 4)

def verify_batch_matmul(a_shape, b_shape):
a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array)
Expand All @@ -525,6 +521,10 @@ def test_batch_matmul():
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)

def test_batch_matmul():
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
verify_batch_matmul((2, 4, 3), (3, 4))
verify_batch_matmul((2, 3, 4, 3), (3, 4))

def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
Expand Down Expand Up @@ -1600,6 +1600,11 @@ def test_where():
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)

x = np.array(1, dtype=np.float32)
y = np.array([2], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)


def verify_or(indata, dtype):
x = indata[0].astype(dtype)
Expand Down

0 comments on commit 5bcd331

Please sign in to comment.