-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.31】为 Paddle 优化 dist op 在 GPU 上的计算性能 #44946
Changes from all commits
36af63c
46aa2d8
1c4b2c1
6f0a1e5
860dce7
21d3f01
f6735ad
7a05b51
520d726
e369fb1
36f5a86
305e483
494e905
bdb14be
39acd97
0c4b29b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/dist_kernel.h" | ||
#include "paddle/phi/backends/gpu/gpu_launch_config.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/elementwise_subtract_kernel.h" | ||
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" | ||
#include "paddle/phi/kernels/gpu/reduce.h" | ||
#include "paddle/phi/kernels/p_norm_kernel.h" | ||
|
||
namespace phi { | ||
|
||
#define FULL_MASK 0xffffffff | ||
|
||
template <typename T> | ||
struct ZeroOrderFunctor { | ||
public: | ||
__device__ T operator()(const T& x, const T& y) const { | ||
return static_cast<T>((x - y) != 0); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct OtherOrderFunctor { | ||
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} | ||
__device__ T operator()(const T& x, const T& y) const { | ||
return static_cast<T>(pow(abs(x - y), p_order_)); | ||
} | ||
|
||
private: | ||
T p_order_; | ||
}; | ||
|
||
template <typename T> | ||
struct PowFunctor { | ||
explicit PowFunctor(const T& p_order) : p_order_(p_order) {} | ||
HOSTDEVICE inline T operator()(const T x) const { | ||
return static_cast<T>(pow(x, p_order_)); | ||
} | ||
T p_order_; | ||
}; | ||
|
||
template <typename T, typename Functor> | ||
__global__ void ReduceSumWithSubtract( | ||
const T* x, const T* y, T* out, int64_t N, Functor func) { | ||
T sum_val = 0; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
sum_val += func(x[i], y[i]); | ||
} | ||
|
||
__syncthreads(); | ||
sum_val = phi::funcs::blockReduceSum<T>(sum_val, FULL_MASK); | ||
if (threadIdx.x == 0) { | ||
out[blockIdx.x] = sum_val; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void ReduceMaxWithSubtract(const T* x, | ||
const T* y, | ||
T* out, | ||
int64_t N) { | ||
T max_val = -1e10f; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
max_val = max(max_val, abs(x[i] - y[i])); | ||
} | ||
|
||
__syncthreads(); | ||
max_val = phi::funcs::blockReduceMax<T>(max_val, FULL_MASK); | ||
if (threadIdx.x == 0) { | ||
out[blockIdx.x] = max_val; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void ReduceMinWithSubtract(const T* x, | ||
const T* y, | ||
T* out, | ||
int64_t N) { | ||
T min_val = 1e10f; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
min_val = min(min_val, abs(x[i] - y[i])); | ||
} | ||
|
||
__syncthreads(); | ||
min_val = phi::funcs::blockReduceMin(min_val, FULL_MASK); | ||
if (threadIdx.x == 0) { | ||
out[blockIdx.x] = min_val; | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void DistKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
const DenseTensor& y, | ||
float p, | ||
DenseTensor* out) { | ||
DenseTensor intermediate; | ||
const T* x_ptr = x.data<T>(); | ||
const T* y_ptr = y.data<T>(); | ||
T* o_ptr = dev_ctx.template Alloc<T>(out); | ||
auto stream = dev_ctx.stream(); | ||
|
||
auto xdim = x.dims(); | ||
if (xdim == y.dims()) { // same shape | ||
auto n = x.numel(); | ||
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); | ||
intermediate.Resize(phi::make_ddim({config.block_per_grid.x})); | ||
T* i_ptr = dev_ctx.template Alloc<T>(&intermediate); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line258-261 这段代码在 line300-303 重复使用了,可以归纳在 if 语句之外. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分代码虽然功能上重复使用了,但是输入不一样。前者是原始输入x,获得其numel,后者是substract之后的tensor,可能由于broadcast之后numel就不一样了。这里建立if分支,主要是为了优化不需要广播计算substract时的性能。如果归纳在if之外,就每次都需要计算一次substract,那么会带来不必要的性能开销。如果我理解错误,希望老师纠正。 @JamesLim-sy |
||
|
||
std::vector<int64_t> axis_dims = {static_cast<int64_t>(-1)}; | ||
std::vector<int> reduce_axis = | ||
funcs::details::GetReduceDim(axis_dims, xdim.size(), true); | ||
|
||
if (p == 0) { | ||
ReduceSumWithSubtract<T> | ||
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>( | ||
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>()); | ||
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis); | ||
|
||
} else if (p == INFINITY) { | ||
ReduceMaxWithSubtract<T> | ||
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>( | ||
x_ptr, y_ptr, i_ptr, n); | ||
phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>( | ||
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis); | ||
|
||
} else if (p == -INFINITY) { | ||
ReduceMinWithSubtract<T> | ||
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>( | ||
x_ptr, y_ptr, i_ptr, n); | ||
|
||
phi::funcs::ReduceKernel<T, T, kps::MinFunctor, kps::IdentityFunctor<T>>( | ||
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis); | ||
|
||
} else { | ||
T p_order = static_cast<T>(p); | ||
ReduceSumWithSubtract<T> | ||
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>( | ||
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order)); | ||
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis); | ||
|
||
const DenseTensor* tmp_norm = out; | ||
std::vector<const DenseTensor*> ins = {tmp_norm}; | ||
std::vector<DenseTensor*> outs = {out}; | ||
T p_order_ = static_cast<T>(1. / p_order); | ||
phi::funcs::ElementwiseKernel<T>( | ||
dev_ctx, ins, &outs, PowFunctor<T>(p_order_)); | ||
} | ||
|
||
} else { | ||
auto t = Subtract<T, Context>(dev_ctx, x, y); | ||
PNormKernel<T, Context>(dev_ctx, t, p, -1, 1e-12, false, true, out); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体看下来,优化后与优化前主要有以下几点区别:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZzSean 已按照建议,对不同shape时使用pnorm,测试下来确实pnorm差不多或更优一些。对于相同shape,按照新的方式确实有一定的提升。