diff --git a/python/tvm/relay/op/contrib/cudnn.py b/python/tvm/relay/op/contrib/cudnn.py index 9714a0b87dcf..e3c256f7e38a 100644 --- a/python/tvm/relay/op/contrib/cudnn.py +++ b/python/tvm/relay/op/contrib/cudnn.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-argument """cuDNN Relay integration.""" -from typing import Callable, List, Tuple, Dict, Optional +from typing import Callable, List, Tuple import tvm import tvm.ir @@ -24,7 +24,6 @@ from tvm import te from tvm.relay import transform from tvm.contrib import cudnn -from tvm.relay.build_module import bind_params_by_name from ...dataflow_pattern import is_op, wildcard from .te_target import lower_composite, relay_to_runtime @@ -34,25 +33,19 @@ tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda())) -def partition_for_cudnn( - mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None -) -> tvm.IRModule: +def partition_for_cudnn(mod: tvm.IRModule) -> tvm.IRModule: """Partition the graph to offload for cuDNN. Parameters ---------- mod : tvm.IRModule The module to partition. - params : Optional[Dict[str, tvm.runtime.NDArray]] - Constant input parameters. Returns ------- tvm.IRModule The partitioned module. """ - if params: - mod["main"] = bind_params_by_name(mod["main"], params) seq = tvm.transform.Sequential( [ @@ -82,6 +75,12 @@ def conv2d_pattern() -> relay.Pattern: """Create pattern for conv2d.""" return is_op("nn.conv2d")(wildcard(), wildcard()) + def conv2d_bias_act_pattern() -> relay.Pattern: + """Create pattern for fused conv2d+bias+activation.""" + conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) + bias = is_op("nn.bias_add")(conv2d, wildcard()) + return bias.optional(is_op("nn.relu")) + def check_softmax(matched: relay.Call) -> bool: """Check if softmax is supported by cuDNN.""" if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]: @@ -115,9 +114,13 @@ def check_conv2d(matched: relay.Call) -> bool: return True + def check_conv2d_bias_act(matched: relay.Call) -> bool: + return True + return [ ("cudnn.softmax", softmax_pattern(), check_softmax), ("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax), + ("cudnn.conv2d_bias_act", conv2d_bias_act_pattern(), check_conv2d_bias_act), ("cudnn.conv2d", conv2d_pattern(), check_conv2d), ] @@ -134,6 +137,64 @@ def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"]) +@lower_composite("cudnn.conv2d_bias_act") +def _lower_conv2d_bias_act(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: + """Lower a fused conv2d+bias+activation using cuDNN.""" + conv_dtype = op.checked_type.dtype + if op.op.name == "nn.relu": + activation_mode = 1 # Relu + conv2d = op.args[0].args[0] + else: + activation_mode = 5 # Identity + conv2d = op.args[0] + + conv_mode = 1 + tensor_format = 0 + algo = 1 + pad = conv2d.attrs["padding"] + strides = conv2d.attrs["strides"] + dilation = conv2d.attrs["dilation"] + groups = conv2d.attrs["groups"] + + oshape = cudnn.conv_output_shape( + tensor_format, + pad, + strides, + dilation, + inputs[0].shape, + inputs[1].shape, + inputs[0].dtype, + conv_dtype, + groups, + ) + + return te.extern( + oshape, + inputs, + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d+bias+act.forward", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + strides[0], + strides[1], + dilation[0], + dilation[1], + activation_mode, + 0, + ins[0], + ins[1], + ins[2], + outs[0], + conv_dtype, + groups, + ), + name="y", + ) + + @lower_composite("cudnn.conv2d") def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor: """Lower a conv2d using cuDNN.""" diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index f5e5ee889c55..626d356da4bf 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -60,6 +60,44 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co entry_ptr->conv_entry.output_desc, y->data)); } +void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, int groups, int act, + double coef, const int pad[], const int stride[], + const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y, + DLTensor* bias, const std::string& conv_dtype) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc, + static_cast(act), + cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef)); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format, + CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, static_cast(w->shape[0]), 1, 1)); + + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, + y->shape, x->dtype, conv_dtype); + // Set Device + entry_ptr->conv_entry.device = x->device; + // Set Algo + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + + // Set workspace + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.fwd_algo, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + CUDNN_CALL(cudnnConvolutionBiasActivationForward( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.bias_desc, bias->data, + entry_ptr->conv_entry.activation_desc, entry_ptr->conv_entry.output_desc, y->data)); +} + void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { @@ -126,6 +164,30 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") conv_dtype); }); +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + int act = args[9]; + double coef = args[10]; + DLTensor* x = args[11]; + DLTensor* w = args[12]; + DLTensor* bias = args[13]; + DLTensor* y = args[14]; + std::string conv_dtype = args[15]; + int groups = args[16]; + + ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v, + dilation_v, x, w, y, bias, conv_dtype); + }); + TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") .set_body([](TVMArgs args, TVMRetValue* ret) { int mode = args[0]; diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index e39c47339c7f..68d5902c06d2 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -140,6 +140,8 @@ ConvEntry::ConvEntry() { CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc)); CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc)); CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc)); + CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc)); } ConvEntry::~ConvEntry() { @@ -147,6 +149,8 @@ ConvEntry::~ConvEntry() { CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc)); CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc)); CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc)); + CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc)); CleanWorkspace(); } diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 426ccfdf37af..871fb35dd470 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -71,6 +71,8 @@ struct ConvEntry { cudnnTensorFormat_t tensor_format; cudnnTensorDescriptor_t input_desc; cudnnFilterDescriptor_t filter_desc; + cudnnTensorDescriptor_t bias_desc; + cudnnActivationDescriptor_t activation_desc; cudnnTensorDescriptor_t output_desc; cudnnConvolutionFwdAlgo_t fwd_algo; cudnnConvolutionBwdDataAlgo_t bwd_data_algo; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 8ca3df343dad..cdbe424710c6 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -461,10 +461,12 @@ def _verify_cudnn_relay(expr): for param in func.params: shape = [int(x) for x in param.checked_type.shape] input_data.append( - (param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype)) + ( + param.name_hint, + np.random.uniform(-32, 32, size=shape).astype(param.checked_type.dtype), + ) ) - # Test against CPU reference cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod) cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod) outputs = [] @@ -484,7 +486,8 @@ def _verify_cudnn_relay(expr): tvm.testing.assert_allclose( outputs[0], outputs[1], - rtol=1e-2, + rtol=1e-3, + atol=30, ) @@ -577,5 +580,47 @@ def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, _verify_cudnn_relay(conv2d) +@tvm.testing.requires_cuda +@pytest.mark.parametrize( + "n,h,w,ci,co,groups", + [ + (1, 16, 20, 8, 16, 1), + (10, 17, 19, 16, 8, 4), + ], +) +@pytest.mark.parametrize( + "kh,kw,padding,strides,dilation,dtype", + [ + (1, 1, (3, 1, 3, 1), (1, 1), (1, 1), "float32"), + (3, 3, (1, 2), (2, 1), (2, 2), "float16"), + (7, 2, (0, 0), (3, 3), (1, 2), "float64"), + ], +) +@pytest.mark.parametrize("activation", [True, False]) +def test_relay_cudnn_conv2d_bias_act( + n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype, activation +): + data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype)) + weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype)) + bias = relay.var("bias", relay.TensorType((co,), dtype)) + conv2d = relay.op.nn.conv2d( + data, + weight, + groups=groups, + channels=co, + kernel_size=(kh, kw), + strides=strides, + dilation=dilation, + padding=padding, + data_layout="NCHW", + kernel_layout="OIHW", + ) + out = relay.op.nn.bias_add(conv2d, bias) + if activation: + out = relay.op.nn.relu(out) + + _verify_cudnn_relay(out) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))