diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 5ffb7f1dea67..4b019cfcbccc 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -481,7 +481,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul(topi.cuda.batch_matmul), wrap_topi_schedule(topi.cuda.schedule_batch_matmul), name="batch_matmul.cuda", plevel=10) diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index bf801820d25a..7d92edfb97b7 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -14,13 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,too-many-locals,unused-variable +# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument """cuda batch_matmul operators""" +import tvm +from tvm import autotvm from tvm import te from tvm.contrib import cublas +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from .. import nn from ..util import traverse_inline, get_const_tuple, get_max_power2_factor -def schedule_batch_matmul(outs): +@autotvm.register_topi_compute("batch_matmul.cuda") +def batch_matmul(cfg, x, y): + """Compute conv2d with NCHW layout""" + return nn.batch_matmul(x, y) + + +@autotvm.register_topi_schedule("batch_matmul.cuda") +def schedule_batch_matmul(cfg, outs): """Schedule for batch_matmul Parameters @@ -37,7 +48,7 @@ def schedule_batch_matmul(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(op): + def _schedule(cfg, op): C = op.output(0) A, B = s[C].op.input_tensors _, M, N = get_const_tuple(C.shape) @@ -51,16 +62,34 @@ def _schedule(op): C = s.outputs[0].output(0) b, y, x = s[C].op.axis - y_bn = get_max_power2_factor(M, 64) - x_bn = get_max_power2_factor(N, 64) - by, y = s[C].split(y, y_bn) - bx, x = s[C].split(x, x_bn) - y_nthreads = min(y_bn, 8) - x_nthreads = min(x_bn, 8) - ty, yi = s[C].split(y, nparts=y_nthreads) - tx, xi = s[C].split(x, nparts=x_nthreads) - thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x") - thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y") + k, = s[CC].op.reduce_axis + + cfg.define_split("tile_y", y, num_outputs=3) + cfg.define_split("tile_x", x, num_outputs=3) + cfg.define_split("tile_k", k, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64]) + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + # llvm-based backends cannot do non-explicit unrolling + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + if cfg.is_fallback: + y_bn = get_max_power2_factor(M, 64) + x_bn = get_max_power2_factor(N, 64) + y_nthreads = min(y_bn, 8) + x_nthreads = min(x_bn, 8) + cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads]) + cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads]) + cfg['tile_k'] = SplitEntity([-1, 8]) + cfg['auto_unroll_max_step'] = OtherOptionEntity(16) + + by, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, tx, xi = cfg["tile_x"].apply(s, C, x) + + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") s[C].reorder(b, by, bx, ty, tx, yi, xi) s[C].bind(b, te.thread_axis("blockIdx.z")) @@ -68,38 +97,41 @@ def _schedule(op): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) - s[C].pragma(yi, "auto_unroll_max_step", 16) + s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) s[CC].compute_at(s[C], tx) _, yi, xi = s[CC].op.axis - k, = s[CC].op.reduce_axis - ko, ki = s[CC].split(k, 8) + ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, yi, xi) - s[CC].pragma(ki, "auto_unroll_max_step", 16) + s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val) s[AA].compute_at(s[CC], ko) s[AL].compute_at(s[CC], ki) s[BB].compute_at(s[CC], ko) s[BL].compute_at(s[CC], ki) _, y, k = s[AA].op.axis - ty, yi = s[AA].split(y, nparts=y_nthreads) - tx, ki = s[AA].split(k, nparts=x_nthreads) + ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1]) + tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1]) s[AA].reorder(ty, tx, yi, ki) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) - s[AA].pragma(yi, "auto_unroll_max_step", 16) + s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) _, x, k = s[BB].op.axis - ty, xi = s[BB].split(x, nparts=y_nthreads) - tx, ki = s[BB].split(k, nparts=x_nthreads) + ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1]) + tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1]) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].reorder(ty, tx, xi, ki) - s[BB].pragma(xi, "auto_unroll_max_step", 16) + s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val) def _callback(op): if "batch_matmul" in op.tag: - _schedule(op) + _schedule(cfg, op) traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/tests/python/test_topi_batch_matmul.py b/topi/tests/python/test_topi_batch_matmul.py index b8c854746847..716f40700339 100644 --- a/topi/tests/python/test_topi_batch_matmul.py +++ b/topi/tests/python/test_topi_batch_matmul.py @@ -28,7 +28,7 @@ _batch_matmul_implement = { "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul), "cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul), - "gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul), + "gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul), } def verify_batch_matmul(batch, M, N, K):