Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] optimize transfer_layout kernel in fp16 model #48692

Merged
merged 1 commit into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions paddle/phi/kernels/funcs/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,83 @@ limitations under the License. */
namespace phi {
namespace funcs {

// The following part of the code refers to NVIDIA-cutlass
zhoutianzi666 marked this conversation as resolved.
Show resolved Hide resolved
// https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nchw_to_nhwc.h
// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
// reserved. SPDX-License-Identifier: BSD-3-Clause
template <typename T>
zhoutianzi666 marked this conversation as resolved.
Show resolved Hide resolved
__global__ void batch_transpose_kernel(
T* output, const T* input, const int batch, const int M, const int N) {
const int num = M * N;
// "+1" to avoid smem bank conflict
__shared__ T shbuf[32 * (32 + 1)];
const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x;
const int32_t wid = tid / 32;
const int32_t lid = tid % 32;
const int32_t batch_i = blockIdx.z;
const int32_t mi0 = blockIdx.y * 32;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mi0表示这个thread block要处理的这部分区域的第0行是全局的第mi0行。
注意,mi0是行的意思,他必须用blockIdx.y * 32;!因为thread block是先按照x维度发送,所以x维度的这些thread block最好处理这些这些行的所有列!

const int32_t ni0 = blockIdx.x * 32;

const size_t input_idx = batch_i * num + (mi0 + wid) * N + ni0;
const T* A = input + input_idx;
if (ni0 + lid < N) {
const int lid_x_33 = lid * 33;
if ((mi0 + 32) <= M) {
int mi = wid; // between 0 and 7
#pragma unroll
for (int mLoopIdx = 0; mLoopIdx < 4; mLoopIdx++) {
shbuf[lid_x_33 + mi] = A[lid];
A = &A[8 * N];
mi += 8;
}
} else {
for (int mi = wid; mi < 32; mi += 8) {
if ((mi + mi0) < M) {
shbuf[lid_x_33 + mi] = A[lid];
}
A = &A[8 * N];
}
}
}
__syncthreads();

const int32_t miOut = mi0 + lid;
output = &output[batch_i * num + miOut];
if (miOut < M) {
if (ni0 + 32 < N) {
int nI = wid;
#pragma unroll
for (int nLoopIdx = 0; nLoopIdx < 4; ++nLoopIdx) {
output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
nI += 8;
}
} else {
for (int nI = wid; nI < 32; nI += 8) {
if (ni0 + nI < N) {
output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
}
}
}
}
}

template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n) {
dim3 grid((n + 31) / 32, (m + 31) / 32, batch);
dim3 block(32, 8);
batch_transpose_kernel<<<grid, block>>>(output, input, batch, m, n);
}

using float16 = phi::dtype::float16;
zhoutianzi666 marked this conversation as resolved.
Show resolved Hide resolved
using bfloat16 = phi::dtype::bfloat16;

template struct SetConstant<phi::GPUContext, phi::dtype::float16>;
template struct SetConstant<phi::GPUContext, phi::dtype::bfloat16>;
template void BatchTranspose(
float16* output, const float16* input, int batch, int m, int n);
template void BatchTranspose(
float* output, const float* input, int batch, int m, int n);

template struct SetConstant<phi::GPUContext, float16>;
template struct SetConstant<phi::GPUContext, bfloat16>;
template struct SetConstant<phi::GPUContext, float>;
template struct SetConstant<phi::GPUContext, double>;
template struct SetConstant<phi::GPUContext, uint8_t>;
Expand All @@ -42,10 +114,9 @@ template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;

template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::float16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
phi::dtype::bfloat16>;
bfloat16>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/funcs/math_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ limitations under the License. */
namespace phi {
namespace funcs {

template <typename T>
zhoutianzi666 marked this conversation as resolved.
Show resolved Hide resolved
void BatchTranspose(T* output, const T* input, int batch, int m, int n);

template <typename DeviceContext, typename T>
struct TransposeNormal {
// for dims >= 7 situation
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/transfer_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ void TransferLayoutGeneral(const Context& dev_ctx,
out->Resize(phi::make_ddim(dst_dim));
dev_ctx.Alloc(out, x.dtype());

// In GPU fp16 model, we will insert many transfer_layout ops in
// conv2d_fusion_layout_transfer_pass, so we optimize this kernel on GPU
if (std::is_same<Context, phi::GPUContext>::value) {
std::vector<int> axis_nchw_nhwc = {0, 2, 3, 1};
std::vector<int> axis_nhwc_nchw = {0, 3, 1, 2};
const int batch = src_dim[0];
int row_len = src_dim[1];
int col_len = src_dim[2] * src_dim[3];
if (axis == axis_nhwc_nchw) {
row_len = src_dim[1] * src_dim[2];
col_len = src_dim[3];
}
if (x.dtype() == phi::DataType::FLOAT16) {
funcs::BatchTranspose(out->data<phi::dtype::float16>(),
x.data<phi::dtype::float16>(),
batch,
row_len,
col_len);
return;
} else if (x.dtype() == phi::DataType::FLOAT32) {
funcs::BatchTranspose(
out->data<float>(), x.data<float>(), batch, row_len, col_len);
return;
}
}

PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] {
CastDataLayout<data_t, Context>(dev_ctx, x, axis, out);
}));
Expand Down