Skip to content

Commit

Permalink
add compat check for skip_layernorm (#33505)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Jun 16, 2021
1 parent 34b79d9 commit 4ddd595
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
return;
}

if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "skip_layernorm pass in op compat failed.";
return;
}

VLOG(4) << "handle SkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
Expand Down
43 changes: 43 additions & 0 deletions paddle/fluid/framework/ir/skip_layernorm_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,49 @@ class Graph;

class SkipLayerNormFusePass : public FusePassBase {
public:
SkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();

AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}

virtual ~SkipLayerNormFusePass() {}

protected:
Expand Down

0 comments on commit 4ddd595

Please sign in to comment.