diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index a286c39f7f8db..6c4ac318264e8 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class FusedAttentionGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->Attrs().Get("attn_dropout_is_test"), false, + platform::errors::InvalidArgument( + "GradOp is only callable when attn_dropout_is_test is false")); + + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedAttentionGrad"); + if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), + ctx->GetInputDim("Ln2Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), + ctx->GetInputDim("Ln2Bias")); + } + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", + "FusedAttentionGrad"); + if (ctx->Attrs().Get("pre_layer_norm") == true) { + OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", + "FusedAttentionGrad"); + } + OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + "FusedAttentionGrad"); + + if (ctx->HasOutput(framework::GradVarName("LnScale"))) { + ctx->SetOutputDim(framework::GradVarName("LnScale"), + ctx->GetInputDim("LnScale")); + } + if (ctx->HasOutput(framework::GradVarName("LnBias"))) { + ctx->SetOutputDim(framework::GradVarName("LnBias"), + ctx->GetInputDim("LnBias")); + } + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), + ctx->GetInputDim("OutLinearBias")); + ctx->SetOutputDim(framework::GradVarName("OutLinearW"), + ctx->GetInputDim("OutLinearW")); + ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); + ctx->SetOutputDim(framework::GradVarName("QKVBias"), + ctx->GetInputDim("QKVBias")); + + ctx->SetOutputDim(framework::GradVarName("LnOut"), + ctx->GetInputDim("LnOut")); + ctx->SetOutputDim(framework::GradVarName("FMHAOut"), + ctx->GetInputDim("FMHAOut")); + ctx->SetOutputDim(framework::GradVarName("QKTVOut"), + ctx->GetInputDim("QKTVOut")); + ctx->SetOutputDim(framework::GradVarName("TransposeOut2"), + ctx->GetInputDim("TransposeOut2")); + ctx->SetOutputDim(framework::GradVarName("QKOut"), + ctx->GetInputDim("QKOut")); + ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), + ctx->GetInputDim("SoftmaxOut")); + ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), + ctx->GetInputDim("AttnDropoutOut")); + ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), + ctx->GetInputDim("SrcMaskOut")); + ctx->SetOutputDim(framework::GradVarName("QKVOut"), + ctx->GetInputDim("QKVOut")); + ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), + ctx->GetInputDim("QKVBiasOut")); + ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), + ctx->GetInputDim("OutLinearOut")); + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("fused_attention_grad"); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + // inputs x, parameters and their grad. + op->SetInput("X", this->Input("X")); + op->SetInput("QKVW", this->Input("QKVW")); + op->SetInput("QKVBias", this->Input("QKVBias")); + op->SetInput("SrcMask", this->Input("SrcMask")); + op->SetInput("OutLinearW", this->Input("OutLinearW")); + op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + if (this->HasInput("LnScale")) { + op->SetInput("LnScale", this->Input("LnScale")); + op->SetOutput(framework::GradVarName("LnScale"), + this->InputGrad("LnScale")); + } + if (this->HasInput("LnBias")) { + op->SetInput("LnBias", this->Input("LnBias")); + op->SetOutput(framework::GradVarName("LnBias"), + this->InputGrad("LnBias")); + } + if (this->HasInput("Ln2Scale")) { + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + } + if (this->HasInput("Ln2Bias")) { + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + } + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); + op->SetOutput(framework::GradVarName("QKVBias"), + this->InputGrad("QKVBias")); + op->SetOutput(framework::GradVarName("OutLinearBias"), + this->InputGrad("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearW"), + this->InputGrad("OutLinearW")); + + // use forward outputs as backward inputs. + op->SetInput("LnOut", this->Output("LnOut")); + op->SetInput("LnMean", this->Output("LnMean")); + op->SetInput("LnVariance", this->Output("LnVariance")); + op->SetInput("QKVOut", this->Output("QKVOut")); + op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetInput("TransposeOut2", this->Output("TransposeOut2")); + op->SetInput("QKOut", this->Output("QKOut")); + op->SetInput("QKTVOut", this->Output("QKTVOut")); + op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); + op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); + op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); + op->SetInput("SrcMaskOut", this->Output("SrcMaskOut")); + op->SetInput("FMHAOut", this->Output("FMHAOut")); + op->SetInput("OutLinearOut", this->Output("OutLinearOut")); + + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); + op->SetInput("BiasDropoutResidualOut", + this->Output("BiasDropoutResidualOut")); + op->SetInput("QKVOut", this->Output("QKVOut")); + + // backward outputs: dinput + op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); + op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); + op->SetOutput(framework::GradVarName("QKVBiasOut"), + this->OutputGrad("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKTVOut"), + this->OutputGrad("QKTVOut")); + op->SetOutput(framework::GradVarName("TransposeOut2"), + this->OutputGrad("TransposeOut2")); + op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut")); + op->SetOutput(framework::GradVarName("SoftmaxOut"), + this->OutputGrad("SoftmaxOut")); + op->SetOutput(framework::GradVarName("AttnDropoutOut"), + this->OutputGrad("AttnDropoutOut")); + op->SetOutput(framework::GradVarName("SrcMaskOut"), + this->OutputGrad("SrcMaskOut")); + op->SetOutput(framework::GradVarName("FMHAOut"), + this->OutputGrad("FMHAOut")); + op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), + this->OutputGrad("BiasDropoutResidualOut")); + op->SetOutput(framework::GradVarName("OutLinearOut"), + this->OutputGrad("OutLinearOut")); + + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, - ops::FusedAttentionOpMaker); + ops::FusedAttentionOpMaker, + ops::FusedAttentionGradOpMaker, + ops::FusedAttentionGradOpMaker); +REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 18a42b5c2cee2..95e690cb17ec1 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -199,6 +199,237 @@ class FusedAttentionOpKernel : public framework::OpKernel { } }; +template +class FusedAttentionGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + const float ln2epsilon = ctx.Attr("ln_epsilon"); + + float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); + bool is_test_1 = ctx.Attr("attn_dropout_is_test"); + auto &dropout_implementation_1 = + ctx.Attr("attn_dropout_implementation"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + int seed_val_1 = ctx.Attr("attn_dropout_seed"); + + // get inputs. + auto *d_y = ctx.Input(framework::GradVarName("Y")); + auto *d_y_data = d_y->data(); + + // fw input + auto *input_x = ctx.Input("X"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_2_scale = ctx.Input("Ln2Scale"); + auto *x_data = input_x->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_2_scale_data = + (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); + // fw parameters. + auto *src_mask = ctx.Input("SrcMask"); + auto *qkv_weight = ctx.Input("QKVW"); + auto *qkv_bias = ctx.Input("QKVBias"); + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); + auto *qkv_weight_data = qkv_weight->data(); + auto *qkv_bias_data = qkv_bias->data(); + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + + // fw output + auto *ln_mean = ctx.Input("LnMean"); + auto *ln_var = ctx.Input("LnVariance"); + auto *ln_out = ctx.Input("LnOut"); + auto *fmha_out = ctx.Input("FMHAOut"); + auto *transpose_out_2 = ctx.Input("TransposeOut2"); + auto *qk_out = ctx.Input("QKOut"); + auto *qktv_out = ctx.Input("QKTVOut"); + auto *softmax_out = ctx.Input("SoftmaxOut"); + auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); + auto *src_mask_out = ctx.Input("SrcMaskOut"); + auto *out_linear_out = ctx.Input("OutLinearOut"); + auto *ln_2_mean = ctx.Input("Ln2Mean"); + auto *ln_2_var = ctx.Input("Ln2Variance"); + auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Input("BiasDropoutResidualOut"); + auto *ln_mean_data = ln_mean->data(); + auto *ln_var_data = ln_var->data(); + auto *ln_out_data = ln_out->data(); + auto *fmha_out_data = fmha_out->data(); + auto *transpose_out_2_data = transpose_out_2->data(); + auto *qk_out_data = qk_out->data(); + auto *qktv_out_data = qktv_out->data(); + auto *softmax_out_data = softmax_out->data(); + auto *src_mask_out_data = src_mask_out->data(); + auto *out_linear_out_data = out_linear_out->data(); + auto *ln_2_mean_data = ln_2_mean->data(); + auto *ln_2_var_data = ln_2_var->data(); + auto *dropout_mask_out_data = dropout_mask_out->data(); + auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); + + // output's grad + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); + auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); + auto *d_qkv_bias_out = + ctx.Output(framework::GradVarName("QKVBiasOut")); + auto *d_qktv_out = ctx.Output(framework::GradVarName("QKTVOut")); + auto *d_transpose_out_2 = + ctx.Output(framework::GradVarName("TransposeOut2")); + auto *d_qk_out = ctx.Output(framework::GradVarName("QKOut")); + auto *d_softmax_out = + ctx.Output(framework::GradVarName("SoftmaxOut")); + auto *d_attn_dropout_out = + ctx.Output(framework::GradVarName("AttnDropoutOut")); + auto *d_src_mask_out = + ctx.Output(framework::GradVarName("SrcMaskOut")); + auto *d_fmha_out = ctx.Output(framework::GradVarName("FMHAOut")); + auto *d_out_linear_out = + ctx.Output(framework::GradVarName("OutLinearOut")); + auto *d_bias_dropout_residual_out = + ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); + auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); + auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); + auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); + auto *d_transpose_out_2_data = + d_transpose_out_2->mutable_data(ctx.GetPlace()); + auto *d_qk_out_data = d_qk_out->mutable_data(ctx.GetPlace()); + auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); + auto *d_attn_dropout_out_data = + d_attn_dropout_out->mutable_data(ctx.GetPlace()); + auto *d_src_mask_out_data = d_src_mask_out->mutable_data(ctx.GetPlace()); + auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); + auto *d_out_linear_out_data = + d_out_linear_out->mutable_data(ctx.GetPlace()); + auto *d_bias_dropout_residual_out_data = + d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + + // parameter grad + auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); + auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); + auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); + auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); + auto *d_out_linear_weight = + ctx.Output(framework::GradVarName("OutLinearW")); + auto *d_out_linear_bias = + ctx.Output(framework::GradVarName("OutLinearBias")); + auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); + auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); + auto *d_ln_scale_data = + (d_ln_scale == nullptr ? nullptr + : d_ln_scale->mutable_data(ctx.GetPlace())); + auto *d_ln_bias_data = + (d_ln_bias == nullptr ? nullptr + : d_ln_bias->mutable_data(ctx.GetPlace())); + auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); + auto *d_out_linear_weight_data = + d_out_linear_weight->mutable_data(ctx.GetPlace()); + auto *d_out_linear_bias_data = + d_out_linear_bias->mutable_data(ctx.GetPlace()); + auto *d_ln_2_scale_data = + (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( + ctx.GetPlace())); + auto *d_ln_2_bias_data = + (d_ln_2_bias == nullptr ? nullptr + : d_ln_2_bias->mutable_data(ctx.GetPlace())); + + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + Tensor d_residual; + d_residual.Resize(input_x_dims); + T *d_residual_data = d_residual.mutable_data(ctx.GetPlace()); + + bool transA = false; + bool transB = true; + bool compute_bias = true; + auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), + epsilon, bsz_seq, dim_embed); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); + AttnDropoutParam attn_dropout_param( + is_test_1, dropout_implementation_1, attn_dropout_prob, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + auto fmha_ref_compute = + FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, + dim_head, attn_dropout_param); + output_size = hidden_size; + transA = false; + transB = false; + compute_bias = false; + auto out_linear_compute = + AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, + output_size, input_size, compute_bias); + DropoutParam dropout_param2(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln2epsilon); + + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, + dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, + d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, + d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); + + out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, + d_out_linear_out_data, d_fmha_out_data, + d_out_linear_weight_data, nullptr); + fmha_ref_compute.ComputeBackward( + *transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_bias_out); + cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, + bsz_seq * 3 * num_head * dim_head * sizeof(T), + cudaMemcpyDeviceToDevice); + + if (pre_layer_norm) { + qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, + d_qkv_bias_out_data, d_ln_out_data, + d_qkv_weight_data, d_qkv_bias_data); + layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, + ln_mean_data, ln_var_data, d_x_data, + d_ln_scale_data, d_ln_bias_data); + } else { + qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data, + d_x_data, d_qkv_weight_data, d_qkv_bias_data); + } + // gradient accumulation + std::vector ins; + std::vector outs; + ins.emplace_back(&d_residual); + ins.emplace_back(d_x); + outs.emplace_back(d_x); + int elewise_add_axis = -1; + LaunchElementwiseCudaKernel( + ctx.cuda_device_context(), ins, &outs, elewise_add_axis, + AddFunctor()); + } +}; + } // namespace operators } // namespace paddle @@ -207,3 +438,7 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel); +REGISTER_OP_CUDA_KERNEL(fused_attention_grad, + ops::FusedAttentionGradKernel, + ops::FusedAttentionGradKernel, + ops::FusedAttentionGradKernel); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 624455d3b148e..919ae418ab19b 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -93,6 +93,7 @@ endforeach() if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) + LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) endif() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 1e0d83f8ac775..7359adff62021 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -34,6 +34,8 @@ def setUp(self): self.generate_input_data() paddle.set_default_dtype(self.x_type) self.__class__.op_type = "fused_attention" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = True self.q_proj = Linear( self.embed_dim, self.embed_dim, @@ -147,7 +149,9 @@ def GetBaselineOut(self): final_out = self.norm1(residual_out) if self.pre_layer_norm: final_out = self.norm2(residual_out) - return final_out + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + return final_out, tensor_query.grad def GetFusedAttentionOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) @@ -196,13 +200,17 @@ def GetFusedAttentionOut(self): ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, out_linear_bias, attn_mask, self.dropout_prob, self.attn_dropout_prob, ln2_epsilon) - return final_out + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) + return final_out, x.grad def test_fused_attention_op(self): - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedAttentionOut() + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) class TestFusedAttentionOpFp16(TestFusedAttentionOp): @@ -226,10 +234,12 @@ def config(self): self.key_length, self.value_length = self.query_length, self.query_length def test_fused_attention_op(self): - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedAttentionOut() + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py new file mode 100644 index 0000000000000..e59ecc19d05cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention +from paddle import tensor +from paddle.fluid import layers +from paddle.static import Program, program_guard +import unittest + + +def fc(x, weight): + return np.matmul(x, weight) + + +def softmax(x): + np.seterr(invalid='ignore') + output = np.zeros(x.shape, dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + x_curr = x[i, j, k, :] + e_x = np.exp(x_curr - np.amax(x_curr)) + output[i, j, k, :] = e_x / np.sum(e_x) + return output + + +def batch_matmul(x, y): + assert x.shape[0] == y.shape[0] + assert x.shape[1] == y.shape[1] + retval = np.zeros( + (x.shape[0], x.shape[1], x.shape[2], y.shape[3]), dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + retval[i, j, :, :] = np.matmul(x[i, j, :, :], y[i, j, :, :]) + return retval + + +def layer_norm(x, has_scale, has_bias, weight, bias, epsilon=1e-05): + batch_size, src_len, d_model = x.shape + x = x.reshape((batch_size * src_len, d_model)) + mu = np.mean(x, axis=1, keepdims=True) + sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model + x1_up = (x - mu) + x1_down_1 = sigma_squar + epsilon + x1_down = np.sqrt(x1_down_1) + x1_down = x1_down.reshape((x1_down.shape[0], 1)) + x1 = x1_up / x1_down + x_scaled = x1 + if (has_scale): + x_scaled = weight * x1 + x_scaled_bias = x_scaled + if (has_bias): + x_scaled_bias = x_scaled + bias + x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model)) + return x_scaled_bias + + +def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, + ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, + out_linear_weight, out_linear_bias): + batch_size = query.shape[0] + seq_len = query.shape[1] + embed_dim = query.shape[2] + + if (pre_layer_norm): + ln_out = layer_norm(query, True, True, ln_scale, ln_bias) + + num_head = qkv_weight.shape[1] + head_dim = qkv_weight.shape[2] + # embed_dim, 3, num_heads, self.head_dim + qkv_weight = qkv_weight.transpose((3, 0, 1, 2)) + qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] * + qkv_weight.shape[2] * qkv_weight.shape[3]) + + if (pre_layer_norm): + ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) + qkv = fc(ln_out, qkv_weight) + ln_out = ln_out.reshape(batch_size, seq_len, embed_dim) + else: + query = query.reshape(batch_size * seq_len, embed_dim) + qkv = fc(query, qkv_weight) + query = query.reshape(batch_size, seq_len, embed_dim) + + qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim) + # q*k^t + qkv = qkv.transpose( + (2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim + qkv = qkv.transpose( + (0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim + + q = qkv[0:1, ::] + q = q.reshape(batch_size, num_head, seq_len, head_dim) + k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] + k = k.reshape(batch_size, num_head, seq_len, head_dim) + v = qkv[2::] + v = v.reshape(batch_size, num_head, seq_len, head_dim) + + k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len] + qkt = batch_matmul(q, k / np.sqrt(head_dim, dtype=np.float64)) + + if attn_mask is not None: + if attn_mask.dtype.name == 'int64': + attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9 + else: + attn_mask = attn_mask.astype(qkt.dtype) + qkt += attn_mask + + # softmax + softmax_out = softmax(qkt) + attn_heads = batch_matmul(softmax_out, v) + + attn_heads = attn_heads.transpose( + (0, 2, 1, 3)) # [batch_size, seq_len, num_head, head_dim] + + # out_linear + out_linear_input = attn_heads.reshape(batch_size, seq_len, + num_head * head_dim) + out_linear_out = fc(out_linear_input, out_linear_weight) + + # bias add, dropout, residual add, layer_norm. + out_linear_bias_out = out_linear_out + out_linear_bias + out_linear_bias_dropout_out = out_linear_bias_out + out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out + out_linear_bias_dropout_residual_ln_out = layer_norm( + out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias) + return out_linear_bias_dropout_residual_ln_out + + +class TestFusedAttentionAPI(unittest.TestCase): + def setUp(self): + self.config() + self.generate_input_data() + + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + self.need_weight = False + + self.batch_size = 1 + self.query_length = 2 + self.head_dim = 2 + self.num_heads = 2 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.key, self.value = self.query, self.query + + def run_imperative(self): + fused_attn = FusedMultiHeadAttention( + self.embed_dim, self.num_heads, self.dropout_prob, + self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, + self.need_weight, self.weight_attr, self.bias_attr) + out = fused_attn( + paddle.to_tensor(self.query), + paddle.to_tensor(self.query), + paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) + ref_out = compute_reference(self.pre_layer_norm, self.query, + self.attn_mask, + fused_attn.pre_ln_scale.numpy(), + fused_attn.pre_ln_bias.numpy(), + fused_attn.ln_scale.numpy(), + fused_attn.ln_bias.numpy(), + fused_attn.qkv_weight.numpy(), + fused_attn.qkv_bias.numpy(), + fused_attn.linear_weight.numpy(), + fused_attn.linear_bias.numpy()) + self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) + + def run_static(self): + fused_attn = FusedMultiHeadAttention( + self.embed_dim, self.num_heads, self.dropout_prob, + self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, + self.need_weight, self.weight_attr, self.bias_attr) + + x = paddle.static.data( + name='X', + shape=[self.batch_size, self.query_length, self.embed_dim], + dtype=self.x_type) + attn_mask = paddle.static.data( + name='SrcMask', + shape=[ + self.batch_size, self.num_heads, self.query_length, + self.key_length + ], + dtype=self.attn_mask_type) + final_out = fused_attn(x, x, x, attn_mask) + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) + + return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(Program()): + out, qkv_weight, qkv_bias, linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = self.run_static( + ) + ref_out = compute_reference(self.pre_layer_norm, self.query, + self.attn_mask, ln_scale, ln_bias, + ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, + linear_weight, linear_bias) + self.assertTrue( + np.allclose( + np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5)) + + def test_dynamic_api(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + self.run_imperative() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py new file mode 100644 index 0000000000000..aada78e4ec6a4 --- /dev/null +++ b/python/paddle/incubate/nn/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 + +__all__ = [ #noqa + 'FusedMultiHeadAttention', +] diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 75bf9f10cef31..68109b4ae694a 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -15,6 +15,7 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid import core, dygraph_utils from paddle import _C_ops __all__ = [] @@ -217,8 +218,8 @@ def fused_multi_head_attention(x, `[batch\_size, sequence\_len, embed\_dim]`. qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. - pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture. - Default False. + pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture + (False). Default False. pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. ln_scale (Tensor, optional): The weight tensor of layernorm. Default None. @@ -228,13 +229,19 @@ def fused_multi_head_attention(x, qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. - attn_mask (Tensor, optional): + attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent positions. It is a tensor + with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the + data type is bool, the unwanted positions have `False` values and the others have `True` values. + When the data type is int, the unwanted positions have 0 values and the others have 1 values. + When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. + It can be None when nothing wanted or needed to be prevented attention to. Default None. dropout_rate (float, optional): The dropout probability used on attention weights to drop some attention targets for the dropout after attention. - 0 for no dropout. Default 0. + 0 for no dropout. Default 0.5. attn_dropout_rate (float, optional): The dropout probability used on attention weights to drop some attention targets for the dropout in attention. - 0 for no dropout. Default 0. + 0 for no dropout. Default 0.5. ln_epsilon (float, optional): Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5. @@ -248,9 +255,9 @@ def fused_multi_head_attention(x, # input: [batch_size, seq_len, embed_dim] x = paddle.rand(shape=(2, 4, 128), dtype="float32") - # qkv_weight: [3, num_head, dim_head, dim_embed] + # qkv_weight: [3, num_head, head_dim, embed_dim] qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - # qkv_bias: [3, num_head, dim_head] + # qkv_bias: [3, num_head, head_dim] qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") # linear_weight: [embed_dim, embed_dim] linear_weight = paddle.rand(shape=(128, 128), dtype="float32") @@ -271,6 +278,12 @@ def fused_multi_head_attention(x, # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out + assert len(qkv_weight.shape + ) == 4, "The dims of the shape of qkv_weight should be 4." + assert qkv_weight.shape[ + 0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." + assert qkv_weight.shape[3] == x.shape[ + 2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', @@ -278,3 +291,95 @@ def fused_multi_head_attention(x, dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', ln_epsilon) return final_out + else: + helper = LayerHelper('fused_multi_head_attention', **locals()) + dtype = x.dtype + # check dtypes + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'fused_multihead_attention') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], + 'fused_multi_head_attention') + + # set inputs + inputs = dict() + inputs['X'] = [x] + if pre_ln_scale: + inputs['LnScale'] = [pre_ln_scale] + if pre_ln_bias: + inputs['LnBias'] = [pre_ln_bias] + inputs['QKVW'] = [qkv_weight] + inputs['QKVBias'] = [qkv_bias] + inputs['SrcMask'] = attn_mask + inputs['OutLinearW'] = [linear_weight] + inputs['OutLinearBias'] = [linear_bias] + if ln_scale: + inputs['Ln2Scale'] = [ln_scale] + if ln_bias: + inputs['Ln2Bias'] = [ln_bias] + + # set attrs + attrs = { + 'pre_layer_norm': pre_layer_norm, + 'epsilon': pre_ln_epsilon, + 'ln_epsilon': ln_epsilon, + 'dropout_rate': dropout_rate, + 'attn_dropout_rate': attn_dropout_rate + } + + # set outputs + pre_ln_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + pre_ln_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + pre_ln_out = helper.create_variable_for_type_inference(dtype=dtype) + + qkv_out = helper.create_variable_for_type_inference(dtype=dtype) + qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype) + + transpose_out = helper.create_variable_for_type_inference(dtype=dtype) + qk_out = helper.create_variable_for_type_inference(dtype=dtype) + qktv_out = helper.create_variable_for_type_inference(dtype=dtype) + softmax_out = helper.create_variable_for_type_inference(dtype=dtype) + attn_dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + attn_dropout_out = helper.create_variable_for_type_inference( + dtype=dtype) + attn_mask_out = helper.create_variable_for_type_inference(dtype=dtype) + fmha_out = helper.create_variable_for_type_inference(dtype=dtype) + out_linear_out = helper.create_variable_for_type_inference(dtype=dtype) + dropout_mask_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + ln_mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + ln_variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + bias_dropout_residual_out = helper.create_variable_for_type_inference( + dtype=dtype) + final_out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='fused_attention', + inputs=inputs, + outputs={ + "LnMean": pre_ln_mean_out, + "LnVariance": pre_ln_variance_out, + "LnOut": pre_ln_out, + "QKVOut": qkv_out, + "QKVBiasOut": qkv_bias_out, + "TransposeOut2": transpose_out, + "QKOut": qk_out, + "QKTVOut": qktv_out, + "SoftmaxOut": softmax_out, + "AttnDropoutMaskOut": attn_dropout_mask_out, + "AttnDropoutOut": attn_dropout_out, + "SrcMaskOut": attn_mask_out, + "FMHAOut": fmha_out, + "OutLinearOut": out_linear_out, + "DropoutMaskOut": dropout_mask_out, + "Ln2Mean": ln_mean_out, + "Ln2Variance": ln_variance_out, + "BiasDropoutResidualOut": bias_dropout_residual_out, + 'Y': final_out + }, + attrs=attrs) + return final_out diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py similarity index 79% rename from python/paddle/nn/layer/fused_transformer.py rename to python/paddle/incubate/nn/layer/fused_transformer.py index 0084f7ff339df..16588dcef3d27 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -12,27 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +from paddle.nn import functional as F +from paddle.incubate.nn import functional as incubate_f +from paddle.nn import Layer +from paddle.framework import ParamAttr +import paddle +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.nn.initializer import Constant + +import collections + class FusedMultiHeadAttention(Layer): """ - Attention mapps queries and a set of key-value pairs to outputs, and + Attention mapps queries and a set of key-value pairs to outputs, and Multi-Head Attention performs multiple parallel attention to jointly attending to information from different representation subspaces. - Please refer to `Attention Is All You Need `_ for more details. - Parameters: embed_dim (int): The expected feature size in the input and output. num_heads (int): The number of heads in multi-head attention. - dropout (float, optional): The dropout probability used on attention - weights to drop some attention targets. 0 for no dropout. Default 0 + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0.5. + attn_dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout in attention. + 0 for no dropout. Default 0.5. kdim (int, optional): The feature size in key. If None, assumed equal to `embed_dim`. Default None. vdim (int, optional): The feature size in value. If None, assumed equal to `embed_dim`. Default None. + normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) + or post_layer_norm architecture (False). Default False. need_weights (bool, optional): Indicate whether to return the attention - weights. Default False. + weights. Now, only False is supported. Default False. weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in :code:`ParamAttr` . @@ -40,35 +55,84 @@ class FusedMultiHeadAttention(Layer): Default: None, which means the default bias parameter property is used. If it is set to False, this layer will not have trainable bias parameter. See usage for details in :code:`ParamAttr` . - Examples: - .. code-block:: python - import paddle - - # encoder input: [batch_size, sequence_length, d_model] + # input: [batch_size, sequence_length, embed_dim] query = paddle.rand((2, 4, 128)) # self attention mask: [batch_size, num_heads, query_len, query_len] attn_mask = paddle.rand((2, 2, 4, 4)) - multi_head_attn = paddle.nn.MultiHeadAttention(128, 2) + multi_head_attn = paddle.incubate.nn.FusedMultiHeadAttention(128, 2) output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] """ - Cache = collections.namedtuple("Cache", ["k", "v"]) - StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) - def __init__(self, embed_dim, num_heads, - dropout=0., + dropout_rate=0.5, + attn_dropout_rate=0.5, kdim=None, vdim=None, + normalize_before=False, need_weights=False, weight_attr=None, - bias_attr=None): + bias_attr=None, + name=None): super(FusedMultiHeadAttention, self).__init__() - raise NotImplementedError() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(num_heads)) + + attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + assert need_weights == False, "Only support need_weight is False now." + + self.qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.linear_weight = self.create_parameter( + shape=[embed_dim, embed_dim], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.linear_bias = self.create_parameter( + shape=[embed_dim], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + + self.pre_ln_scale = self.create_parameter( + attr=self._weight_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.pre_ln_bias = self.create_parameter( + attr=self._bias_attr, shape=[embed_dim], is_bias=True) + self.ln_scale = self.create_parameter( + attr=self._weight_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0)) + self.ln_bias = self.create_parameter( + attr=self._bias_attr, shape=[embed_dim], is_bias=True) + + self.dropout_rate = dropout_rate + self.attn_dropout_rate = attn_dropout_rate + + self.name = name def forward(self, query, key=None, value=None, attn_mask=None, cache=None): """ @@ -97,30 +161,34 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): - It is a namedtuple with `k` and `v` as fields, and stores tensors - shaped `[batch_size, num_heads, length, embed_dim]` which are results - of linear projection, reshape and transpose calculations in - MultiHeadAttention. If it is an instance of `Cache`, `k` and `v` - fields reserve intermediate results of previous positions, which - mostly used for decoder self attention. If it is an instance of - `StaticCache`, `key` and `value` args would be ignored, `k` and - `v` fields would be used as calculated results on `key` and - `value`, which mostly used for decoder-encoder cross attention. - It is only used for inference and should be None for training. - Default None. + Now, only None is supported. Default None. Returns: Tensor|tuple: It is a tensor that has the same shape and data type \ - as `query`, representing attention output. Or a tuple if \ - `need_weights` is True or `cache` is not None. If `need_weights` \ - is True, except for attention output, the tuple also includes \ - the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \ - If `cache` is not None, the tuple then includes the new cache \ - having the same type as `cache`, and if it is `StaticCache`, it \ - is same as the input `cache`, if it is `Cache`, the new cache \ - reserves tensors concatanating raw tensors with intermediate \ - results of current query. + as `query`, representing attention output. """ - raise NotImplementedError() + if attn_mask is not None: + # Support bool or int mask + attn_mask = _convert_attention_mask(attn_mask, query.dtype) + + assert cache == None, "Only support cache is None now." + + out = incubate_f.fused_multi_head_attention( + x=query, + qkv_weight=self.qkv_weight, + linear_weight=self.linear_weight, + pre_layer_norm=self.normalize_before, + pre_ln_scale=self.pre_ln_scale, + pre_ln_bias=self.pre_ln_bias, + ln_scale=self.ln_scale, + ln_bias=self.ln_bias, + pre_ln_epsilon=1e-05, + qkv_bias=self.qkv_bias, + linear_bias=self.linear_bias, + attn_mask=attn_mask, + dropout_rate=self.dropout_rate, + attn_dropout_rate=self.attn_dropout_rate, + ln_epsilon=1e-05) + return out class FusedFeedForward(Layer): @@ -186,7 +254,8 @@ class FusedTransformerEncoderLayer(Layer): Examples: .. code-block:: python - + + # required: gpu import paddle from paddle.nn import TransformerEncoderLayer