Skip to content

Commit fa6a3bd

Browse files
committed
Merge branch 'i43' into i41
2 parents 9befe50 + fbbbe2d commit fa6a3bd

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const {
3131
PADDLE_ENFORCE_NOT_NULL(graph,
3232
common::errors::InvalidArgument(
3333
"Pointer to graph argument should not be NULL."));
34-
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
34+
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn")) &&
35+
!(graph->Has("use_onednn") && graph->Get<bool>("use_onednn"))) {
3536
VLOG(3) << "Do not handle interpolate_onednn_pass";
3637
return;
3738
}
@@ -53,7 +54,7 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const {
5354
interpolate_op_types.end(),
5455
node->Name()) != interpolate_op_types.end()) {
5556
auto* op_desc = node->Op();
56-
op_desc->SetAttr("use_mkldnn", true);
57+
op_desc->SetAttr("use_onednn", true);
5758
++found_count;
5859
}
5960
}

paddle/fluid/framework/ir/onednn/operator_unsqueeze2_onednn_fuse_pass.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,15 @@ void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
5656
GET_IR_NODE_FROM_SUBGRAPH(
5757
unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern);
5858

59-
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
59+
bool use_mkldnn_not =
60+
!operator_op->Op()->HasAttr("use_mkldnn") ||
6061
(operator_op->Op()->HasAttr("use_mkldnn") &&
61-
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
62+
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))));
63+
bool use_onednn_not =
64+
!operator_op->Op()->HasAttr("use_onednn") ||
65+
(operator_op->Op()->HasAttr("use_onednn") &&
66+
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_onednn"))));
67+
if (use_mkldnn_not && use_onednn_not) {
6268
VLOG(4) << "Only oneDNN version of " << op_type
6369
<< "can be fused with unsqueeze2.";
6470
return;

0 commit comments

Comments
 (0)