Skip to content

Commit

Permalink
Support optional residual add in fused_attention and fused_feedforwar…
Browse files Browse the repository at this point in the history
…d. (PaddlePaddle#43474)

* Support optional residual add in fused_attention and fused_feedforward.

* Add checkpoint and add the check of add_residual when pre_layer_norm is false.

* Add TODO and change the python api to add add_residual argument.
  • Loading branch information
Xreki authored and sneaxiy committed Jun 27, 2022
1 parent e97369a commit 10fd3ae
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 137 deletions.
10 changes: 10 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -378,6 +379,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
Expand Down Expand Up @@ -655,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);

REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
61 changes: 33 additions & 28 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,26 +246,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr,
bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data,
ln_mean_2_ptr, ln_var_2_ptr);
}
}
};
Expand Down Expand Up @@ -419,16 +425,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int output_size = 3 * hidden_size;
int input_size = dim_embed;

bool add_residual = ctx.Attr<bool>("add_residual");
Tensor d_residual;
d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims);
d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
}

bool transA = false;
bool transB = true;
bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
Expand Down Expand Up @@ -539,16 +546,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
}
// gradient accumulation
std::vector<const Tensor *> ins;
std::vector<Tensor *> outs;
ins.emplace_back(&d_residual);
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
phi::funcs::AddFunctor<T>());

if (add_residual) {
// gradient accumulation
std::vector<const Tensor *> ins = {&d_residual, d_x};
std::vector<Tensor *> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
phi::funcs::AddFunctor<T>());
}
}
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ class FusedDropoutHelper {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place, d_residual, cuda_place, d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
if (d_residual) {
memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
}
}

// out = dropout(activation(src + bias))
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>("ring_id", "ring id for tensor model parallel.")
.SetDefault(-1);
AddComment(R"DOC(
Expand Down Expand Up @@ -367,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);

REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
84 changes: 48 additions & 36 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
}

void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight,
void FFN(const platform::CUDADeviceContext& ctx, const framework::Tensor& x,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* linear2_bias,
Expand All @@ -84,10 +85,9 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
framework::Tensor* dropout1_out, framework::Tensor* dropout2_out,
const int bsz_seq, const int d_model, const int dim_feedforward,
const std::string& act_method, const bool pre_layer_norm,
const float epsilon1, const float epsilon2, const int ring_id,
const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2,
const platform::CUDADeviceContext& ctx) const {
const float epsilon1, const float epsilon2, const bool add_residual,
const int ring_id, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
Expand Down Expand Up @@ -127,15 +127,22 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(linear2_out, ring_id, ctx);

const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) {
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(),
ln2_variance->data<U>());
} else {
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
out->data<T>(), dropout2_mask->data<uint8_t>());
}
}
Expand Down Expand Up @@ -183,6 +190,7 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const int ring_id = context.Attr<int>("ring_id");
const bool add_residual = context.Attr<bool>("add_residual");

DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
Expand Down Expand Up @@ -214,12 +222,12 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask,
dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance,
linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model,
dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2,
ring_id, dropout_param1, dropout_param2, context.cuda_device_context());
FFN(context.cuda_device_context(), *x, *linear1_weight, linear1_bias,
*linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
out, dropout1_mask, dropout2_mask, ln1_mean, ln1_variance, ln2_mean,
ln2_variance, linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq,
d_model, dim_feedforward, act_method, pre_layer_norm, epsilon1,
epsilon2, add_residual, ring_id, dropout_param1, dropout_param2);
}
};

Expand All @@ -243,8 +251,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
}

void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask,
const platform::CUDADeviceContext& ctx, const framework::Tensor& d_out,
const framework::Tensor& x, const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out,
Expand All @@ -264,7 +272,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const int ring_id, const platform::CUDADeviceContext& ctx) const {
const bool add_residual, const int ring_id) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
Expand Down Expand Up @@ -296,19 +304,22 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);

T* d_residual_ptr = nullptr;
if (add_residual) {
d_residual_ptr = d_residual.mutable_data<T>(d_x->dims(), place);
}
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr);
d_linear2_out.data<T>(), d_residual_ptr, d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>());
d_residual_ptr);
}

framework::Tensor d_dropout1_out;
Expand Down Expand Up @@ -339,14 +350,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx);
}
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1);
ins[0] = &d_residual;
ins[1] = d_x;
outs[0] = d_x;
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, elewise_add_axis, phi::funcs::AddFunctor<T>());

if (add_residual) {
// gradient accumulation
std::vector<const Tensor*> ins = {&d_residual, d_x};
std::vector<Tensor*> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs,
phi::funcs::AddFunctor<T>());
}
}

void Compute(const framework::ExecutionContext& context) const override {
Expand Down Expand Up @@ -410,6 +421,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {

const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool add_residual = context.Attr<bool>("add_residual");
const int ring_id = context.Attr<int>("ring_id");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
Expand Down Expand Up @@ -447,15 +459,15 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
dropout1_out, dropout2_out, linear1_weight, linear1_bias,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
dim_feedforward, dropout_param1, dropout_param2, act_method,
pre_layer_norm, epsilon1, epsilon2, ring_id,
context.cuda_device_context());
FFNGrad(context.cuda_device_context(), d_out, x, dropout1_mask,
dropout2_mask, linear1_out, ln1_out, dropout1_out, dropout2_out,
linear1_weight, linear1_bias, linear2_weight, ln1_scale, ln1_bias,
ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance,
d_x, d_linear1_weight, d_linear1_bias, d_linear2_weight,
d_linear2_bias, d_ln1_scale, d_ln1_bias, d_ln2_scale, d_ln2_bias,
bsz_seq, d_model, dim_feedforward, dropout_param1, dropout_param2,
act_method, pre_layer_norm, epsilon1, epsilon2, add_residual,
ring_id);
}
};
} // namespace operators
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/operators/fused/fused_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,12 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
if (residual == dst) return;
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T),
ctx.stream());
if (residual) {
memory::Copy(ctx.GetPlace(), dst, ctx.GetPlace(), residual,
rows * cols * sizeof(T), ctx.stream());
} else {
SetZero<T>(ctx, dst, rows * cols);
}
if (!is_test) {
SetZero<MaskType>(ctx, mask_data, rows * cols);
}
Expand Down
Loading

0 comments on commit 10fd3ae

Please sign in to comment.