Skip to content

Commit

Permalink
[TOPI][Tensor Core] Conv2d and Dense ops support on Tensor Core (apac…
Browse files Browse the repository at this point in the history
…he#5099)

* [TOPI][Tensor Core] Optimization of CNNs on Tensor Core apache#6004

* update conv2d test

* # pylint: dense_tensorcore.py

* modify

* modify conv2d

* modify the unclear comment,add shape assertion in conv2d compute,combine general gemm intrinsic

* add shape assertion in conv2d compute, combine general gemm intrinsic

Co-authored-by: libaihong <libaihong@inspur.com>
Co-authored-by: libaihong <61525430+libaihong@users.noreply.github.com>
  • Loading branch information
3 people authored and zhiics committed Apr 17, 2020
1 parent 85b2e09 commit 6c02e26
Show file tree
Hide file tree
Showing 12 changed files with 1,172 additions and 47 deletions.
39 changes: 31 additions & 8 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"""Definition of CUDA/GPU operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
import tvm
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from .generic import *
from .. import op as _op
from .... import get_global_func
Expand Down Expand Up @@ -112,13 +114,23 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda")
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
# elif layout == "NHWC":
# assert kernel_layout == "HWIO"
# strategy.add_implementation(
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# name="conv2d_nhwc.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda")
N, _, _, _ = get_const_tuple(data.shape)
_, _, CI, CO = get_const_tuple(kernel.shape)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
name="conv2d_nhwc_tensorcore.cuda",
plevel=20)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
Expand Down Expand Up @@ -279,6 +291,9 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
strategy = _op.OpStrategy()
data, weights = inputs
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if out_type.dtype == "int8":
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
Expand All @@ -289,13 +304,21 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
name="dense_small_batch.cuda")
b = inputs[0].shape[0]
with SpecializedCondition(b >= 32):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5)
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
Expand Down
59 changes: 45 additions & 14 deletions python/tvm/relay/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def residual_unit(data,
stride,
dim_match,
name,
bottle_neck=True):
bottle_neck=True,
data_layout="NCHW",
kernel_layout="IOHW"
):
"""Return ResNet Unit symbol for building ResNet
Parameters
Expand Down Expand Up @@ -67,42 +70,50 @@ def residual_unit(data,
kernel_size=(1, 1),
strides=stride,
padding=(0, 0),
name=name + '_conv1')
name=name + '_conv1',
data_layout=data_layout,
kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)
bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = relay.nn.relu(data=bn3)
conv3 = layers.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), name=name + '_conv3')
strides=(1, 1), padding=(0, 0), name=name + '_conv3',
data_layout=data_layout, kernel_layout=kernel_layout)
if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv3, shortcut)

bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = relay.nn.relu(data=bn1)
conv1 = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), name=name + '_conv1')
strides=stride, padding=(1, 1), name=name + '_conv1',
data_layout=data_layout, kernel_layout=kernel_layout)
bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = relay.nn.relu(data=bn2)
conv2 = layers.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name=name + '_conv2')
strides=(1, 1), padding=(1, 1), name=name + '_conv2',
data_layout=data_layout, kernel_layout=kernel_layout)

if dim_match:
shortcut = data
else:
shortcut = layers.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, name=name+'_sc')
strides=stride, name=name+'_sc',
data_layout=data_layout, kernel_layout=kernel_layout)
return relay.add(conv2, shortcut)


Expand All @@ -112,6 +123,7 @@ def resnet(units,
num_classes,
data_shape,
bottle_neck=True,
layout="NCHW",
dtype="float32"):
"""Return ResNet Program.
Expand All @@ -135,9 +147,16 @@ def resnet(units,
bottle_neck : bool
Whether apply bottleneck transformation.
layout: str
The data layout for conv2d
dtype : str
The global data type.
"""

data_layout = layout
kernel_layout = "OIHW" if layout == "NCHW" else "HWIO"

num_unit = len(units)
assert num_unit == num_stages
data = relay.var("data", shape=data_shape, dtype=dtype)
Expand All @@ -146,27 +165,32 @@ def resnet(units,
if height <= 32: # such as cifar10
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), name="conv0")
strides=(1, 1), padding=(1, 1), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
else: # often expected to be 224 such as imagenet
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), name="conv0")
strides=(2, 2), padding=(3, 3), name="conv0",
data_layout=data_layout, kernel_layout=kernel_layout)
body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0')
body = relay.nn.relu(data=body)
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1),
layout=data_layout)

for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck,
data_layout=data_layout, kernel_layout=kernel_layout)
bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1')
relu1 = relay.nn.relu(data=bn1)
# Although kernel is not used here when global_pool=True, we should put one
pool1 = relay.nn.global_avg_pool2d(data=relu1)
pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout)
flat = relay.nn.batch_flatten(data=pool1)
fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
net = relay.nn.softmax(data=fc1)
Expand All @@ -177,6 +201,7 @@ def get_net(batch_size,
num_classes,
num_layers=50,
image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32",
**kwargs):
"""
Expand Down Expand Up @@ -229,13 +254,15 @@ def get_net(batch_size,
num_classes=num_classes,
data_shape=data_shape,
bottle_neck=bottle_neck,
layout=layout,
dtype=dtype)


def get_workload(batch_size=1,
num_classes=1000,
num_layers=18,
image_shape=(3, 224, 224),
layout="NCHW",
dtype="float32",
**kwargs):
"""Get benchmark workload for resnet
Expand All @@ -254,6 +281,9 @@ def get_workload(batch_size=1,
image_shape : tuple, optional
The input image shape
layout: str
The data layout for conv2d
dtype : str, optional
The data type
Expand All @@ -273,5 +303,6 @@ def get_workload(batch_size=1,
num_layers=num_layers,
image_shape=image_shape,
dtype=dtype,
layout=layout,
**kwargs)
return create_workload(net)
7 changes: 7 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,

func = tir::ThreadSync(func, "shared");
func = tir::ThreadSync(func, "warp");
func = tir::InferFragment(func);
func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
Expand Down Expand Up @@ -244,6 +245,12 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< "\n";
}

for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerDeviceStorageAccessInfo(func);
fdevice.Set(i, func);
}

for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type);
Expand Down
2 changes: 2 additions & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@
from .nms import get_valid_counts, non_max_suppression
from .rcnn import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .dense_tensorcore import *
35 changes: 17 additions & 18 deletions topi/python/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda
from .conv2d_nhwc import schedule_conv2d_nhwc_direct


@autotvm.register_topi_compute("conv2d_nchw.cuda")
Expand All @@ -46,24 +47,22 @@ def _callback(op):
return s


# TODO(@alexgl-github): It's invalid to call schedule_direct_cuda for NHWC layout
# as it assumes the input layout to be NCHW. Please fix this.
# @autotvm.register_topi_compute("conv2d_nhwc.cuda")
# def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
# return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
#
#
# @autotvm.register_topi_schedule("conv2d_nhwc.cuda")
# def schedule_conv2d_nhwc(cfg, outs):
# outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
# s = te.create_schedule([x.op for x in outs])
#
# def _callback(op):
# if op.tag == 'conv2d_nhwc':
# schedule_direct_cuda(cfg, s, op.output(0))
#
# traverse_inline(s, outs[0].op, _callback)
# return s
@autotvm.register_topi_compute("conv2d_nhwc.cuda")
def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'):
"""Compute conv2d with NHWC layout"""
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv2d_nhwc.cuda")
def schedule_conv2d_nhwc(cfg, outs):
"""Create the schedule for conv2d_nhwc"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'conv2d_nhwc':
schedule_conv2d_nhwc_direct(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv2d_cudnn.cuda")
Expand Down
Loading

0 comments on commit 6c02e26

Please sign in to comment.