Skip to content

Commit

Permalink
[cherry pick] add op: fused_feedforward(backward) (#36730)
Browse files Browse the repository at this point in the history
* add op: fused_feedforward(backward) (#35611)

这个PR是fused_feedforward反向的代码

相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias

fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。

* Move fused_attention and fused_feedforward functional api path to incubate (#36704)

将 #35905#35843 PR中新增的的python api接口移到incubate目录下。
  • Loading branch information
zhangkaihuo authored Oct 26, 2021
1 parent 5b357e0 commit 76c1bae
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 34 deletions.
2 changes: 0 additions & 2 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)


op_library(fused_feedforward_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_feedforward);\n")

# fused_attention_op
op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
Expand Down
147 changes: 146 additions & 1 deletion paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,154 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout1_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedFeedForwardGrad");

OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FusedFeedForwardGrad");

auto d_out_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), d_out_dim);
if (ctx->HasOutput(framework::GradVarName("Ln1Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln1Scale"),
ctx->GetInputDim("Ln1Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln1Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln1Bias"),
ctx->GetInputDim("Ln1Bias"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
ctx->SetOutputDim(framework::GradVarName("Linear1Weight"),
ctx->GetInputDim("Linear1Weight"));
if (ctx->HasOutput(framework::GradVarName("Linear1Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Linear1Bias"),
ctx->GetInputDim("Linear1Bias"));
}
ctx->SetOutputDim(framework::GradVarName("Linear2Weight"),
ctx->GetInputDim("Linear2Weight"));
if (ctx->HasOutput(framework::GradVarName("Linear2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Linear2Bias"),
ctx->GetInputDim("Linear2Bias"));
}
}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

template <typename T>
class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_feedforward_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Linear1Weight", this->Input("Linear1Weight"));
op->SetInput("Linear1Bias", this->Input("Linear1Bias"));
op->SetInput("Linear2Weight", this->Input("Linear2Weight"));
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Dropout1Mask", this->Output("Dropout1Mask"));
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));

op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Ln1Scale"),
this->InputGrad("Ln1Scale"));
op->SetOutput(framework::GradVarName("Ln1Bias"),
this->InputGrad("Ln1Bias"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
op->SetOutput(framework::GradVarName("Linear1Weight"),
this->InputGrad("Linear1Weight"));
op->SetOutput(framework::GradVarName("Linear1Bias"),
this->InputGrad("Linear1Bias"));
op->SetOutput(framework::GradVarName("Linear2Weight"),
this->InputGrad("Linear2Weight"));
if (this->HasInput("Linear2Bias")) {
op->SetInput("Linear2Bias", this->Input("Linear2Bias"));
op->SetOutput(framework::GradVarName("Linear2Bias"),
this->InputGrad("Linear2Bias"));
}

op->SetAttrMap(this->Attrs());
}
};

template <typename T>
class FusedFeedForwardOpDoubleGradMaker
: public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpMaker);
ops::FusedFeedForwardOpMaker,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);
211 changes: 211 additions & 0 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,210 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
public:
void MatMulGrad(const platform::CUDADeviceContext& ctx,
const framework::Tensor& d_out, const framework::Tensor& a,
const framework::Tensor& b, framework::Tensor* d_a,
framework::Tensor* d_b) const {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto a_2d = FoldInitDims(a);
auto b_2d = FoldInitDims(b);
auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, true);
auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, true);
auto mat_dim_dout = math::CreateMatrixDescriptor(d_out.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0));
blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0));
}

void FFNGrad(
const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor& ln1_out,
const framework::Tensor& dropout1_out,
const framework::Tensor& dropout2_out,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance,
const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance,
framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
framework::Tensor* d_ln1_beta, framework::Tensor* d_ln2_gamma,
framework::Tensor* d_ln2_beta, const int bsz_seq, const int d_model,
const int dim_feedforward, const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2, const std::string& act_method,
const bool pre_layer_norm, const float epsilon1, const float epsilon2,
const platform::CUDADeviceContext& ctx) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
ctx, bsz_seq, dim_feedforward, dropout_param1);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);

auto place = ctx.GetPlace();
using U = LayerNormParamType<T>;
const U* ln1_gamma_ptr =
ln1_gamma == nullptr ? nullptr : ln1_gamma->data<U>();
const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data<U>();
const U* ln2_gamma_ptr =
ln2_gamma == nullptr ? nullptr : ln2_gamma->data<U>();
const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
T* d_linear1_bias_ptr =
d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data<T>();
T* d_linear2_bias_ptr =
d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data<T>();
U* d_ln1_gamma_ptr =
d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data<U>();
U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data<U>();
U* d_ln2_gamma_ptr =
d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data<U>();
U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data<U>();

framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>({bsz_seq, d_model}, place);

if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_mask.data<uint8_t>(),
d_linear2_out.data<T>(), d_residual.data<T>(), d_linear2_bias_ptr);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean.data<U>(),
ln2_variance.data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>());
}

framework::Tensor d_dropout1_out;
d_dropout1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
MatMulGrad(ctx, d_linear2_out, dropout1_out, linear2_weight,
&d_dropout1_out, d_linear2_weight);

framework::Tensor d_linear1_out;
d_linear1_out.mutable_data<T>({bsz_seq, dim_feedforward}, place);
fused_act_dropout_helper.DropoutActBiasGrad(
ctx, d_dropout1_out.data<T>(), linear1_out.data<T>(), linear1_bias_ptr,
dropout1_mask.data<uint8_t>(), d_linear1_out.data<T>(),
d_linear1_bias_ptr, act_method);

if (pre_layer_norm) {
framework::Tensor d_ln1_out;
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place);
MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out,
d_linear1_weight);

pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data<T>(), x.data<T>(),
ln1_gamma_ptr, ln1_mean.data<U>(),
ln1_variance.data<U>(), d_x->data<T>(),
d_ln1_gamma_ptr, d_ln1_beta_ptr);
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
}
}

void Compute(const framework::ExecutionContext& context) const override {
using U = LayerNormParamType<T>;
auto d_out =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto x = *context.Input<framework::Tensor>("X");
auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask");
auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask");
auto linear1_out = *context.Input<framework::Tensor>("Linear1Out");
auto ln1_out = *context.Input<framework::Tensor>("Ln1Out");
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
auto ln1_mean = *context.Input<framework::Tensor>("Ln1Mean");
auto ln1_variance = *context.Input<framework::Tensor>("Ln1Variance");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto ln2_mean = *context.Input<framework::Tensor>("Ln2Mean");
auto ln2_variance = *context.Input<framework::Tensor>("Ln2Variance");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");

auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* d_ln1_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Scale"));
auto* d_ln1_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Bias"));
auto* d_ln2_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Scale"));
auto* d_ln2_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Bias"));
auto* d_linear1_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Weight"));
auto* d_linear1_bias = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Bias"));
auto* d_linear2_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear2Weight"));
auto* d_linear2_bias = context.Output<framework::Tensor>(
framework::GradVarName("Linear2Bias"));

const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);

auto place = context.GetPlace();
d_x->mutable_data<T>(place);
if (d_ln1_scale) {
d_ln1_scale->mutable_data<U>(place);
}
if (d_ln1_bias) {
d_ln1_bias->mutable_data<U>(place);
}
if (d_ln2_scale) {
d_ln2_scale->mutable_data<U>(place);
}
if (d_ln2_bias) {
d_ln2_bias->mutable_data<U>(place);
}
if (d_linear1_bias) {
d_linear1_bias->mutable_data<T>(place);
}
if (d_linear2_bias) {
d_linear2_bias->mutable_data<T>(place);
}
d_linear1_weight->mutable_data<T>(place);
d_linear2_weight->mutable_data<T>(place);

auto x_dim = x.dims();
auto mat_dim_x =
math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false);

auto linear1_weight_dim = linear1_weight.dims();
int d_model = linear1_weight_dim[0];
int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;

FFNGrad(d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
dropout1_out, dropout2_out, linear1_weight, linear1_bias,
linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
dim_feedforward, dropout_param1, dropout_param2, act_method,
pre_layer_norm, epsilon1, epsilon2, context.cuda_device_context());
}
};
} // namespace operators
} // namespace paddle

Expand All @@ -181,3 +385,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
fused_feedforward_grad,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
1 change: 0 additions & 1 deletion python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
endforeach()

if(NOT WITH_GPU)

LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
endif()
Expand Down
Loading

0 comments on commit 76c1bae

Please sign in to comment.