Skip to content

Commit

Permalink
[Paddle-Inference] fix special_slice plugin (#39875)
Browse files Browse the repository at this point in the history
* fix plugin: special slice for ernie
  • Loading branch information
Wangzheee authored Feb 24, 2022
1 parent ce207c3 commit 1255e7d
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;
const int hidden = blockDim.x * gridDim.x;
const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_id = blockIdx.y;

output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + local_idx];
output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
}

int SpecialSlicePluginDynamic::enqueue(
Expand All @@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue(
"hidden should be multiple of 128."));

constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);

const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);

SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
cu_seqlens, output);
const int32_t num_blocks_x = hidden / num_threads;
const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks

SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}

Expand Down

0 comments on commit 1255e7d

Please sign in to comment.