Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay/TOPI][Op] Add batch_matmul in relay and TOPI #2561

Merged
merged 17 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ List of operators
topi.nn.upsampling
topi.nn.softmax
topi.nn.dense
topi.nn.batch_matmul
topi.nn.log_softmax
topi.nn.conv2d_nchw
topi.nn.conv2d_hwcn
Expand Down Expand Up @@ -134,6 +135,7 @@ topi.nn
.. autofunction:: topi.nn.upsampling
.. autofunction:: topi.nn.softmax
.. autofunction:: topi.nn.dense
.. autofunction:: topi.nn.batch_matmul
.. autofunction:: topi.nn.log_softmax
.. autofunction:: topi.nn.conv2d_nchw
.. autofunction:: topi.nn.conv2d_hwcn
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.device_copy
tvm.relay.annotation.on_device
tvm.relay.reverse_reshape
tvm.relay.nn.batch_matmul


Level 1 Definitions
Expand Down Expand Up @@ -260,3 +261,4 @@ Level 10 Definitions
.. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape
.. autofunction:: tvm.relay.nn.batch_matmul
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_batch_dot(inputs, attrs):
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True or transpose_b is False:
raise RuntimeError("batch_dot: only support transpose_a=False and "
"transpose_b=True")
return _op.batch_matmul(inputs[0], inputs[1])


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -351,6 +360,7 @@ def _mx_multibox_detection(inputs, attrs):
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
"concat" : _mx_concat,
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
Expand All @@ -363,7 +373,6 @@ def _mx_multibox_detection(inputs, attrs):
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,

}

# set identity list
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]

@reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_batch_matmul(outputs)

reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d
@reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target):
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,31 @@ def batch_norm(data,
return TupleWrapper(result, 3)


def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.

.. math::

\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)

Parameters
----------
x : tvm.relay.Expr
The first input.

y : tvm.relay.Expr
The second input.

Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y)


def contrib_conv2d_winograd_without_weight_transform(data,
weight,
tile_size,
Expand Down
63 changes: 63 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,5 +654,68 @@ axis to be the last item in the input shape.
.set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel);


// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
if (x->shape.size() != 3 || y->shape.size() != 3) return false;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;

Array<tvm::Expr> oshape = x->shape;
oshape.Set(2, y->shape[1]);

// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
return true;
}


// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x,
Expr y) {
static const Op& op = Op::Get("nn.batch_matmul");
return CallNode::make(op, {x, y}, Attrs(), {});
}


TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
});


RELAY_REGISTER_OP("nn.batch_matmul")
.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
are data in batch.

.. math::

batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T)

- **x**: `(b, m, k)`
- **y**: `(b, n, k)`
- **out**: `(b, m, n)`.

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "3D Tensor", "First input.")
.add_argument("y", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel);


} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def test_dense():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)



if __name__ == "__main__":
test_concatenate()
test_bias_add()
Expand Down
36 changes: 35 additions & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi
import topi.testing

def test_collapse_sum_like():
shape = (3, 4, 5, 6)
Expand Down Expand Up @@ -126,7 +128,6 @@ def verify_reverse_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reverse_reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z)
print(zz.checked_type)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

Expand All @@ -144,8 +145,41 @@ def verify_reverse_reshape(shape, newshape, oshape):
verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))

def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType(y_shape, dtype))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)

func = relay.Function([x, y], z)
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
z_np = topi.testing.batch_matmul(x_np, y_np)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
z = intrp.evaluate(func)(x_np, y_np)
tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)

def test_batch_matmul():
b, m, n, k = tvm.var("b"), tvm.var("m"), tvm.var("n"), tvm.var("k")
x = relay.var("x", relay.TensorType((b, m, k), "float32"))
y = relay.var("y", relay.TensorType((b, n, k), "float32"))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((b, m, n), "float32")

verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))


if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
test_slice_like()
test_reverse_reshape()
test_batch_matmul()
49 changes: 49 additions & 0 deletions topi/include/topi/nn/batch_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*!
* Copyright (c) 2019 by Contributors
* \brief Batch matmul op constructions
* \file nn/batch_matmul.h
*/
#ifndef TOPI_NN_BATCH_MATMUL_H_
#define TOPI_NN_BATCH_MATMUL_H_

#include <string>

#include "topi/tags.h"
#include "tvm/tvm.h"

namespace topi {
namespace nn {
using namespace tvm;

/*!
* \brief Creates an operation that calculates matrix multiplication in batch.
*
* \param x Tensor with shape [batch, M, K]
* \param y Tensor with shape [batch, N, K]
*
* \return Tensor with shape [batch, M, N]
*/
inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
const tvm::Tensor& y) {
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";

auto batch = x->shape[0];
auto M = x->shape[1];
auto K = x->shape[2];
auto N = y->shape[1];

auto k = tvm::reduce_axis(Range(0, K), "k");
auto result = tvm::compute(
{ batch, M, N },
[&](Var b, Var i, Var j) {
return tvm::sum(x(b, i, k) * y(b, j, k), { k });
}, "tensor", "batch_matmul");

return result;
}

} // namespace nn
} // namespace topi

#endif // TOPI_NN_BATCH_MATMUL_H_
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
from . import ssd
from .ssd import *
Expand Down
89 changes: 89 additions & 0 deletions topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""cuda batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm

from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor


@generic.schedule_batch_matmul.register(["cuda", "gpu"])
def schedule_batch_matmul(outs):
"""Schedule for batch_matmul

Parameters
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
s = tvm.create_schedule([x.op for x in outs])

def _schedule(op):
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
AA = s.cache_read(A, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
by, y = s[C].split(y, y_bn)
bx, x = s[C].split(x, x_bn)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
ty, yi = s[C].split(y, nparts=y_nthreads)
tx, xi = s[C].split(x, nparts=x_nthreads)
thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x")
thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y")

s[C].reorder(b, by, bx, ty, tx, yi, xi)
s[C].bind(b, tvm.thread_axis("blockIdx.z"))
s[C].bind(by, tvm.thread_axis("blockIdx.y"))
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].pragma(yi, "auto_unroll_max_step", 16)

s[CC].compute_at(s[C], tx)
_, yi, xi = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, 8)
s[CC].reorder(ko, ki, yi, xi)
s[CC].pragma(ki, "auto_unroll_max_step", 16)

s[AA].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], ki)
s[BB].compute_at(s[CC], ko)
s[BL].compute_at(s[CC], ki)
_, y, k = s[AA].op.axis
ty, yi = s[AA].split(y, nparts=y_nthreads)
tx, ki = s[AA].split(k, nparts=x_nthreads)
s[AA].reorder(ty, tx, yi, ki)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].pragma(yi, "auto_unroll_max_step", 16)

_, x, k = s[BB].op.axis
ty, xi = s[BB].split(x, nparts=y_nthreads)
tx, ki = s[BB].split(k, nparts=x_nthreads)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].reorder(ty, tx, xi, ki)
s[BB].pragma(xi, "auto_unroll_max_step", 16)

def _callback(op):
if "batch_matmul" in op.tag:
_schedule(op)

traverse_inline(s, outs[0].op, _callback)
return s
Loading