From 4374fa4e145d8e9d31961f9f13dc8bc205e45122 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Tue, 27 Apr 2021 21:22:19 +0800 Subject: [PATCH 1/4] implement MHA order same as training --- .../tensorrt/plugin/qkv_to_context_plugin.cu | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index a5fc9e73c5f27..0790f7e63a142 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -225,6 +225,12 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( return input_types[0]; } +template +__global__ void apply_scale(T* data, T scale, int n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + data[tid] = data[tid] * scale; +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -291,10 +297,16 @@ int QkvToContextPluginDynamic::enqueue( platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(device_id))); + int n_q = seq_len * head_number_ * head_size_; + constexpr int threads = 128; + int blocks = (n_q + threads - 1) / threads; + + apply_scale<<>>(tptr, static_cast(scale_), n_q); + const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, - qkptr, input1_data, tptr, half(scale_), half(0.0)); + qkptr, input1_data, tptr, half(1.), half(0.0)); int grid = batch * head_number_ * seq_len; int block = head_size_; From 6206a2180662015d6531e2dd8127073f29353085 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 28 Apr 2021 16:36:08 +0800 Subject: [PATCH 2/4] fix fp16 compile issue on old architecture --- paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 0790f7e63a142..9c199be396859 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -227,8 +227,10 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( template __global__ void apply_scale(T* data, T scale, int n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int tid = blockIdx.x * blockDim.x + threadIdx.x; data[tid] = data[tid] * scale; +#endif } int QkvToContextPluginDynamic::enqueue( From 757129b98602f5a8936dd6bf6d7777908d18ce6f Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 28 Apr 2021 20:34:55 +0800 Subject: [PATCH 3/4] fix format --- paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 9c199be396859..970f679ee8fb0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -233,6 +233,7 @@ __global__ void apply_scale(T* data, T scale, int n) { #endif } + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, From 4fd31186fd62ec6031776ec3625e01597ad90807 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 28 Apr 2021 20:35:43 +0800 Subject: [PATCH 4/4] fix format --- .../inference/tensorrt/plugin/qkv_to_context_plugin.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 970f679ee8fb0..214e1a81e7dc0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -226,14 +226,13 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( } template -__global__ void apply_scale(T* data, T scale, int n) { +__global__ void apply_scale(T *data, T scale, int n) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int tid = blockIdx.x * blockDim.x + threadIdx.x; data[tid] = data[tid] * scale; #endif } - int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -304,7 +303,8 @@ int QkvToContextPluginDynamic::enqueue( constexpr int threads = 128; int blocks = (n_q + threads - 1) / threads; - apply_scale<<>>(tptr, static_cast(scale_), n_q); + apply_scale<<>>(tptr, static_cast(scale_), + n_q); const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func;