Skip to content

Commit

Permalink
Add checkpoint and add the check of add_residual when pre_layer_norm …
Browse files Browse the repository at this point in the history
…is false.
  • Loading branch information
Xreki committed Jun 14, 2022
1 parent 85d5041 commit de26128
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -656,3 +657,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);

REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
5 changes: 5 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);

REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
5 changes: 5 additions & 0 deletions paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {

const T* residual_ptr = add_residual ? x.data<T>() : nullptr;
if (!pre_layer_norm) {
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), residual_ptr, linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
Expand Down

0 comments on commit de26128

Please sign in to comment.