Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional residual add in fused_attention and fused_feedforward. #43474

Merged
merged 4 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -378,6 +379,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
Expand Down Expand Up @@ -655,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);

REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
61 changes: 33 additions & 28 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,26 +246,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr,
bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data,
ln_mean_2_ptr, ln_var_2_ptr);
}
}
};
Expand Down Expand Up @@ -419,16 +425,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int output_size = 3 * hidden_size;
int input_size = dim_embed;

bool add_residual = ctx.Attr<bool>("add_residual");
Tensor d_residual;
d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims);
d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
}

bool transA = false;
bool transB = true;
bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
Expand Down Expand Up @@ -539,16 +546,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
}
// gradient accumulation
std::vector<const Tensor *> ins;
std::vector<Tensor *> outs;
ins.emplace_back(&d_residual);
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
phi::funcs::AddFunctor<T>());

if (add_residual) {
// gradient accumulation
std::vector<const Tensor *> ins = {&d_residual, d_x};
std::vector<Tensor *> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
phi::funcs::AddFunctor<T>());
}
}
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ class FusedDropoutHelper {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place, d_residual, cuda_place, d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
if (d_residual) {
memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
}
}

// out = dropout(activation(src + bias))
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>("ring_id", "ring id for tensor model parallel.")
.SetDefault(-1);
AddComment(R"DOC(
Expand Down Expand Up @@ -367,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);

REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
84 changes: 48 additions & 36 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
}

void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight,
void FFN(const platform::CUDADeviceContext& ctx, const framework::Tensor& x,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* linear2_bias,
Expand All @@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework::Tensor* dropout1_out, framework::Tensor* dropout2_out,
const int bsz_seq, const int d_model, const int dim_feedforward,
const std::string& act_method, const bool pre_layer_norm,
const float epsilon1, const float epsilon2, const int ring_id,
const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2,
const platform::CUDADeviceContext& ctx) const {
const float epsilon1, const float epsilon2, const bool add_residual,
const int ring_id, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
Expand Down Expand Up @@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(linear2_out, ring_id, ctx);

const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(),
ln2_variance->data<U>());
} else {
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
out->data<T>(), dropout2_mask->data<uint8_t>());
}
}
Expand Down Expand Up @@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const int ring_id = context.Attr<int>("ring_id");
const bool add_residual = context.Attr<bool>("add_residual");

DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
Expand Down Expand Up @@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask,
dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance,
linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model,
dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2,
ring_id, dropout_param1, dropout_param2, context.cuda_device_context());
FFN(context.cuda_device_context(), *x, *linear1_weight, linear1_bias,
*linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
out, dropout1_mask, dropout2_mask, ln1_mean, ln1_variance, ln2_mean,
ln2_variance, linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq,
d_model, dim_feedforward, act_method, pre_layer_norm, epsilon1,
epsilon2, add_residual, ring_id, dropout_param1, dropout_param2);
}
};

Expand All @@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
}

void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask,
const platform::CUDADeviceContext& ctx, const framework::Tensor& d_out,
const framework::Tensor& x, const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out,
Expand All @@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const int ring_id, const platform::CUDADeviceContext& ctx) const {
const bool add_residual, const int ring_id) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
Expand Down Expand Up @@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);

T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual_ptr = d_residual.mutable_data<T>(d_x->dims(), place);
}
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr);
d_linear2_out.data<T>(), d_residual_ptr, d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>());
d_residual_ptr);
}

framework::Tensor d_dropout1_out;
Expand Down Expand Up @@ -339,14 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx);
}
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1);
ins[0] = &d_residual;
ins[1] = d_x;
outs[0] = d_x;
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, elewise_add_axis, phi::funcs::AddFunctor<T>());

if (add_residual) {
// gradient accumulation
std::vector<const Tensor*> ins = {&d_residual, d_x};
std::vector<Tensor*> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs,
phi::funcs::AddFunctor<T>());
}
}

void Compute(const framework::ExecutionContext& context) const override {
Expand Down Expand Up @@ -410,6 +421,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {

const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool add_residual = context.Attr<bool>("add_residual");
const int ring_id = context.Attr<int>("ring_id");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
Expand Down Expand Up @@ -447,15 +459,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
dropout1_out, dropout2_out, linear1_weight, linear1_bias,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
dim_feedforward, dropout_param1, dropout_param2, act_method,
pre_layer_norm, epsilon1, epsilon2, ring_id,
context.cuda_device_context());
FFNGrad(context.cuda_device_context(), d_out, x, dropout1_mask,
dropout2_mask, linear1_out, ln1_out, dropout1_out, dropout2_out,
linear1_weight, linear1_bias, linear2_weight, ln1_scale, ln1_bias,
ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance,
d_x, d_linear1_weight, d_linear1_bias, d_linear2_weight,
d_linear2_bias, d_ln1_scale, d_ln1_bias, d_ln2_scale, d_ln2_bias,
bsz_seq, d_model, dim_feedforward, dropout_param1, dropout_param2,
act_method, pre_layer_norm, epsilon1, epsilon2, add_residual,
ring_id);
}
};
} // namespace operators
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/fused/fused_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
if (residual == dst) return;
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T),
ctx.stream());
if (residual) {
memory::Copy(ctx.GetPlace(), dst, ctx.GetPlace(), residual,
rows * cols * sizeof(T), ctx.stream());
} else {
SetZero<T>(ctx, dst, rows * cols);
}
if (!is_test) {
SetZero<MaskType>(ctx, mask_data, rows * cols);
}
Expand Down
Loading