Skip to content

Commit

Permalink
[CUBLAS, CUDNN] Support dynamic batch size (#7194)
Browse files Browse the repository at this point in the history
* support cudnn and cublas on dynamic batch size

* added test for cublas

* add comment on algo choice
  • Loading branch information
masahi authored Jan 4, 2021
1 parent 7053235 commit 361f508
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 63 deletions.
81 changes: 51 additions & 30 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,36 +342,57 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)
x_shape = list(x.shape)

if isinstance(x.shape[0], tvm.tir.expr.IntImm):
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)
else:
# The dynamic batch size case, pretend this is a single batch
x_shape[0] = 1
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
oshape[0] = x.shape[0]
# This picks CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
# It seems this is the fastest among algorithms that are always applicable
algo = 1

if dims == 4:
return te.extern(
Expand Down
24 changes: 13 additions & 11 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,19 @@ def conv2d_cudnn(
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(
groups
* 2
* N
* OH
* OW
* CO
* CI
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if isinstance(N, int):
cfg.add_flop(
groups
* 2
* N
* OH
* OW
* CO
* CI
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if data.dtype == "int8" or kernel.dtype == "int8":
if layout == "NCHW":
Expand Down
26 changes: 14 additions & 12 deletions python/tvm/topi/cuda/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,20 @@ def conv3d_cudnn(
OD = (D + 2 * pad_d - KD) // stride_d + 1
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(
2
* N
* OD
* OH
* OW
* CO
* CI
* ((KD - 1) * dilation_d + 1)
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if isinstance(N, int):
cfg.add_flop(
2
* N
* OD
* OH
* OW
* CO
* CI
* ((KD - 1) * dilation_d + 1)
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

return cudnn.conv_forward(
data,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
batch, in_dim = data.shape
out_dim, _ = weight.shape
matmul = cublas.matmul(data, weight, False, True)
cfg.add_flop(batch * in_dim * out_dim * 2)
if isinstance(batch, int):
cfg.add_flop(batch * in_dim * out_dim * 2)
if bias is not None:
matmul = te.compute(
(batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
Expand Down
50 changes: 41 additions & 9 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ def check_result(
str(e),
str(r),
)
return

if flatten:
r = r.flatten()
e = e.flatten()
tvm.testing.assert_allclose(r, e, atol=2e-6)
else:
if flatten:
r = r.flatten()
e = e.flatten()
tvm.testing.assert_allclose(r, e, atol=2e-6)


def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
Expand Down Expand Up @@ -454,6 +453,7 @@ def verify_any_conv2d(
dilation,
static_data_shape,
ref_out_shape,
use_cudnn=False,
):
mod = tvm.IRModule()
dtype = "float32"
Expand All @@ -463,7 +463,12 @@ def verify_any_conv2d(
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)

targets = None
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
targets = [("cuda -libs=cudnn", tvm.gpu(0))]

check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)


# TODO(@kevinthesun): Support dynamic input height and width.
Expand All @@ -487,6 +492,16 @@ def test_any_conv2d():
(2, 64, 224, 224),
(2, 64, 222, 222),
)
verify_any_conv2d(
(relay.Any(), 64, 224, 224),
(64, 64, 3, 3),
(1, 1),
(1, 1),
(1, 1),
(1, 64, 224, 224),
(1, 64, 224, 224),
use_cudnn=True,
)


def verify_any_conv2d_NCHWc(
Expand Down Expand Up @@ -724,7 +739,13 @@ def test_any_batch_flatten():


def verify_any_dense(
data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape
data_shape,
weight_shape,
units,
static_data_shape,
static_weight_shape,
ref_out_shape,
use_cublas=False,
):
mod = tvm.IRModule()
dtype = "float32"
Expand All @@ -734,7 +755,12 @@ def verify_any_dense(
mod["main"] = relay.Function([data, weight], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True)

targets = None
if use_cublas and tvm.get_global_func("tvm.contrib.cublas.matmul", True):
targets = [("cuda -libs=cublas", tvm.gpu(0))]

check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True, targets=targets)


# TODO(tvm-team) Fix dense schedule
Expand All @@ -744,6 +770,12 @@ def test_any_dense():
verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))


@tvm.testing.uses_gpu
def test_any_dense_dynamic_batch():
verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50))
verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True)


@tvm.testing.uses_gpu
def verify_any_pad(data_shape, pad_width, static_data_shape):
mod = tvm.IRModule()
Expand Down

0 comments on commit 361f508

Please sign in to comment.