Skip to content

Commit bef99e0

Browse files
committed
Fix
1 parent 9667405 commit bef99e0

File tree

5 files changed

+16
-8
lines changed

5 files changed

+16
-8
lines changed

paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog,
2828
const std::string& onednn_data_type = "float32") {
2929
auto* op = prog->MutableBlock(0)->AppendOp();
3030
op->SetType(type);
31-
op->SetAttr("use_mkldnn", use_onednn);
31+
op->SetAttr("use_onednn", use_onednn);
3232
op->SetAttr("name", name);
3333

3434
if (type == "conv2d") {

paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ MultiGRUFusePass::MultiGRUFusePass() {
186186
.AddAttr("origin_mode")
187187
.IsType<bool>()
188188
.End()
189-
.AddAttr("use_mkldnn")
189+
.AddAttr("use_onednn")
190190
.IsType<bool>()
191191
.End()
192192
.AddAttr("mkldnn_data_type")

paddle/fluid/framework/ir/onednn/shuffle_channel_onednn_detect_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void ShuffleChannelMKLDNNDetectPass::ApplyImpl(ir::Graph* graph) const {
208208
new_op_desc.SetOutput("Out", {output_name});
209209

210210
new_op_desc.SetAttr("group", group);
211-
new_op_desc.SetAttr("use_mkldnn", true);
211+
new_op_desc.SetAttr("use_onednn", true);
212212
new_op_desc.Flush();
213213

214214
// Create a new node for the fused op.

paddle/fluid/framework/ir/onednn/shuffle_channel_onednn_detect_pass_tester.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ void MainTest() {
6666
for (const auto* node : graph->Nodes()) {
6767
if (node->IsOp() && node->Op()->Type() == "shuffle_channel") {
6868
const auto* op = node->Op();
69-
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
70-
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
69+
ASSERT_TRUE(op->HasAttr("use_mkldnn") || op->HasAttr("use_onednn"));
70+
EXPECT_TRUE((op->HasAttr("use_mkldnn") &&
71+
PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))) ||
72+
(op->HasAttr("use_onednn") &&
73+
PADDLE_GET_CONST(bool, op->GetAttr("use_onednn"))));
7174
}
7275
}
7376
}

paddle/fluid/framework/ir/onednn/squeeze2_transpose2_onednn_fuse_pass.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,15 @@ void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
4747
GET_IR_NODE_FROM_SUBGRAPH(
4848
transpose2_op, transpose2_op, squeeze2_transpose2_pattern);
4949

50-
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
50+
bool use_mkldnn_not =
51+
!transpose2_op->Op()->HasAttr("use_mkldnn") ||
5152
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
52-
!(PADDLE_GET_CONST(bool,
53-
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
53+
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn"))));
54+
bool use_onednn_not =
55+
!transpose2_op->Op()->HasAttr("use_onednn") ||
56+
(transpose2_op->Op()->HasAttr("use_onednn") &&
57+
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_onednn"))));
58+
if (use_mkldnn_not && use_onednn_not) {
5459
VLOG(4) << "Only oneDNN version of transpose2 can be fused after with "
5560
"squeeze2.";
5661
return;

0 commit comments

Comments
 (0)