Skip to content

Commit

Permalink
[Paddle-TRT] Implement MHA fp16 order same as training (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…32629)

* implement MHA order same as training

* fix fp16 compile issue on old architecture

* fix format

* fix format
  • Loading branch information
zlsh80826 authored and shangzhizhou committed May 7, 2021
1 parent ded39f8 commit 1676c04
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
return input_types[0];
}

template <typename T>
__global__ void apply_scale(T *data, T scale, int n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
data[tid] = data[tid] * scale;
#endif
}

int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
Expand Down Expand Up @@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));

int n_q = seq_len * head_number_ * head_size_;
constexpr int threads = 128;
int blocks = (n_q + threads - 1) / threads;

apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
n_q);

const platform::CUDADeviceContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
qkptr, input1_data, tptr, half(scale_), half(0.0));
qkptr, input1_data, tptr, half(1.), half(0.0));

int grid = batch * head_number_ * seq_len;
int block = head_size_;
Expand Down

0 comments on commit 1676c04

Please sign in to comment.