Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused elemwise gelu and optimize performance #33480

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 103 additions & 76 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
*mod = dividend_copy % divisor; \
} while (0)

#define DIVUP(x, y) (((x) + (y)-1) / (y))

#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -2152,10 +2156,10 @@ template <typename T, typename CompoundFunctor, bool BcastY,
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int j = blockIdx.x;
int i = threadIdx.x;
int i = blockIdx.x;
int j = threadIdx.x;

while (i < h) {
while (j < w) {
int offset = i * w + j;

T y_val = BcastY ? y[j] : y[offset];
Expand All @@ -2181,7 +2185,7 @@ static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
out[offset] = compound_functor.GetOut(x_val, y_val);
}

i += ELEMWISE_MAX_BLOCK_DIM;
j += ELEMWISE_MAX_BLOCK_DIM;
}
}

Expand All @@ -2192,8 +2196,8 @@ static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, w);
int gird_size = h;
FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
Expand Down Expand Up @@ -2581,106 +2585,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0), inter_val(0);
int64_t tmp_out_idx, x_idx, y_idx;
__shared__ T sdata[BLOCK_Y][BLOCK_X];
size_t idx = threadIdx.x + BLOCK_X * blockIdx.x;
size_t width_stride = gridDim.x * BLOCK_X;

size_t full_w = ROUNDUP(w, BLOCK_X);

T zero = static_cast<T>(0);

do {
int offset = i * w + j;
for (size_t j = idx; j < full_w; j += width_stride) {
T val(0), inter_val(0);
if (j < w) {
for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) {
size_t offset = i * w + j;

tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];
size_t tmp_out_idx = BcastY ? j : offset;
size_t y_idx = BcastY ? j : offset;
size_t x_idx = BcastY ? offset : j;
T x_val = (x == nullptr) ? zero : x[x_idx];
T y_val = (y == nullptr) ? zero : y[y_idx];

if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}

if (dx != nullptr) {
T tmp = UseIntermediateOut
if (dx != nullptr) {
T tmp =
UseIntermediateOut
? dx_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);

if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp =
UseIntermediateOut
? dy_op.UseIntermediateOut(x_val, y_val,
intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
if (d_intermediate != nullptr) {
T tmp = UseIntermediateOut
? dintermediate_op.UseIntermediateOut(
y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dintermediate_op.Recompute(x_val, y_val, out[offset],
dout[offset]);
if (SameShapeOfIntermediateOutAndOut) {
d_intermediate[tmp_out_idx] = tmp;
} else {
inter_val += tmp;
}
}
}
}

i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
// transpose, for ReduceSum with wrap
sdata[threadIdx.y][threadIdx.x] = val;
__syncthreads();
val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
}

h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
size_t idx_j = j + threadIdx.y;
if (BcastY) {
if (dy) {
if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
} else {
if (dx) {
if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
}
}
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;

if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
sdata[threadIdx.y][threadIdx.x] = inter_val;
__syncthreads();
inter_val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
}
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
}
}
}
} // end for
}

template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
gpuStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
const framework::ExecutionContext &ctx, const T *x, const T *y,
const T *intermediate_out, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
T *d_intermediate) {
gpuStream_t stream = ctx.cuda_device_context().stream();

dim3 blocks(BLOCK_X, BLOCK_Y);
int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1);
int theory_block = (w + BLOCK_X - 1) / BLOCK_X;
dim3 grids(std::min(theory_block, max_blocks));

FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
SameShapeOfIntermediateOutAndOut><<<grids, blocks, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
dx, dy, d_intermediate);
}
Expand Down Expand Up @@ -2832,7 +2859,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
ctx, x_data, y_data,
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ static bool IsSupportedCompound(const std::vector<std::string> &functors) {
functors.size(), 2));

static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"};
"sigmoid", "gelu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/fused/fused_elemwise_activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
paddle::operators::math::SigmoidFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "gelu,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::GeluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::GeluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
Expand Down Expand Up @@ -374,6 +381,16 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "gelu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::GeluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::GeluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
Expand Down
58 changes: 58 additions & 0 deletions paddle/fluid/operators/math/functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h"

namespace paddle {
Expand Down Expand Up @@ -130,6 +131,63 @@ struct SigmoidGradFunctor {
}
};

template <typename T>
struct GeluFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T x) {
// this function is tanh approximation of gelu
// actual gelu is:
// x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
MT mx = static_cast<MT>(x);
MT out = mx * static_cast<MT>(0.5) *
(static_cast<MT>(1.0) +
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx)));
return static_cast<T>(out);
}
};

template <typename T>
struct GeluGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T UseX(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseOut(T x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UseOut代表什么?加注释?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是使用Out来计算梯度,这个目前没有用到。

MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
inline HOSTDEVICE T UseXAndOut(T x, T out) {
MT mx = static_cast<MT>(x);
MT tanh_out =
tanh(static_cast<MT>(0.79788456) * mx *
(static_cast<MT>(1) + static_cast<MT>(0.044715) * mx * mx));
MT ans = static_cast<MT>(0.5) * mx *
((static_cast<MT>(1) - tanh_out * tanh_out) *
(static_cast<MT>(0.79788456) +
static_cast<MT>(0.1070322243) * mx * mx)) +
static_cast<MT>(0.5) * (static_cast<MT>(1) + tanh_out);
return static_cast<T>(ans);
}
};

} // namespace math
} // namespace operators
} // namespace paddle
Loading