diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index a5fc9e73c5f27..214e1a81e7dc0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType( return input_types[0]; } +template +__global__ void apply_scale(T *data, T scale, int n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int tid = blockIdx.x * blockDim.x + threadIdx.x; + data[tid] = data[tid] * scale; +#endif +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue( platform::DeviceContextPool::Instance().Get( platform::CUDAPlace(device_id))); + int n_q = seq_len * head_number_ * head_size_; + constexpr int threads = 128; + int blocks = (n_q + threads - 1) / threads; + + apply_scale<<>>(tptr, static_cast(scale_), + n_q); + const platform::CUDADeviceContext &dev_ctx = *device_ctx; operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_, - qkptr, input1_data, tptr, half(scale_), half(0.0)); + qkptr, input1_data, tptr, half(1.), half(0.0)); int grid = batch * head_number_ * seq_len; int block = head_size_;