From e53706ca89d79ca58d95bb24aea1ea0177b7e0d7 Mon Sep 17 00:00:00 2001 From: jiweibo Date: Thu, 10 Jun 2021 10:32:50 +0000 Subject: [PATCH 1/3] add compat check for skip_layernorm --- .../framework/ir/skip_layernorm_fuse_pass.cc | 5 +++ .../framework/ir/skip_layernorm_fuse_pass.h | 39 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc index 232e1d8da4ded..3c851f13b4d4d 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -129,6 +129,11 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { return; } + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "skip_layernorm pass in op compat failed."; + return; + } + VLOG(4) << "handle SkipLayerNorm fuse"; GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h index 3a3e50052396a..fba357473cfae 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -33,6 +33,45 @@ class Graph; class SkipLayerNormFusePass : public FusePassBase { public: + SkipLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") // unconstrained + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") // unconstrained + .End() + .AddAttr("begin_norm_axis") // unconstrained + .End(); + } + virtual ~SkipLayerNormFusePass() {} protected: From dee3c307f56b80b95bef1ce56162da6ac165c3dd Mon Sep 17 00:00:00 2001 From: jiweibo Date: Wed, 16 Jun 2021 02:46:42 +0000 Subject: [PATCH 2/3] update attr --- paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h index fba357473cfae..4aaaf8f639205 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -44,7 +44,8 @@ class SkipLayerNormFusePass : public FusePassBase { .AddOutput("Out") .IsTensor() .End() - .AddAttr("axis") // unconstrained + .AddAttr("axis") + .IsNumEQ(0) .End(); AddOpCompat(OpCompat("layer_norm")) @@ -66,9 +67,12 @@ class SkipLayerNormFusePass : public FusePassBase { .AddOutput("Variance") .IsTensor() .End() - .AddAttr("epsilon") // unconstrained + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) .End() - .AddAttr("begin_norm_axis") // unconstrained + .AddAttr("begin_norm_axis") + .IsNumGT(0) .End(); } From a5078267f68d7f34aa32c85407c46d44197637e4 Mon Sep 17 00:00:00 2001 From: jiweibo Date: Wed, 16 Jun 2021 09:04:24 +0000 Subject: [PATCH 3/3] update attr --- paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h index 4aaaf8f639205..804d0abdd6f06 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h @@ -45,7 +45,7 @@ class SkipLayerNormFusePass : public FusePassBase { .IsTensor() .End() .AddAttr("axis") - .IsNumEQ(0) + .IsIntIn({0, -1}) .End(); AddOpCompat(OpCompat("layer_norm"))