Skip to content

Commit

Permalink
[Topi] Allow batch_matmul to broadcast along batch dimension. (apache…
Browse files Browse the repository at this point in the history
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
  • Loading branch information
jwfromm authored and Tushar Dey committed Oct 15, 2020
1 parent 8c6ab1b commit f43fbb5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
9 changes: 0 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,6 @@ def flatten_to_3d(x, x_shape):
# Convert a and b into 3 dimensional tensors.
a = flatten_to_3d(inputs[0], a_shape)
b = flatten_to_3d(inputs[1], b_shape)
# Broadcast b to match batch size of a
new_b_shape = _op.concatenate(
[
_op.strided_slice(_op.shape_of(a), [0], [1]),
_op.strided_slice(_op.shape_of(b), [1], [3]),
],
0,
)
b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
Expand Down
22 changes: 12 additions & 10 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,21 @@ def batch_matmul(x, y, oshape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
XB = x_shape[0]
YB = y_shape[0]
_, M, K = x.shape
k = te.reduce_axis((0, K), name="k")
if oshape is None:
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
batch = max(XB, YB)
N = y.shape[1]
k = te.reduce_axis((0, K), name="k")
oshape = (batch, M, N)
else:
_, _, K = x.shape
k = te.reduce_axis((0, K), name="k")
return te.compute(
oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul"
oshape,
lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k),
tag="batch_matmul",
)
5 changes: 3 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,9 @@ bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}
}
if (!is_dyn) {
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) ||
reporter->AssertEQ(y->shape[0], 1))
<< "BatchDot: batch dimensions don't match, "
<< " x shape=" << x->shape << ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
Expand Down

0 comments on commit f43fbb5

Please sign in to comment.