Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-TRT] Implement MHA fp16 order same as training #32629

Merged
merged 4 commits into from
Apr 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
return input_types[0];
}

template <typename T>
zlsh80826 marked this conversation as resolved.
Show resolved Hide resolved
__global__ void apply_scale(T *data, T scale, int n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
data[tid] = data[tid] * scale;
#endif
}

int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
Expand Down Expand Up @@ -291,10 +299,17 @@ int QkvToContextPluginDynamic::enqueue(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));

int n_q = seq_len * head_number_ * head_size_;
constexpr int threads = 128;
int blocks = (n_q + threads - 1) / threads;

apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
n_q);

const platform::CUDADeviceContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
qkptr, input1_data, tptr, half(scale_), half(0.0));
qkptr, input1_data, tptr, half(1.), half(0.0));

int grid = batch * head_number_ * seq_len;
int block = head_size_;
Expand Down