From 0e6000f899c391da5fcc4afdd96506f3460f5fca Mon Sep 17 00:00:00 2001 From: zhink <771809832@qq.com> Date: Wed, 3 Apr 2024 14:52:29 +0800 Subject: [PATCH] inference gets cutlass info & improve coding efficiency --- .../fusion/cutlass/conv2d/conv2d_common.py | 20 ++++++++++------- .../fusion/cutlass/conv2d/conv2d_decl.h | 22 +++++++++---------- .../cutlass/fused_conv2d_add_act_kernel.cu | 14 +++++++++--- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py index 29f9e443d9c536..5d2425fe4059bf 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py @@ -57,7 +57,7 @@ ${element_c} *bias = (${element_c} *)(params.bias); ${element_c} *output = (${element_c} *)(params.output); // only used by conv2d_bias_residual - auto residual = (${element_c} *)(params.residual); + auto residual = (${element_c} *)(params.residual); int batch = params.batch; int ic = params.ic; @@ -96,8 +96,8 @@ ImplicitGemm implicit_gemm_op; size_t bytes = implicit_gemm_op.get_workspace_size(arguments); -auto stream = params.stream; -void *workspace = params.workspace; + auto stream = params.stream; + void *workspace = params.workspace; cutlass::Status status = implicit_gemm_op.can_implement(arguments); CUTLASS_CHECK(status); @@ -125,7 +125,7 @@ std::map, int> map_problem_${func_name}; std::mutex ${func_name}_mutex; -void ${func_name}(ConvAllParams params) { +bool ${func_name}(ConvAllParams params) { int batch = params.batch; int ic = params.ic; int ih = params.ih; @@ -145,7 +145,7 @@ if (map_problem_${func_name}.count(problem_size)) { ${func_name}_all_func[map_problem_${func_name}.at(problem_size)]( params); - return; + return true; } int best_config_index = ProfileToGetBestConfig( @@ -155,6 +155,7 @@ map_problem_${func_name}[problem_size] = best_config_index; ${func_name}_all_func[best_config_index](params); + return true; } """ @@ -164,8 +165,8 @@ # this function is invoked by phi kernel CommonWrapperForPhi = """ -void ${op_name}(ConvAllParams params) { - ${dispatch_body} +bool ${op_name}(ConvAllParams params) { + ${dispatch_body} } """ @@ -177,12 +178,14 @@ def convert_c_data_type(dtype): return "Conv2dDataType::bf16" elif dtype == "fp32": return "Conv2dDataType::fp32" + else: + return None CommonDispatchTemp = ''' if (params.sm_version == ${sm_code} && params.data_type == ${data_type}) { - ${op_name_with_sm}(params); + return ${op_name_with_sm}(params); } ''' @@ -215,6 +218,7 @@ def GenerateFunctionForPhi( + data_type ) dispatch_body += SubstituteTemplate(CommonDispatchTemp, sm_dicts) + dispatch_body += ''' return false;''' op_dicts = {} op_dicts["dispatch_body"] = dispatch_body op_dicts["op_name"] = camel_names[epi_func] diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h index b29ce65f5230a2..a36495ca6abfb3 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h @@ -59,19 +59,17 @@ typedef struct { } ConvAllParams; // Below functions are provided by cutlass, they are called by phi. -extern "C" void Conv2dBiasAddRelu(ConvAllParams params); -extern "C" void Conv2dBiasRelu(ConvAllParams params); -extern "C" void Conv2dBiasLeakyRelu(ConvAllParams params); -extern "C" void Conv2dBiasSilu(ConvAllParams params); -extern "C" void Conv2dBias(ConvAllParams params); -extern "C" void Conv2dBiasSigmoid(ConvAllParams params); +extern "C" bool Conv2dBiasAddRelu(ConvAllParams params); +extern "C" bool Conv2dBiasRelu(ConvAllParams params); +extern "C" bool Conv2dBiasLeakyRelu(ConvAllParams params); +extern "C" bool Conv2dBiasSilu(ConvAllParams params); +extern "C" bool Conv2dBias(ConvAllParams params); +extern "C" bool Conv2dBiasSigmoid(ConvAllParams params); -extern "C" void Conv2dDepthwiseBias(ConvAllParams params); -extern "C" void Conv2dDepthwiseBiasRelu(ConvAllParams params); -extern "C" void Conv2dDepthwiseBiasSigmoid(ConvAllParams params); -extern "C" void Conv2dDepthwiseBiasSilu(ConvAllParams params); - -extern "C" int HelloFromCutlassConv2d(int a, int b); +extern "C" bool Conv2dDepthwiseBias(ConvAllParams params); +extern "C" bool Conv2dDepthwiseBiasRelu(ConvAllParams params); +extern "C" bool Conv2dDepthwiseBiasSigmoid(ConvAllParams params); +extern "C" bool Conv2dDepthwiseBiasSilu(ConvAllParams params); } // namespace cutlass_internal } // namespace fusion diff --git a/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu b/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu index 79057bee76219b..d496dfc0d22a9e 100644 --- a/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu +++ b/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu @@ -26,7 +26,7 @@ namespace phi { namespace fusion { namespace cutlass_internal { -typedef void (*func)(phi::fusion::cutlass_internal::ConvAllParams); +typedef bool (*func)(phi::fusion::cutlass_internal::ConvAllParams); template void FusedConv2dAddActKernel(const Context& ctx, @@ -230,7 +230,11 @@ void FusedConv2dAddActKernel(const Context& ctx, "Cutlass conv2d_depthwise does not support this activation: %s.", activation.c_str())); } - conv_func(params); + + if (!conv_func(params)) { + PADDLE_THROW(phi::errors::Fatal("no fused_conv2d_add_act cutlass kernel ")); + } + output->set_layout(DataLayout::NHWC); return; } @@ -265,7 +269,11 @@ void FusedConv2dAddActKernel(const Context& ctx, PADDLE_THROW(phi::errors::InvalidArgument( "Cutlass does not support this activation: %s.", activation.c_str())); } - conv_func(params); + + if (!conv_func(params)) { + PADDLE_THROW(phi::errors::Fatal("no fused_conv2d_add_act cutlass kernel ")); + } + output->set_layout(DataLayout::NHWC); } } // namespace cutlass_internal