Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 3 No.31】为 Paddle 优化 dist op 在 GPU 上的计算性能 #44946

Merged
merged 16 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions paddle/phi/kernels/dist_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,3 @@ void DistKernel(const Context& dev_ctx,
} // namespace phi

PD_REGISTER_KERNEL(dist, CPU, ALL_LAYOUT, phi::DistKernel, float, double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体看下来,优化后与优化前主要有以下几点区别:

  1. x与y的shape相同时,将减法与后续的reduce相当于做了个融合,理论上看融合肯定是有性能收益的,那这部分我觉得可以单独拿出来保留下来;
  2. x与y的shape不同时,其实逻辑跟原本的pnorm基本一致,但是使用的还是重写的一些reduce kernel,这部分性能是否有提升还需要测试验证。如果确认有性能提升,那么我觉得这部分优化可以直接用于pnorm中,正好这次算子优化中也有pnorm,这样可以减少一些重复工作;如果这部分性能提升不理想,那我觉得为了简化代码可以还沿用之前的直接调pnorm的写法。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZzSean 已按照建议,对不同shape时使用pnorm,测试下来确实pnorm差不多或更优一些。对于相同shape,按照新的方式确实有一定的提升。

#endif
176 changes: 176 additions & 0 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/p_norm_kernel.h"

namespace phi {

#define FULL_MASK 0xffffffff

template <typename T>
struct ZeroOrderFunctor {
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
}
};

template <typename T>
struct OtherOrderFunctor {
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
}

private:
T p_order_;
};

template <typename T>
struct PowFunctor {
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
}
T p_order_;
};

template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
T sum_val = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += func(x[i], y[i]);
}

__syncthreads();
sum_val = phi::funcs::blockReduceSum<T>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = sum_val;
}
}

template <typename T>
__global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = -1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = max(max_val, abs(x[i] - y[i]));
}

__syncthreads();
max_val = phi::funcs::blockReduceMax<T>(max_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = max_val;
}
}

template <typename T>
__global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = 1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = min(min_val, abs(x[i] - y[i]));
}

__syncthreads();
min_val = phi::funcs::blockReduceMin(min_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = min_val;
}
}

template <typename T, typename Context>
void DistKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
float p,
DenseTensor* out) {
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
auto stream = dev_ctx.stream();

auto xdim = x.dims();
if (xdim == y.dims()) { // same shape
auto n = x.numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
intermediate.Resize(phi::make_ddim({config.block_per_grid.x}));
T* i_ptr = dev_ctx.template Alloc<T>(&intermediate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line258-261 这段代码在 line300-303 重复使用了,可以归纳在 if 语句之外.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分代码虽然功能上重复使用了,但是输入不一样。前者是原始输入x,获得其numel,后者是substract之后的tensor,可能由于broadcast之后numel就不一样了。这里建立if分支,主要是为了优化不需要广播计算substract时的性能。如果归纳在if之外,就每次都需要计算一次substract,那么会带来不必要的性能开销。如果我理解错误,希望老师纠正。 @JamesLim-sy


std::vector<int64_t> axis_dims = {static_cast<int64_t>(-1)};
std::vector<int> reduce_axis =
funcs::details::GetReduceDim(axis_dims, xdim.size(), true);

if (p == 0) {
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n);
phi::funcs::ReduceKernel<T, T, kps::MaxFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else if (p == -INFINITY) {
ReduceMinWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n);

phi::funcs::ReduceKernel<T, T, kps::MinFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else {
T p_order = static_cast<T>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
T p_order_ = static_cast<T>(1. / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
}

} else {
auto t = Subtract<T, Context>(dev_ctx, x, y);
PNormKernel<T, Context>(dev_ctx, t, p, -1, 1e-12, false, true, out);
}
}

} // namespace phi

PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}