From 6c02e2651e472ec07badd0916b24a0fc542e525d Mon Sep 17 00:00:00 2001 From: Shawn-Inspur <56216438+Shawn-Inspur@users.noreply.github.com> Date: Sat, 28 Mar 2020 07:20:40 +0800 Subject: [PATCH] [TOPI][Tensor Core] Conv2d and Dense ops support on Tensor Core (#5099) * [TOPI][Tensor Core] Optimization of CNNs on Tensor Core #6004 * update conv2d test * # pylint: dense_tensorcore.py * modify * modify conv2d * modify the unclear comment,add shape assertion in conv2d compute,combine general gemm intrinsic * add shape assertion in conv2d compute, combine general gemm intrinsic Co-authored-by: libaihong Co-authored-by: libaihong <61525430+libaihong@users.noreply.github.com> --- python/tvm/relay/op/strategy/cuda.py | 39 ++- python/tvm/relay/testing/resnet.py | 59 +++- src/driver/driver_api.cc | 7 + topi/python/topi/cuda/__init__.py | 2 + topi/python/topi/cuda/conv2d.py | 35 +- topi/python/topi/cuda/conv2d_nhwc.py | 131 ++++++++ .../topi/cuda/conv2d_nhwc_tensorcore.py | 318 ++++++++++++++++++ topi/python/topi/cuda/dense_tensorcore.py | 252 ++++++++++++++ topi/python/topi/cuda/tensor_intrin.py | 147 +++++++- topi/tests/python/test_topi_conv2d_nhwc.py | 12 +- .../test_topi_conv2d_nhwc_tensorcore.py | 126 +++++++ .../python/test_topi_dense_tensorcore.py | 91 +++++ 12 files changed, 1172 insertions(+), 47 deletions(-) create mode 100644 topi/python/topi/cuda/conv2d_nhwc.py create mode 100644 topi/python/topi/cuda/conv2d_nhwc_tensorcore.py create mode 100644 topi/python/topi/cuda/dense_tensorcore.py create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py create mode 100644 topi/tests/python/test_topi_dense_tensorcore.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index f52a7d5f2dd1..db03c5965470 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -17,7 +17,9 @@ """Definition of CUDA/GPU operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import topi +import tvm from tvm.te import SpecializedCondition +from tvm.contrib import nvcc from .generic import * from .. import op as _op from .... import get_global_func @@ -112,13 +114,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.conv2d_hwcn), wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), name="conv2d_hwcn.cuda") - # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda - # elif layout == "NHWC": - # assert kernel_layout == "HWIO" - # strategy.add_implementation( - # wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - # wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - # name="conv2d_nhwc.cuda") + elif layout == "NHWC": + assert kernel_layout == "HWIO" + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), + name="conv2d_nhwc.cuda") + N, _, _, _ = get_const_tuple(data.shape) + _, _, CI, CO = get_const_tuple(kernel.shape) + if nvcc.have_tensorcore(tvm.gpu(0).compute_version): + if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0): + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore), + name="conv2d_nhwc_tensorcore.cuda", + plevel=20) elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( @@ -279,6 +291,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target): def dense_strategy_cuda(attrs, inputs, out_type, target): """dense cuda strategy""" strategy = _op.OpStrategy() + data, weights = inputs + b, i = get_const_tuple(data.shape) + o, _ = get_const_tuple(weights.shape) if out_type.dtype == "int8": strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_int8), @@ -289,13 +304,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_dense(topi.cuda.dense_small_batch), wrap_topi_schedule(topi.cuda.schedule_dense_small_batch), name="dense_small_batch.cuda") - b = inputs[0].shape[0] with SpecializedCondition(b >= 32): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_large_batch), wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), name="dense_large_batch.cuda", plevel=5) + if nvcc.have_tensorcore(tvm.gpu(0).compute_version): + if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \ + or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \ + or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0): + strategy.add_implementation( + wrap_compute_dense(topi.cuda.dense_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore), + name="dense_tensorcore.cuda", + plevel=20) if target.target_name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_cublas), diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 97b6bdc7e617..b431dd096f9d 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -32,7 +32,10 @@ def residual_unit(data, stride, dim_match, name, - bottle_neck=True): + bottle_neck=True, + data_layout="NCHW", + kernel_layout="IOHW" + ): """Return ResNet Unit symbol for building ResNet Parameters @@ -67,42 +70,50 @@ def residual_unit(data, kernel_size=(1, 1), strides=stride, padding=(0, 0), - name=name + '_conv1') + name=name + '_conv1', + data_layout=data_layout, + kernel_layout=kernel_layout) bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), - strides=(1, 1), padding=(1, 1), name=name + '_conv2') + strides=(1, 1), padding=(1, 1), name=name + '_conv2', + data_layout=data_layout, kernel_layout=kernel_layout) bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') act3 = relay.nn.relu(data=bn3) conv3 = layers.conv2d( data=act3, channels=num_filter, kernel_size=(1, 1), - strides=(1, 1), padding=(0, 0), name=name + '_conv3') + strides=(1, 1), padding=(0, 0), name=name + '_conv3', + data_layout=data_layout, kernel_layout=kernel_layout) if dim_match: shortcut = data else: shortcut = layers.conv2d( data=act1, channels=num_filter, kernel_size=(1, 1), - strides=stride, name=name+'_sc') + strides=stride, name=name+'_sc', + data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv3, shortcut) bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( data=act1, channels=num_filter, kernel_size=(3, 3), - strides=stride, padding=(1, 1), name=name + '_conv1') + strides=stride, padding=(1, 1), name=name + '_conv1', + data_layout=data_layout, kernel_layout=kernel_layout) bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=num_filter, kernel_size=(3, 3), - strides=(1, 1), padding=(1, 1), name=name + '_conv2') + strides=(1, 1), padding=(1, 1), name=name + '_conv2', + data_layout=data_layout, kernel_layout=kernel_layout) if dim_match: shortcut = data else: shortcut = layers.conv2d( data=act1, channels=num_filter, kernel_size=(1, 1), - strides=stride, name=name+'_sc') + strides=stride, name=name+'_sc', + data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv2, shortcut) @@ -112,6 +123,7 @@ def resnet(units, num_classes, data_shape, bottle_neck=True, + layout="NCHW", dtype="float32"): """Return ResNet Program. @@ -135,9 +147,16 @@ def resnet(units, bottle_neck : bool Whether apply bottleneck transformation. + layout: str + The data layout for conv2d + dtype : str The global data type. """ + + data_layout = layout + kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) @@ -146,27 +165,32 @@ def resnet(units, if height <= 32: # such as cifar10 body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(3, 3), - strides=(1, 1), padding=(1, 1), name="conv0") + strides=(1, 1), padding=(1, 1), name="conv0", + data_layout=data_layout, kernel_layout=kernel_layout) else: # often expected to be 224 such as imagenet body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(7, 7), - strides=(2, 2), padding=(3, 3), name="conv0") + strides=(2, 2), padding=(3, 3), name="conv0", + data_layout=data_layout, kernel_layout=kernel_layout) body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') body = relay.nn.relu(data=body) - body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1)) + body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), + layout=data_layout) for i in range(num_stages): body = residual_unit( body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2), - False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck) + False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, + data_layout=data_layout, kernel_layout=kernel_layout) for j in range(units[i]-1): body = residual_unit( body, filter_list[i+1], (1, 1), True, - name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck) + name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, + data_layout=data_layout, kernel_layout=kernel_layout) bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') relu1 = relay.nn.relu(data=bn1) # Although kernel is not used here when global_pool=True, we should put one - pool1 = relay.nn.global_avg_pool2d(data=relu1) + pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout) flat = relay.nn.batch_flatten(data=pool1) fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1') net = relay.nn.softmax(data=fc1) @@ -177,6 +201,7 @@ def get_net(batch_size, num_classes, num_layers=50, image_shape=(3, 224, 224), + layout="NCHW", dtype="float32", **kwargs): """ @@ -229,6 +254,7 @@ def get_net(batch_size, num_classes=num_classes, data_shape=data_shape, bottle_neck=bottle_neck, + layout=layout, dtype=dtype) @@ -236,6 +262,7 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18, image_shape=(3, 224, 224), + layout="NCHW", dtype="float32", **kwargs): """Get benchmark workload for resnet @@ -254,6 +281,9 @@ def get_workload(batch_size=1, image_shape : tuple, optional The input image shape + layout: str + The data layout for conv2d + dtype : str, optional The data type @@ -273,5 +303,6 @@ def get_workload(batch_size=1, num_layers=num_layers, image_shape=image_shape, dtype=dtype, + layout=layout, **kwargs) return create_workload(net) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 774c47666b17..0f56f9d654ae 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -201,6 +201,7 @@ Array > split_dev_host_funcs(const Array& funcs, func = tir::ThreadSync(func, "shared"); func = tir::ThreadSync(func, "warp"); + func = tir::InferFragment(func); func = tir::LowerThreadAllreduce(func, target->thread_warp_size); auto fsplits = tir::SplitHostDevice(func); fhost.push_back(fsplits[0]); @@ -244,6 +245,12 @@ Array > split_dev_host_funcs(const Array& funcs, << "\n"; } + for (size_t i = 0; i < fdevice.size(); ++i) { + auto func = fdevice[i]; + func = tir::LowerDeviceStorageAccessInfo(func); + fdevice.Set(i, func); + } + for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = tir::BindDeviceType(func, target->device_type); diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 6e38318a0062..302171ee6466 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -43,3 +43,5 @@ from .nms import get_valid_counts, non_max_suppression from .rcnn import * from .sort import * +from .conv2d_nhwc_tensorcore import * +from .dense_tensorcore import * diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index c24789307340..c7df3dc96a5e 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -24,6 +24,7 @@ from ..nn.util import get_pad_tuple from ..util import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda +from .conv2d_nhwc import schedule_conv2d_nhwc_direct @autotvm.register_topi_compute("conv2d_nchw.cuda") @@ -46,24 +47,22 @@ def _callback(op): return s -# TODO(@alexgl-github): It's invalid to call schedule_direct_cuda for NHWC layout -# as it assumes the input layout to be NCHW. Please fix this. -# @autotvm.register_topi_compute("conv2d_nhwc.cuda") -# def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'): -# return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) -# -# -# @autotvm.register_topi_schedule("conv2d_nhwc.cuda") -# def schedule_conv2d_nhwc(cfg, outs): -# outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs -# s = te.create_schedule([x.op for x in outs]) -# -# def _callback(op): -# if op.tag == 'conv2d_nhwc': -# schedule_direct_cuda(cfg, s, op.output(0)) -# -# traverse_inline(s, outs[0].op, _callback) -# return s +@autotvm.register_topi_compute("conv2d_nhwc.cuda") +def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'): + """Compute conv2d with NHWC layout""" + return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc.cuda") +def schedule_conv2d_nhwc(cfg, outs): + """Create the schedule for conv2d_nhwc""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + def _callback(op): + if op.tag == 'conv2d_nhwc': + schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) + traverse_inline(s, outs[0].op, _callback) + return s @autotvm.register_topi_compute("conv2d_cudnn.cuda") diff --git a/topi/python/topi/cuda/conv2d_nhwc.py b/topi/python/topi/cuda/conv2d_nhwc.py new file mode 100644 index 000000000000..55714b2d80a0 --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Direct conv2d in NHWC layout""" +import tvm +from tvm import te +from tvm import autotvm +from ..util import get_const_tuple + + +def schedule_conv2d_nhwc_direct(cfg, s, Conv): + """schedule optimized for NHWC direct conv2d""" + pad_data, kernel = s[Conv].op.input_tensors + s[pad_data].compute_inline() + + if isinstance(kernel.op, tvm.te.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if Conv.op in s.outputs: + output = Conv + OL = s.cache_write(Conv, 'local') + else: + output = s.outputs[0].output(0) + s[Conv].set_scope('local') + OL = Conv + # create cache stage + AA = s.cache_read(pad_data, 'shared', [OL]) + WW = s.cache_read(kernel, "shared", [OL]) + AL = s.cache_read(AA, "local", [OL]) + WL = s.cache_read(WW, "local", [OL]) + + # Schedule for autotvm + cfg.define_knob("tile_n", [2, 4, 8]) + cfg.define_knob("tile_c", [2, 4, 8]) + cfg.define_knob("num_thread_n", [4, 8, 16]) + cfg.define_knob("num_thread_c", [4, 8, 16]) + cfg.define_knob("vthread_n", [1, 2]) + cfg.define_knob("vthread_c", [1, 2]) + cfg.define_knob("step", [16, 3, 32, 64]) + + # fallback support + target = tvm.target.Target.current() + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'conv2d_nhwc.cuda') + cfg.fallback_with_reference_log(ref_log) + + tile_n = cfg["tile_n"].val + tile_c = cfg["tile_c"].val + num_thread_n = cfg["num_thread_n"].val + num_thread_c = cfg["num_thread_c"].val + vthread_n = cfg["vthread_n"].val + vthread_c = cfg["vthread_c"].val + step = cfg["step"].val + block_factor_c = tile_c * num_thread_c * vthread_c + + offset = 8 + A_align = step + offset + W_align = block_factor_c + offset + + block_x = te.thread_axis("blockIdx.x") + block_y = te.thread_axis("blockIdx.y") + block_z = te.thread_axis("blockIdx.z") + thread_x = te.thread_axis((0, num_thread_c), "threadIdx.x") + thread_y = te.thread_axis((0, num_thread_n), "threadIdx.y") + thread_xz = te.thread_axis((0, vthread_c), "vthread", name="vx") + thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy") + + # Schedule for output + ni, hi, wi, fi = s[output].op.axis + bz = s[output].fuse(hi, wi) + tx, fi = s[output].split(fi, factor=tile_c) + txz, tx = s[output].split(tx, factor=num_thread_c) + bx, txz = s[output].split(txz, factor=vthread_c) + ty, ni = s[output].split(ni, factor=tile_n) + tyz, ty = s[output].split(ty, factor=num_thread_n) + by, tyz = s[output].split(tyz, factor=vthread_n) + s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi) + s[output].bind(bz, block_z) + s[output].bind(by, block_y) + s[output].bind(bx, block_x) + s[output].bind(tyz, thread_yz) + s[output].bind(txz, thread_xz) + s[output].bind(ty, thread_y) + s[output].bind(tx, thread_x) + # Schedule local computation + s[OL].compute_at(s[output], tx) + ni, yi, xi, fi = s[OL].op.axis + ry, rx, rc = s[OL].op.reduce_axis + rco, rci = s[OL].split(rc, factor=step) + s[OL].reorder(rco, ry, rx, rci, ni, fi) + + s[AA].compute_at(s[OL], rx) + s[WW].compute_at(s[OL], rx) + s[AL].compute_at(s[OL], rci) + s[WL].compute_at(s[OL], rci) + # Schedule for data's share memory + ni, yi, xi, ci = s[AA].op.axis + s[AA].reorder(yi, xi, ni, ci) + s[AA].storage_align(xi, A_align - 1, A_align) + t = s[AA].fuse(ni, ci) + ty, tx = s[AA].split(t, factor=num_thread_c) + _, ty = s[AA].split(ty, factor=num_thread_n) + s[AA].bind(tx, thread_x) + s[AA].bind(ty, thread_y) + # Schedule for kernel's share memory + _, _, ic, o = s[WW].op.axis + t = s[WW].fuse(ic, o) + s[WW].storage_align(ic, W_align - 1, W_align) + ty, tx = s[WW].split(t, factor=num_thread_c) + _, ty = s[WW].split(ty, factor=num_thread_n) + s[WW].bind(tx, thread_x) + s[WW].bind(ty, thread_y) + + N, OH, OW, CO = get_const_tuple(output.shape) + KH, KW, CI, _ = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py new file mode 100644 index 000000000000..8f8f93d00a8f --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-function-args +# pylint: disable=too-many-statements, unused-argument, too-many-arguments +"""Tensorcore template for cuda backend""" +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +from ..util import get_const_tuple, traverse_inline, simplify +from ..nn.pad import pad +from ..nn.util import get_pad_tuple +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + + +def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype): + """Compute declaration for tensorcore""" + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = get_const_tuple(Input.shape) + kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape) + assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) or \ + (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0) or \ + (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0), \ + "The shape of (batch, in_channel, num_filter) "\ + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name='rc') + ry = te.reduce_axis((0, kernel_h), name='ry') + rx = te.reduce_axis((0, kernel_w), name='rx') + # convert data type of input feature maps and weights + TransPaddedInput = te.compute( + PaddedInput.shape, + lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16')) + TransFilter = te.compute( + Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16')) + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + TransPaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc_tensorcore") + return Output + + +def schedule_nhwc_tensorcore_cuda(cfg, s, Conv): + """Schedule tensorcore template""" + kh, kw, ic = s[Conv].op.reduce_axis + out_dtype = Conv.dtype + trans_paddata, kernel = s[Conv].op.input_tensors + in_dtype = trans_paddata.dtype + batch, _, _, _ = get_const_tuple(Conv.shape) + _, _, _, out_channels = get_const_tuple(kernel.shape) + paddata = s[trans_paddata].op.input_tensors + + # inline the pad and dtype transform + s[trans_paddata].compute_inline() + s[kernel].compute_inline() + s[paddata[0]].compute_inline() + + # Designate the memory hierarchy + AS = s.cache_read(trans_paddata, 'shared', [Conv]) + WS = s.cache_read(kernel, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + if Conv.op in s.outputs: + output = Conv + ConvS = s.cache_read(ConvF, 'shared', [Conv]) + OL = ConvS + else: + output = s.outputs[0].output(0) + s[Conv].set_scope('shared') + OL = Conv + + # Schedule for autotvm + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4]) + cfg.define_knob("warp_col_tiles", [1, 2, 4]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 8]) + cfg.define_knob("vector_width", [1, 2, 4, 8]) + + if (batch % 16 == 0 and out_channels % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (batch % 8 == 0 and out_channels % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + elif (batch % 32 == 0 and out_channels % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + + # fallback support + target = tvm.target.Target.current() + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda') + cfg.fallback_with_reference_log(ref_log) + + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offset = cfg["offset"].val + wmma_m = cfg["wmma_m"].val + vector_width = cfg["vector_width"].val + + wmma_k = 16 + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + warp_size = 32 + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Define the intrin strides + def get_strides(extents): + return [np.prod(extents[i:]).tolist() for i in range(len(extents))] + + AS_align = chunk * wmma_k + offset + WS_align = warp_col_tiles * block_col_warps * wmma_n + offset + block_factor_n = wmma_m * warp_row_tiles * block_row_warps + block_factor_o = wmma_n * warp_col_tiles * block_col_warps + CS_align = block_factor_o + offset + AS_strides = get_strides([1, 1, AS_align, 1]) + AL_strides = get_strides([1, 1, wmma_k, 1]) + WS_strides = get_strides([WS_align, 1]) + WL_strides = get_strides([wmma_n * warp_col_tiles, 1]) + CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1]) + CS_strides = get_strides([1, 1, CS_align, 1]) + + # Schedule for output + nc, hc, wc, oc = output.op.axis + block_k = s[output].fuse(hc, wc) + s[output].bind(block_k, block_z) + block_i, nc = s[output].split(nc, factor=block_factor_n) + block_j, oc = s[output].split(oc, factor=block_factor_o) + s[output].reorder(block_k, block_i, block_j, nc, oc) + t = s[output].fuse(nc, oc) + t, ti = s[output].split(t, factor=vector_width) + t, tx = s[output].split(t, factor=warp_size) + t, ty = s[output].split(t, factor=block_row_warps) + t, tz = s[output].split(t, factor=block_col_warps) + s[output].bind(block_i, block_x) + s[output].bind(block_j, block_y) + s[output].bind(tz, thread_z) + s[output].bind(ty, thread_y) + s[output].bind(tx, thread_x) + s[output].vectorize(ti) + + # Schedule wmma store + s[OL].compute_at(s[output], block_j) + nc, hc, wc, oc = OL.op.axis + s[OL].reorder(hc, wc, nc, oc) + s[OL].storage_align(wc, CS_align - 1, CS_align) + oc, ooc = s[OL].split(oc, factor=wmma_n) + oc, oci = s[OL].split(oc, factor=warp_col_tiles) + _, oc = s[OL].split(oc, factor=block_col_warps) + nc, nnc = s[OL].split(nc, factor=wmma_m) + nc, nci = s[OL].split(nc, factor=warp_row_tiles) + _, nc = s[OL].split(nc, factor=block_row_warps) + s[OL].reorder(nc, oc, nci, oci, nnc, ooc) + s[OL].bind(nc, thread_y) + s[OL].bind(oc, thread_z) + + # Schedule wmma computation + s[ConvF].compute_at(s[OL], oc) + n, h, w, o = ConvF.op.axis + n, nnf = s[ConvF].split(n, factor=wmma_m) + o, oof = s[ConvF].split(o, factor=wmma_n) + ic, ii = s[ConvF].split(ic, factor=wmma_k) + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii) + + s[AF].compute_at(s[ConvF], ki) + s[WF].compute_at(s[ConvF], ki) + + # Schedule wmma load + n, h, w, i = AF.op.axis + n, nn = s[AF].split(n, factor=wmma_m) + i, ii = s[AF].split(i, factor=wmma_k) + s[AF].reorder(n, i, nn, ii) + + kh, kw, i, o = WF.op.axis + i, ii = s[WF].split(i, factor=wmma_k) + o, oo = s[WF].split(o, factor=wmma_n) + s[WF].reorder(o, i, oo) + s[WF].reorder(i, o, ii, oo) + + s[WS].compute_at(s[ConvF], ko) + s[AS].compute_at(s[ConvF], ko) + + # Schedule for data's share memory + n, h, w, i = AS.op.axis + s[AS].reorder(h, w, n, i) + s[AS].storage_align(w, AS_align - 1, AS_align) + t = s[AS].fuse(n, i) + t, ti = s[AS].split(t, factor=vector_width) + t, tx = s[AS].split(t, factor=warp_size) + t, ty = s[AS].split(t, factor=block_row_warps) + _, tz = s[AS].split(t, factor=block_col_warps) + s[AS].bind(ty, thread_y) + s[AS].bind(tz, thread_z) + s[AS].bind(tx, thread_x) + s[AS].vectorize(ti) + + # Schedule for kernel's share memory + kh, kw, ic, o = WS.op.axis + t = s[WS].fuse(ic, o) + s[WS].storage_align(ic, WS_align - 1, WS_align) + t, ti = s[WS].split(t, factor=vector_width) + t, tx = s[WS].split(t, factor=warp_size) + t, ty = s[WS].split(t, factor=block_row_warps) + _, tz = s[WS].split(t, factor=block_col_warps) + s[WS].bind(ty, thread_y) + s[WS].bind(tz, thread_z) + s[WS].bind(tx, thread_x) + s[WS].vectorize(ti) + + shape = (wmma_m, wmma_n, wmma_k) + + # tensorize the wmma process + AS_shape = (wmma_m, 1, 1, wmma_k) + AL_shape = (wmma_m, 1, 1, wmma_k) + WS_shape = (wmma_k, wmma_n) + WL_shape = (wmma_k, wmma_n) + CL_shape = (wmma_m, 1, 1, wmma_n) + CS_shape = (wmma_m, 1, 1, wmma_n) + + AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype) + WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name="k") + CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj: + te.sum(AL_gemm[ii, t0, t1, k_gemm].astype(out_dtype) * \ + WL_gemm[k_gemm, jj].astype(out_dtype), axis=k_gemm), + name='C') + + s[AF].tensorize(nn, intrin_wmma_load_matrix_A(AL_strides, AS_strides, shape, + "row_major", AS_shape, AL_shape, in_dtype)) + s[WF].tensorize(ii, intrin_wmma_load_matrix_W(WL_strides, WS_strides, shape, + "row_major", WS_shape, WL_shape, in_dtype)) + s[OL].tensorize(nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, + shape, out_dtype, CL_shape, CS_shape)) + s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, + WL_strides, CL_strides, shape)) + + N, OH, OW, CO = get_const_tuple(output.shape) + KH, KW, CI, _ = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW) + + +@autotvm.register_topi_compute("conv2d_nhwc_tensorcore.cuda") +def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with tensorcore for NCHW layout""" + return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc_tensorcore.cuda") +def schedule_conv2d_nhwc_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_tensorcore' in op.tag: + schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/dense_tensorcore.py b/topi/python/topi/cuda/dense_tensorcore.py new file mode 100644 index 000000000000..3546847bd268 --- /dev/null +++ b/topi/python/topi/cuda/dense_tensorcore.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Compute and Schedule definition for dense tensorcore with cuda backend""" +from __future__ import absolute_import as _abs +import tvm +from tvm import te +import tvm.autotvm as autotvm +from .. import tag +from ..util import traverse_inline, get_const_tuple +from .tensor_intrin import intrin_wmma_load_matrix_A, \ + intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm + + +@autotvm.register_topi_compute("dense_tensorcore.cuda") +def dense_tensorcore(cfg, data, weight, bias=None, out_dtype=None): + """Dense tensorcore operator on CUDA""" + matmul = dense_tensorcore_cuda(data, weight, bias, out_dtype) + return matmul + + +@autotvm.register_topi_schedule("dense_tensorcore.cuda") +def schedule_dense_tensorcore(cfg, outs): + """Schedule dense operator using Tensorcore""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'dense_tensorcore': + _schedule_dense_tensorcore(cfg, s, op.output(0)) + traverse_inline(s, outs[0].op, _callback) + return s + + +def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None): + """Dense tensorcore operator on CUDA""" + assert len(data.shape) == 2 and len(weight.shape) == 2, \ + "only support 2-dim dense" + if bias is not None: + assert len(bias.shape) == 1 + if out_dtype is None: + out_dtype = data.dtype + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) + assert ((batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0) or \ + (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0) or \ + (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)), \ + "The shape of (batch, in_dim, out_dim) "\ + "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now" + k = te.reduce_axis((0, in_dim), name='k') + data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype('float16')) + weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype('float16')) + matmul = te.compute((batch, out_dim), \ + lambda i, j: te.sum(data_16[i, k].astype(out_dtype) * \ + weight_16[j, k].astype(out_dtype), axis=k), \ + name='T_dense', tag='dense_tensorcore') + if bias is not None: + matmul = te.compute((batch, out_dim), \ + lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), \ + tag=tag.BROADCAST) + return matmul + + +def _schedule_dense_tensorcore(cfg, s, C): + """Schedule dense operator using Tensorcore""" + A, B = s[C].op.input_tensors + batch, out_dim = get_const_tuple(C.shape) + out_dtype = C.dtype + s[A].compute_inline() + s[B].compute_inline() + + # Explicit memory access + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + CS = s.cache_read(CF, 'shared', [C]) + + # fallback support + target = tvm.target.Target.current() + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'dense_tensorcore.cuda') + cfg.fallback_with_reference_log(ref_log) + + # Deal with op fusion, such as bias and relu + if C.op not in s.outputs: + s[C].compute_inline() + C = s.outputs[0].output(0) + + # create tuning space + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4]) + cfg.define_knob("warp_col_tiles", [1, 2, 4]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 8]) + cfg.define_knob("offsetCS", [0, 8]) + cfg.define_knob("vec", [1, 2, 4, 8]) + + #Ensure that the default parameters are applicable when autotvm is not in use + if (batch % 32 == 0 and out_dim % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + elif (batch%16 == 0 and out_dim % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (batch % 8 == 0 and out_dim % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + + warp_size = 32 + wmma_k = 16 + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offset = cfg["offset"].val + offsetCS = cfg["offsetCS"].val + wmma_m = cfg["wmma_m"].val + vec = cfg["vec"].val + + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + #Define the stride of intrin functions + AS_align = chunk * wmma_k + offset + BS_align = chunk * wmma_k + offset + CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS + AS_stride = [AS_align, 1] + BS_stride = [BS_align, 1] + AF_stride = [wmma_k, 1] + BF_stride = [wmma_k, 1] + CF_stride = [warp_col_tiles * wmma_n, 1] + CS_stride = [CS_align, 1] + + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + #Schedule for dense computation + block_factor_b = wmma_m * warp_row_tiles * block_row_warps + block_factor_o = wmma_n * warp_col_tiles * block_col_warps + b, o = C.op.axis + block_i, bc = s[C].split(b, factor=block_factor_b) + block_j, oc = s[C].split(o, factor=block_factor_o) + s[C].reorder(block_i, block_j, bc, oc) + t = s[C].fuse(bc, oc) + t, vi = s[C].split(t, factor=vec) + t, tx = s[C].split(t, factor=warp_size) + t, ty = s[C].split(t, factor=block_row_warps) + t, tz = s[C].split(t, factor=block_col_warps) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(tz, thread_z) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].vectorize(vi) + + #Schedule for wmma store + s[CS].compute_at(s[C], block_j) + bb, oo = CS.op.axis + s[CS].storage_align(bb, CS_align - 1, CS_align) + bb, bbi = s[CS].split(bb, factor=wmma_m) + oo, ooi = s[CS].split(oo, factor=wmma_n) + bb, bbii = s[CS].split(bb, factor=warp_row_tiles) + oo, ooii = s[CS].split(oo, factor=warp_col_tiles) + s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi) + + #Schedule for wmma computation + s[CF].compute_at(s[CS], oo) + warp_i, warp_j = CF.op.axis + warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) + warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) + k, = CF.op.reduce_axis + k, _k = s[CF].split(k, factor=wmma_k) + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k) + + #Schedule for wmma_matrix_a load + s[AF].compute_at(s[CF], ki) + b, i = AF.op.axis + b, b_ii = s[AF].split(b, factor=wmma_m) + i, i_jj = s[AF].split(i, factor=wmma_k) + s[AF].reorder(b, i, b_ii, i_jj) + + #Schedule for wmma_matrix_b load + s[BF].compute_at(s[CF], ki) + o, i = BF.op.axis + o, o_ii = s[BF].split(o, factor=wmma_n) + i, i_ii = s[BF].split(i, factor=wmma_k) + s[BF].reorder(o, i, o_ii, i_ii) + + #Schedule for A's(B's) shared memory load + def shared_shedule(stage, strides): + s[stage].compute_at(s[CF], ko) + xo, yo = stage.op.axis + s[stage].storage_align(xo, strides - 1, strides) + t = s[stage].fuse(xo, yo) + t, vi = s[stage].split(t, factor=vec) + t, tx = s[stage].split(t, factor=warp_size) + t, ty = s[stage].split(t, factor=block_row_warps) + _, tz = s[stage].split(t, factor=block_col_warps) + s[stage].bind(ty, thread_y) + s[stage].bind(tz, thread_z) + s[stage].bind(tx, thread_x) + s[stage].vectorize(vi) + + shared_shedule(AS, AS_align) + shared_shedule(BS, BS_align) + + shape = (wmma_m, wmma_n, wmma_k) + in_dtype = 'float16' + AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) + BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') + CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *\ + BL_gemm[jj, k_gemm].astype(out_dtype),\ + axis=k_gemm), name='CL_compute') + + #lower the computation loops down to TensorCore hardware intrinsics + #by mapping the dense tensorcore to tensor intrinsics + s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \ + AF_stride, AS_stride, shape, "row_major",\ + (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) + s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \ + BF_stride, BS_stride, shape, "col_major",\ + (wmma_n, wmma_k), (wmma_n, wmma_k), 'float16')) + s[CF].tensorize(_ii, intrin_wmma_gemm( \ + AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape)) + s[CS].tensorize(bbi, intrin_wmma_store_matrix( \ + CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n))) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index 468e2cd21fa8..f8fce342e212 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, unnecessary-lambda, too-many-arguments """Tensor intrinsics on CUDA.""" -#pylint: disable=invalid-name import tvm from tvm import te @@ -77,3 +77,148 @@ def _instr(index): scope=scopes[t]) for t in [x, y, z]} return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds) + + +def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): + """Intrin function for loading data from shared memory to wmma.matrix_a""" + wmma_m, wmma_n, wmma_k = shape + + A = te.placeholder(A_shape, name='A', dtype=in_dtype) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, + scope='shared', strides=strides_from, + data_alignment=32, offset_factor=8) + C = te.compute(C_shape, lambda *i: A(*i), name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, + scope="wmma.matrix_a", strides=strides_dst, + data_alignment=32, offset_factor=8) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + row = wmma_m * wmma_k + warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, wmma_m, wmma_n, wmma_k, warp_index, + BA.access_ptr('r'), strides_from[0], layout)) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_load_matrix_W(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): + """Intrin function for loading data from shared memory to wmma.matrix_b""" + wmma_m, wmma_n, wmma_k = shape + + A = te.placeholder(A_shape, name='A', dtype=in_dtype) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, + scope='shared', strides=strides_from, + data_alignment=32, offset_factor=8) + C = te.compute(C_shape, lambda *i: A(*i), name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, + scope="wmma.matrix_b", strides=strides_dst, + data_alignment=32, offset_factor=8) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + row = wmma_n * wmma_k + warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, wmma_m, wmma_n, wmma_k, warp_index, + BA.access_ptr('r'), strides_from[0], layout)) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shape, C_shape): + """Intrin function for storing the results from wmma.accumulator to shared""" + wmma_m, wmma_n, wmma_k = shape + A = te.placeholder(A_shape, name='A', dtype=out_dtype) + BA = tvm.tir.decl_buffer(A.shape, A.dtype, + scope='wmma.accumulator', + strides=strides_from, data_alignment=32, + offset_factor=8) + C = te.compute(C_shape, lambda *i: A(*i), name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, + scope='shared', strides=strides_dst, + data_alignment=32, offset_factor=8) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + row = wmma_m * wmma_n + warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, wmma_m, wmma_n, wmma_k, warp_index, + BC.access_ptr('w'), strides_dst[0], 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, + strides_W, strides_Conv, shape): + """Intrin for wmma fill_fragment and mma_sync + + Parameters + ---------- + AL_gemm : tvm.te.placeholder + wmma matrix A + WL_gemm : tvm.te.placeholder + wmma matrix B + CL_compute : tvm.te.compute + The definition of wmma gemm + """ + wmma_m, wmma_n, wmma_k = shape + A = AL_gemm + B = WL_gemm + C = CL_compute + + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', + scope='wmma.matrix_a', data_alignment=32, + offset_factor=8, strides=strides_A) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', + scope='wmma.matrix_b', data_alignment=32, + offset_factor=8, strides=strides_W) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', + scope='wmma.accumulator', data_alignment=32, + offset_factor=8, strides=strides_Conv) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def warp_idnex(offset, row, col): + row = row * col + return offset // row + offset % row // col + + warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k) + warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n) + warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n) + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k, + warp_index_C, 0.0)) + return ib.get() + + def update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + BC.data, warp_index_C, + BA.data, warp_index_A, + BB.data, warp_index_B, + BC.data, warp_index_C)) + return ib.get() + + return update(), init(), update() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py index 814fd45e0636..e027d5a7ccd9 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc.py +++ b/topi/tests/python/test_topi_conv2d_nhwc.py @@ -27,7 +27,8 @@ _conv2d_nhwc_implement = { - "generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), + "llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), + "cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "arm_cpu": (topi.arm_cpu.conv2d_nhwc_spatial_pack, topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), @@ -60,9 +61,9 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - B = topi.nn.conv2d(A, W, (stride, stride), padding, - (dilation, dilation), layout='NHWC', out_dtype=dtype) - s = topi.generic.schedule_conv2d_nhwc([B]) + fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_implement) + B = fcompute(A, W, stride, padding, dilation, dtype) + s = fschedule([B]) ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) @@ -71,8 +72,7 @@ def check_device(device): func(a, w, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - # TODO(@alexgl-github): add cuda back after fix conv2d_nhwc for cuda - for device in ['llvm']: + for device in ['llvm', 'cuda']: check_device(device) diff --git a/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py new file mode 100644 index 000000000000..cc327849caea --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_tensorcore.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +"""Example code to do convolution.""" + +import numpy as np +import tvm +import topi +import topi.testing +from tvm import te +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + + +_conv2d_nhwc_tensorcore_implement = { + "cuda": (topi.cuda.conv2d_nhwc_tensorcore, topi.cuda.schedule_conv2d_nhwc_tensorcore) +} + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, devices='cuda'): + """Test the conv2d with tensorcore for nhwc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_height, in_width, in_channel), name='A') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') + bias = te.placeholder((1, 1, 1, num_filter), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support Tensor Cores") + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv2d_nhwc_tensorcore_implement) + C = fcompute(A, W, stride, padding, dilation, 'float32') + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + rtol = 1e-3 + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) + + check_device(devices) + + +def test_conv2d_nhwc_tensorcore(): + """Test the conv2d with tensorcore for nhwc layout""" + verify_conv2d_nhwc(16, 16, 14, 16, 3, 1, 1) + verify_conv2d_nhwc(16, 128, 7, 128, 7, 1, 3) + verify_conv2d_nhwc(16, 160, 7, 160, 7, 1, 3) + + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_bias=True) + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True) + verify_conv2d_nhwc(32, 64, 14, 64, 3, 1, 1, add_relu=True, add_bias=True) + + verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, (3, 3, 2, 2)) + verify_conv2d_nhwc(16, 64, 17, 64, 7, 1, "SAME") + verify_conv2d_nhwc(16, 48, 35, 48, 5, 1, "VALID") + verify_conv2d_nhwc(16, 48, 56, 48, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nhwc(16, 64, 28, 64, 3, 1, (1, 1, 1, 1)) + + +if __name__ == "__main__": + test_conv2d_nhwc_tensorcore() diff --git a/topi/tests/python/test_topi_dense_tensorcore.py b/topi/tests/python/test_topi_dense_tensorcore.py new file mode 100644 index 000000000000..f74f31e740bc --- /dev/null +++ b/topi/tests/python/test_topi_dense_tensorcore.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument +"""Test code for dense tensorcore operator""" +import numpy as np +import tvm +import topi +import topi.testing +from topi.util import get_const_tuple +from tvm import te +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc + + +_dense_implement = { + "gpu": [(topi.cuda.dense_tensorcore, topi.cuda.schedule_dense_tensorcore)] +} + +def verify_dense(batch, in_dim, out_dim, use_bias=True): + """Dense tensorcore verify function""" + A = te.placeholder((batch, in_dim), name='A') + B = te.placeholder((out_dim, in_dim), name='B') + C = te.placeholder((out_dim,), name='C') + dtype = A.dtype + + # use memoize to pickle the test data for next time use + @memoize("topi.tests.test_topi_dense_tensorcore") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(dtype) + if use_bias: + d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) + else: + d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) + return (a_np, b_np, c_np, d_np) + # get the test data + a_np, b_np, c_np, d_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support Tensor Cores") + return + print("Running on target: %s" % device) + for fcompute, fschedule in topi.testing.dispatch(device, _dense_implement): + with tvm.target.create(device): + D = fcompute(A, B, C if use_bias else None) + D = topi.nn.relu(D) + s = fschedule([D]) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(c_np, ctx) + d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B, C, D], device, name="dense") + f(a, b, c, d) + tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-3) + + + for device in ['cuda']: + check_device(device) + + +def test_dense_tensorcore(): + """Test cases""" + verify_dense(8, 16, 32, use_bias=True) + verify_dense(16, 32, 16, use_bias=True) + verify_dense(256, 1024, 1024, use_bias=True) + verify_dense(1000, 1024, 1024, use_bias=False) + verify_dense(256, 2048, 1000, use_bias=False) + + +if __name__ == "__main__": + test_dense_tensorcore()