From 8c618f525a20ee8323ab3fec24fb8383561af485 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 13 Aug 2025 15:10:04 +0800 Subject: [PATCH] Fix --- paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc | 2 +- paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc | 5 +++-- paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc index a13e2f7fdb798b..c56253074a09c3 100644 --- a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass_tester.cc @@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog, const std::string& onednn_data_type = "float32") { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - op->SetAttr("use_mkldnn", use_onednn); + op->SetAttr("use_onednn", use_onednn); op->SetAttr("name", name); if (type == "conv2d") { diff --git a/paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc b/paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc index 2659df8e830b41..f707166b514a46 100644 --- a/paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc +++ b/paddle/fluid/framework/ir/onednn/interpolate_onednn_pass.cc @@ -31,7 +31,8 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL(graph, common::errors::InvalidArgument( "Pointer to graph argument should not be NULL.")); - if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn"))) { + if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn")) && + !(graph->Has("use_onednn") && graph->Get("use_onednn"))) { VLOG(3) << "Do not handle interpolate_onednn_pass"; return; } @@ -53,7 +54,7 @@ void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const { interpolate_op_types.end(), node->Name()) != interpolate_op_types.end()) { auto* op_desc = node->Op(); - op_desc->SetAttr("use_mkldnn", true); + op_desc->SetAttr("use_onednn", true); ++found_count; } } diff --git a/paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc b/paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc index 9634ca0759c436..509dd0278a7445 100644 --- a/paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/onednn/multi_gru_fuse_pass.cc @@ -186,7 +186,7 @@ MultiGRUFusePass::MultiGRUFusePass() { .AddAttr("origin_mode") .IsType() .End() - .AddAttr("use_mkldnn") + .AddAttr("use_onednn") .IsType() .End() .AddAttr("mkldnn_data_type")