diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 6dc04c9f58fd1..0e22e0c09274b 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -342,36 +342,57 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co conv_dtype = x.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) - oshape = conv_output_shape( - tensor_format, - pad, - stride, - dilation, - list(x.shape), - list(w.shape), - x.dtype, - conv_dtype, - groups, - ) - if algo == -1: - # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when - # using INT8 data type, CuDNN will crash down. - # On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format - if tensor_format == 1 and conv_dtype == "int32": - algo = 1 - else: - algo = conv_find_algo( - tensor_format, - pad, - stride, - dilation, - list(x.shape), - list(w.shape), - oshape, - x.dtype, - conv_dtype, - groups, - ) + x_shape = list(x.shape) + + if isinstance(x.shape[0], tvm.tir.expr.IntImm): + oshape = conv_output_shape( + tensor_format, + pad, + stride, + dilation, + x_shape, + list(w.shape), + x.dtype, + conv_dtype, + groups, + ) + if algo == -1: + # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when + # using INT8 data type, CuDNN will crash down. + # On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format + if tensor_format == 1 and conv_dtype == "int32": + algo = 1 + else: + algo = conv_find_algo( + tensor_format, + pad, + stride, + dilation, + list(x.shape), + list(w.shape), + oshape, + x.dtype, + conv_dtype, + groups, + ) + else: + # The dynamic batch size case, pretend this is a single batch + x_shape[0] = 1 + oshape = conv_output_shape( + tensor_format, + pad, + stride, + dilation, + x_shape, + list(w.shape), + x.dtype, + conv_dtype, + groups, + ) + oshape[0] = x.shape[0] + # This picks CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + # It seems this is the fastest among algorithms that are always applicable + algo = 1 if dims == 4: return te.extern( diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index ce9cebc3c9635..63c7c93082848 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -96,17 +96,19 @@ def conv2d_cudnn( pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1 - cfg.add_flop( - groups - * 2 - * N - * OH - * OW - * CO - * CI - * ((KH - 1) * dilation_h + 1) - * ((KW - 1) * dilation_w + 1) - ) + + if isinstance(N, int): + cfg.add_flop( + groups + * 2 + * N + * OH + * OW + * CO + * CI + * ((KH - 1) * dilation_h + 1) + * ((KW - 1) * dilation_w + 1) + ) if data.dtype == "int8" or kernel.dtype == "int8": if layout == "NCHW": diff --git a/python/tvm/topi/cuda/conv3d.py b/python/tvm/topi/cuda/conv3d.py index e5a3a53a89ff9..530df31ed3dc3 100644 --- a/python/tvm/topi/cuda/conv3d.py +++ b/python/tvm/topi/cuda/conv3d.py @@ -206,18 +206,20 @@ def conv3d_cudnn( OD = (D + 2 * pad_d - KD) // stride_d + 1 OH = (H + 2 * pad_h - KH) // stride_h + 1 OW = (W + 2 * pad_w - KW) // stride_w + 1 - cfg.add_flop( - 2 - * N - * OD - * OH - * OW - * CO - * CI - * ((KD - 1) * dilation_d + 1) - * ((KH - 1) * dilation_h + 1) - * ((KW - 1) * dilation_w + 1) - ) + + if isinstance(N, int): + cfg.add_flop( + 2 + * N + * OD + * OH + * OW + * CO + * CI + * ((KD - 1) * dilation_d + 1) + * ((KH - 1) * dilation_h + 1) + * ((KW - 1) * dilation_w + 1) + ) return cudnn.conv_forward( data, diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 47b9db4f390a5..85b9b19bdb024 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -42,7 +42,8 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): batch, in_dim = data.shape out_dim, _ = weight.shape matmul = cublas.matmul(data, weight, False, True) - cfg.add_flop(batch * in_dim * out_dim * 2) + if isinstance(batch, int): + cfg.add_flop(batch * in_dim * out_dim * 2) if bias is not None: matmul = te.compute( (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index e6812aa3bbfad..cb3b5d42e5538 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -72,12 +72,11 @@ def check_result( str(e), str(r), ) - return - - if flatten: - r = r.flatten() - e = e.flatten() - tvm.testing.assert_allclose(r, e, atol=2e-6) + else: + if flatten: + r = r.flatten() + e = e.flatten() + tvm.testing.assert_allclose(r, e, atol=2e-6) def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): @@ -454,6 +453,7 @@ def verify_any_conv2d( dilation, static_data_shape, ref_out_shape, + use_cudnn=False, ): mod = tvm.IRModule() dtype = "float32" @@ -463,7 +463,12 @@ def verify_any_conv2d( mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) - check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) + + targets = None + if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + targets = [("cuda -libs=cudnn", tvm.gpu(0))] + + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets) # TODO(@kevinthesun): Support dynamic input height and width. @@ -487,6 +492,16 @@ def test_any_conv2d(): (2, 64, 224, 224), (2, 64, 222, 222), ) + verify_any_conv2d( + (relay.Any(), 64, 224, 224), + (64, 64, 3, 3), + (1, 1), + (1, 1), + (1, 1), + (1, 64, 224, 224), + (1, 64, 224, 224), + use_cudnn=True, + ) def verify_any_conv2d_NCHWc( @@ -724,7 +739,13 @@ def test_any_batch_flatten(): def verify_any_dense( - data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape + data_shape, + weight_shape, + units, + static_data_shape, + static_weight_shape, + ref_out_shape, + use_cublas=False, ): mod = tvm.IRModule() dtype = "float32" @@ -734,7 +755,12 @@ def verify_any_dense( mod["main"] = relay.Function([data, weight], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) weight_np = np.random.uniform(size=static_weight_shape).astype(dtype) - check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True) + + targets = None + if use_cublas and tvm.get_global_func("tvm.contrib.cublas.matmul", True): + targets = [("cuda -libs=cublas", tvm.gpu(0))] + + check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True, targets=targets) # TODO(tvm-team) Fix dense schedule @@ -744,6 +770,12 @@ def test_any_dense(): verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50)) +@tvm.testing.uses_gpu +def test_any_dense_dynamic_batch(): + verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50)) + verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True) + + @tvm.testing.uses_gpu def verify_any_pad(data_shape, pad_width, static_data_shape): mod = tvm.IRModule()