Skip to content

Commit

Permalink
Add fused attention op backward and python layer. (#36498) (#36752)
Browse files Browse the repository at this point in the history
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
  • Loading branch information
limin2021 authored Oct 27, 2021
1 parent 211cf20 commit 64643d5
Show file tree
Hide file tree
Showing 8 changed files with 952 additions and 54 deletions.
199 changes: 198 additions & 1 deletion paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

class FusedAttentionGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));

OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");

if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));

ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_attention_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}

op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));

// use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut"));
op->SetInput("LnMean", this->Output("LnMean"));
op->SetInput("LnVariance", this->Output("LnVariance"));
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));

op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));

// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
this->OutputGrad("TransposeOut2"));
op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
op->SetOutput(framework::GradVarName("SoftmaxOut"),
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));

op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionOpMaker);
ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
Loading

0 comments on commit 64643d5

Please sign in to comment.