Skip to content

Commit

Permalink
[Topi] Allow batch_matmul to broadcast along batch dimension. (apache…
Browse files Browse the repository at this point in the history
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
  • Loading branch information
jwfromm authored and trevor-m committed Oct 19, 2020
1 parent 7e3bfbd commit fcc9be5
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 112 deletions.
67 changes: 0 additions & 67 deletions include/tvm/topi/nn/batch_matmul.h

This file was deleted.

9 changes: 0 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,6 @@ def flatten_to_3d(x, x_shape):
# Convert a and b into 3 dimensional tensors.
a = flatten_to_3d(inputs[0], a_shape)
b = flatten_to_3d(inputs[1], b_shape)
# Broadcast b to match batch size of a
new_b_shape = _op.concatenate(
[
_op.strided_slice(_op.shape_of(a), [0], [1]),
_op.strided_slice(_op.shape_of(b), [1], [3]),
],
0,
)
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
Expand Down
28 changes: 17 additions & 11 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def batch_matmul(x, y, oshape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
data in batch. Supports broadcasting for batch dimension.
Parameters
----------
Expand All @@ -32,24 +32,30 @@ def batch_matmul(x, y, oshape=None):
y : tvm.te.Tensor
3-D with shape [batch, N, K]
oshape : List[Optional]
Explicit intended output shape of the computation. Can be useful in cases
with dynamic input shapes.
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
XB = x_shape[0]
YB = y_shape[0]
_, M, K = x.shape
k = te.reduce_axis((0, K), name="k")
if oshape is None:
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
batch = max(XB, YB)
N = y.shape[1]
k = te.reduce_axis((0, K), name="k")
oshape = (batch, M, N)
else:
_, _, K = x.shape
k = te.reduce_axis((0, K), name="k")
return te.compute(
oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
oshape,
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
)
7 changes: 4 additions & 3 deletions python/tvm/topi/testing/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def batch_matmul(x, y):
out : numpy.ndarray
3-D with shape [batch, M, N]
"""
batch, M, _ = x.shape
N = y.shape[1]
XB, M, _ = x.shape
YB, N, _ = y.shape
batch = max(XB, YB)
out = np.zeros((batch, M, N)).astype(x.dtype)
for i in range(batch):
out[i] = np.dot(x[i], y[i].T)
out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T)
return out
10 changes: 6 additions & 4 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
data in batch. Supports broadcasting in batch dimension.
Parameters
----------
Expand All @@ -45,9 +45,9 @@ def batch_matmul(cfg, x, y, out_shape=None):
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
B = XB
B = max(XB, YB)
K = XK
if out_shape is not None:
assert out_shape[0] == B, "got invalid output shape"
Expand All @@ -58,7 +58,9 @@ def batch_matmul(cfg, x, y, out_shape=None):

k = te.reduce_axis((0, K), name="k")
C = te.compute(
(B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
(B, M, N),
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
)
return C

Expand Down
6 changes: 4 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <tvm/topi/nn/flatten.h>
#include <tvm/topi/nn/softmax.h>

#include <algorithm>
#include <string>
#include <vector>

Expand Down Expand Up @@ -862,8 +863,9 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}
}
if (!is_dyn) {
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y->shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
Expand Down
6 changes: 0 additions & 6 deletions src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/topi/nn.h>
#include <tvm/topi/nn/batch_matmul.h>
#include <tvm/topi/nn/bias_add.h>
#include <tvm/topi/nn/bnn.h>
#include <tvm/topi/nn/dense.h>
Expand Down Expand Up @@ -68,11 +67,6 @@ TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* r
*rv = nn::bias_add(args[0], args[1], args[2]);
});

/* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::batch_matmul(args[0], args[1]);
});

/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::dilate(args[0], args[1], args[2]);
Expand Down
1 change: 0 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3628,7 +3628,6 @@ def verify_roi_align(
test_clip_min_max_as_inputs()
test_onehot()
test_matmul()
test_batch_matmul()
test_gather()
test_gatherelements()
test_gather_nd()
Expand Down
21 changes: 12 additions & 9 deletions tests/python/topi/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@
}


def verify_batch_matmul(batch, M, N, K):
x = te.placeholder((batch, M, K), name="x")
y = te.placeholder((batch, N, K), name="y")
def verify_batch_matmul(x_batch, y_batch, M, N, K):
x = te.placeholder((x_batch, M, K), name="x")
y = te.placeholder((y_batch, N, K), name="y")
dtype = x.dtype

# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_batch_matmul")
def get_ref_data():
a_np = np.random.uniform(size=(batch, M, K)).astype(dtype)
b_np = np.random.uniform(size=(batch, N, K)).astype(dtype)
a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
return (a_np, b_np, c_np)

Expand All @@ -67,10 +67,13 @@ def check_device(device, ctx):

@tvm.testing.uses_gpu
def test_batch_matmul():
verify_batch_matmul(1, 16, 16, 32)
verify_batch_matmul(5, 16, 16, 32)
verify_batch_matmul(5, 16, 20, 32)
verify_batch_matmul(30, 16, 20, 32)
verify_batch_matmul(1, 1, 16, 16, 32)
verify_batch_matmul(5, 5, 16, 16, 32)
verify_batch_matmul(5, 5, 16, 20, 32)
verify_batch_matmul(30, 30, 16, 20, 32)
# Test batch broadcasting.
verify_batch_matmul(1, 5, 16, 16, 32)
verify_batch_matmul(5, 1, 16, 16, 32)


if __name__ == "__main__":
Expand Down

0 comments on commit fcc9be5

Please sign in to comment.