From f43fbb5584ec341a2e71bdf575af9c2fb6cca509 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 6 Oct 2020 14:07:26 -0700 Subject: [PATCH] [Topi] Allow batch_matmul to broadcast along batch dimension. (#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak. --- python/tvm/relay/frontend/onnx.py | 9 --------- python/tvm/topi/nn/batch_matmul.py | 22 ++++++++++++---------- src/relay/op/nn/nn.cc | 5 +++-- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 06f711b21611..1bf8b7b9e0d6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -528,15 +528,6 @@ def flatten_to_3d(x, x_shape): # Convert a and b into 3 dimensional tensors. a = flatten_to_3d(inputs[0], a_shape) b = flatten_to_3d(inputs[1], b_shape) - # Broadcast b to match batch size of a - new_b_shape = _op.concatenate( - [ - _op.strided_slice(_op.shape_of(a), [0], [1]), - _op.strided_slice(_op.shape_of(b), [1], [3]), - ], - 0, - ) - b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index b653f495ef8d..9b926a1182d8 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -41,19 +41,21 @@ def batch_matmul(x, y, oshape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + XB = x_shape[0] + YB = y_shape[0] + _, M, K = x.shape + k = te.reduce_axis((0, K), name="k") if oshape is None: - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - x_shape = get_const_tuple(x.shape) - y_shape = get_const_tuple(y.shape) - assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch, M, K = x.shape + batch = max(XB, YB) N = y.shape[1] - k = te.reduce_axis((0, K), name="k") oshape = (batch, M, N) - else: - _, _, K = x.shape - k = te.reduce_axis((0, K), name="k") return te.compute( - oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + oshape, + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 564c6da1fb06..1de7ca003772 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -863,8 +863,9 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs } } if (!is_dyn) { - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "BatchDot: batch dimension doesn't match, " + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) || + reporter->AssertEQ(y->shape[0], 1)) + << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, "