diff --git a/paddle/fluid/framework/ir/onednn/conv_activation_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/onednn/conv_activation_onednn_fuse_pass.cc index 483554fbb81890..434bff293f5eb7 100644 --- a/paddle/fluid/framework/ir/onednn/conv_activation_onednn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/onednn/conv_activation_onednn_fuse_pass.cc @@ -122,7 +122,8 @@ void ConvActivationOnednnFusePass::FuseConvConcatAct( } bool is_not_conv_onednn = - !(prev_op_nodes[0]->Op()->GetAttrIfExists("use_mkldnn")); + !(prev_op_nodes[0]->Op()->GetAttrIfExists("use_mkldnn") || + prev_op_nodes[0]->Op()->GetAttrIfExists("use_onednn")); if ((prev_op_nodes[0]->Op()->Type() != "conv2d" && prev_op_nodes[0]->Op()->Type() != "fused_conv2d") || is_not_conv_onednn) { diff --git a/paddle/fluid/framework/ir/onednn/conv_affine_channel_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/onednn/conv_affine_channel_onednn_fuse_pass.cc index e5024ae307c679..c63b8fd960d545 100644 --- a/paddle/fluid/framework/ir/onednn/conv_affine_channel_onednn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/onednn/conv_affine_channel_onednn_fuse_pass.cc @@ -288,7 +288,9 @@ void ConvAffineChannelFusePass::FuseConvAffineChannel( desc.SetOutput("Out", std::vector({ac_out->Name()})); desc.SetType("elementwise_add"); desc.SetAttr("axis", 1); - desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists("use_mkldnn")); + desc.SetAttr("use_onednn", + conv->Op()->GetAttrIfExists("use_mkldnn") || + conv->Op()->GetAttrIfExists("use_onednn")); auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. diff --git a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_placement_pass_tester.cc b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_placement_pass_tester.cc index bf3ac6c20b5abd..034d36b0790264 100644 --- a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_placement_pass_tester.cc @@ -25,11 +25,11 @@ void SetOp(ProgramDesc* prog, const std::vector& inputs, const std::vector& outputs, const std::string& mkldnn_data_type = "float32", - const bool use_mkldnn = true) { + const bool use_onednn = true) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - if (type != "reshape2") op->SetAttr("use_mkldnn", use_mkldnn); + if (type != "reshape2") op->SetAttr("use_onednn", use_onednn); op->SetAttr("mkldnn_data_type", mkldnn_data_type); if (type == "conv2d") { diff --git a/paddle/fluid/framework/ir/onednn/cpu_quantize_placement_pass_tester.cc b/paddle/fluid/framework/ir/onednn/cpu_quantize_placement_pass_tester.cc index bd5db7c0e3df21..7f0a863fa478c3 100644 --- a/paddle/fluid/framework/ir/onednn/cpu_quantize_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/onednn/cpu_quantize_placement_pass_tester.cc @@ -30,7 +30,7 @@ void SetOp(ProgramDesc* prog, auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - op->SetAttr("use_mkldnn", true); + op->SetAttr("use_onednn", true); op->SetAttr("mkldnn_data_type", mkldnn_data_type); if (type == "conv2d") { diff --git a/paddle/fluid/framework/ir/onednn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/onednn/cpu_quantize_squash_pass_tester.cc index 592aa2aa009643..a02f9387b11a8a 100644 --- a/paddle/fluid/framework/ir/onednn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/onednn/cpu_quantize_squash_pass_tester.cc @@ -34,7 +34,7 @@ void SetOp(ProgramDesc* prog, bool is_negative_input = true) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - op->SetAttr("use_mkldnn", use_onednn); + op->SetAttr("use_onednn", use_onednn); op->SetAttr("name", name); if (type != "dropout" && type != "quantize" && type != "dequantize") { op->SetAttr("mkldnn_data_type", onednn_data_type); diff --git a/paddle/fluid/framework/ir/onednn/depthwise_conv_onednn_pass.cc b/paddle/fluid/framework/ir/onednn/depthwise_conv_onednn_pass.cc index 62b398463d91e7..45c0e77329a781 100644 --- a/paddle/fluid/framework/ir/onednn/depthwise_conv_onednn_pass.cc +++ b/paddle/fluid/framework/ir/onednn/depthwise_conv_onednn_pass.cc @@ -80,7 +80,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { auto* pattern = gpd.mutable_pattern(); pattern->NewNode("depthwise_conv") ->assert_is_op("depthwise_conv2d") - ->assert_op_attr("use_mkldnn", true); + ->assert_op_attr_or("use_mkldnn", "use_onednn", true); int found_depthwise_conv_onednn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, diff --git a/paddle/fluid/framework/ir/onednn/fc_onednn_pass.cc b/paddle/fluid/framework/ir/onednn/fc_onednn_pass.cc index f120dd282b861f..6011d1d708b568 100644 --- a/paddle/fluid/framework/ir/onednn/fc_onednn_pass.cc +++ b/paddle/fluid/framework/ir/onednn/fc_onednn_pass.cc @@ -43,9 +43,10 @@ void FCONEDNNPass::ApplyImpl(ir::Graph* graph) const { int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "Handle FC MKL-DNN pass"; - if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn"))) { - VLOG(3) << "do not enable FC MKL-DNN because it doesn't have use_mkldnn " + VLOG(4) << "Handle FC ONE-DNN pass"; + if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn")) && + !(graph->Has("use_onednn") && graph->Get("use_onednn"))) { + VLOG(3) << "do not enable FC ONE-DNN because it doesn't have use_onednn " "attribute."; return; } @@ -68,7 +69,7 @@ void FCONEDNNPass::ApplyImpl(ir::Graph* graph) const { "2, 3 & 4, or when width or height is different than one."; return; } - desc->SetAttr("use_mkldnn", true); + desc->SetAttr("use_onednn", true); found_fc_count++; }; diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 3025d3f2ff27b8..e7e6c41eb6ea27 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -398,7 +398,7 @@ class TEST_API OperatorBase { VariableNameMap outputs_; AttributeMap attrs_; // NOTE: runtime_attrs_ contains the attributes which used for dispatching - // kernel (use_mkldnn, use_cudnn, ...) or passing additional configuration + // kernel (use_onednn, use_cudnn, ...) or passing additional configuration // for special heterogeneous kernel (workspace_size_MB, ...). // The attributes in runtime_attrs_ are set by framework (such as PASS), // and not in the python api.