diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 6ee8bc01cb54..345da66f1112 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -117,6 +117,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], tvm.relay.op.nn.dense: [topi.nn.dense], + tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul], tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], } diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index ac4683d4ae0b..7bfc313de6e9 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -87,6 +87,7 @@ def __init__(self, allow_duplicate=False): topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc", topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8", topi.nn.dense: "topi_nn_dense", + topi.nn.batch_matmul: "topi_nn_batch_matmul", topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", topi.nn.bitserial_dense: "topi_nn_bitserial_dense", @@ -103,6 +104,7 @@ def __init__(self, allow_duplicate=False): topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc], topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8], topi.nn.dense: [topi.generic.schedule_dense], + topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul], topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense], @@ -118,6 +120,7 @@ def __init__(self, allow_duplicate=False): topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x), topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x), + topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x), topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x), topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x), topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x), @@ -226,6 +229,15 @@ def _topi_nn_dense(*args, **kwargs): return s, [data, weight, bias, C] return s, [data, weight, C] + @register("topi_nn_batch_matmul") + def _topi_nn_batch_matmul(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, B = args + C = topi.nn.batch_matmul(A, B) + s = topi.generic.schedule_batch_matmul([C]) + return s, [A, B, C] + @register("topi_nn_bitserial_conv2d_nhwc") def _topi_bitserial_conv2d_nhwc(*args, **kwargs): args = deserialize_args(args) diff --git a/topi/python/topi/x86/batch_matmul.py b/topi/python/topi/x86/batch_matmul.py index 047e97fa2e51..b505cbfabb55 100644 --- a/topi/python/topi/x86/batch_matmul.py +++ b/topi/python/topi/x86/batch_matmul.py @@ -18,24 +18,26 @@ """x86 batch_matmul operators""" from __future__ import absolute_import as _abs import tvm +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas -from topi.nn import batch_matmul, batch_matmul_default -from .. import generic +from .. import generic, nn from ..util import traverse_inline, get_const_tuple, get_max_power2_factor -@batch_matmul.register(["cpu"]) -def batch_matmul_x86(x, y): + +@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct") +def _declaration_batch_matmul_nopack(cfg, x, y): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. Parameters ---------- + cfg : ConfigSpace + Autotvm tuning space config file x : tvm.Tensor 3-D with shape [batch, M, K] - y : tvm.Tensor 3-D with shape [batch, N, K] - Returns ------- output : tvm.Tensor @@ -44,17 +46,37 @@ def batch_matmul_x86(x, y): target = tvm.target.current_target() if "cblas" in target.libs: return cblas.batch_matmul(x, y, False, True) - return batch_matmul_default(x, y) -@generic.schedule_batch_matmul.register(["cpu"]) -def schedule_batch_matmul(outs): + assert len(x.shape) == 3 and len( + y.shape) == 3, "only support 3-dim batch_matmul" + XB, M, XK = get_const_tuple(x.shape) + YB, N, YK = get_const_tuple(y.shape) + assert XB == YB, "batch dimension doesn't match" + assert XK == YK, "shapes of x and y is inconsistant" + B = XB + K = XK + if cfg.is_fallback: + _default_batch_matmul_nopack_config(cfg, M, N, K) + + k = tvm.reduce_axis((0, K), name='k') + C = tvm.compute( + (B, M, N), + lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k), + tag='batch_matmul') + return C + + +@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct") +def schedule_batch_matmul(cfg, outs): """Schedule for batch_matmul Parameters ---------- - outs: Array of Tensor - The computation graph description of batch_matmul - in the format of an array of tensors. + cfg : ConfigSpace + AutoTVM tuning space config file. + outs : Array of Tensor + The computation graph description of batch_matmul + in the format of an array of tensors. Returns ------- @@ -71,16 +93,22 @@ def _callback(op): if "batch_matmul" in op.tag: C = op.output(0) A, B = s[C].op.input_tensors - _, M, N = get_const_tuple(C.shape) + _, M, K = get_const_tuple(A.shape) + _, _, N = get_const_tuple(C.shape) + + # create tuning space + cfg.define_split("tile_y", M, num_outputs=2) + cfg.define_split("tile_x", N, num_outputs=2) + cfg.define_split("tile_k", K, num_outputs=2) + k, = s[C].op.reduce_axis - ko, ki = s[C].split(k, 16) + + ko, ki = cfg["tile_k"].apply(s, C, k) CC = s.rfactor(C, ki) b, y, x = s[C].op.axis - y_bn = get_max_power2_factor(M, 8) - x_bn = get_max_power2_factor(N, 8) - yo, yi = s[C].split(y, y_bn) - xo, xi = s[C].split(x, x_bn) + yo, yi = cfg["tile_y"].apply(s, C, y) + xo, xi = cfg["tile_x"].apply(s, C, x) s[C].reorder(b, yo, xo, yi, xi) bxyo = s[C].fuse(b, yo, xo) s[C].parallel(bxyo) @@ -94,3 +122,11 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def _default_batch_matmul_nopack_config(cfg, M, N, K): + cfg["tile_k"] = SplitEntity([K // 16, 16]) + x_bn = get_max_power2_factor(N, 8) + cfg["tile_x"] = SplitEntity([N // x_bn, x_bn]) + y_bn = get_max_power2_factor(M, 8) + cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])