Skip to content

Commit 1a0beda

Browse files
authored
fc_lstm_fuse_pass.cc modify use_mkldnn [fluid_ops] (#74550)
1 parent a418cd0 commit 1a0beda

File tree

5 files changed

+22
-12
lines changed

5 files changed

+22
-12
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
195195
Node* cell,
196196
Node* xx,
197197
Node* fc_bias,
198-
const bool use_mkldnn) {
198+
const bool use_onednn) {
199199
OpDesc op_desc;
200200
op_desc.SetType("fusion_lstm");
201201
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
@@ -235,7 +235,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
235235
op_desc.SetOutput("XX", {xx->Name()});
236236
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
237237
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
238-
op_desc.SetAttr("use_mkldnn", use_mkldnn);
238+
op_desc.SetAttr("use_onednn", use_onednn);
239239
// TODO(TJ): get from attr
240240
op_desc.SetAttr("use_seq", true);
241241

@@ -300,8 +300,9 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
300300
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
301301
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
302302
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
303-
const bool use_mkldnn =
304-
(mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
303+
const bool use_onednn =
304+
((mul->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
305+
mul->Op()->GetAttrIfExists<bool>("use_onednn")) &&
305306
lstm->Op()->GetAttrIfExists<std::string>("gate_activation") ==
306307
"sigmoid" &&
307308
lstm->Op()->GetAttrIfExists<std::string>("cell_activation") ==
@@ -323,7 +324,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
323324
Cell,
324325
fc_out,
325326
fc_bias,
326-
use_mkldnn);
327+
use_onednn);
327328
// Remove unneeded nodes.
328329
std::unordered_set<const Node*> marked_nodes(
329330
{mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
@@ -339,7 +340,7 @@ int FCLstmFusePass::BuildFusion(Graph* graph,
339340
Cell,
340341
fc_out,
341342
nullptr,
342-
use_mkldnn);
343+
use_onednn);
343344
// Remove unneeded nodes.
344345
std::unordered_set<const Node*> marked_nodes(
345346
{mul, lstm, BatchGate, BatchCellPreAct});

paddle/fluid/framework/ir/fuse_pass_base.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,16 @@ void FusePassBase::AddStatis(int count_of_fused) const {
5858
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
5959
const Node& node2) const {
6060
#ifdef PADDLE_WITH_DNNL
61-
bool node1_onednn = node1.Op()->HasAttr("use_mkldnn") &&
62-
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_mkldnn"));
63-
bool node2_onednn = node2.Op()->HasAttr("use_mkldnn") &&
64-
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_mkldnn"));
61+
bool node1_onednn =
62+
(node1.Op()->HasAttr("use_mkldnn") &&
63+
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_mkldnn"))) ||
64+
(node1.Op()->HasAttr("use_onednn") &&
65+
PADDLE_GET_CONST(bool, node1.Op()->GetAttr("use_onednn")));
66+
bool node2_onednn =
67+
(node2.Op()->HasAttr("use_mkldnn") &&
68+
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_mkldnn"))) ||
69+
(node2.Op()->HasAttr("use_onednn") &&
70+
PADDLE_GET_CONST(bool, node2.Op()->GetAttr("use_onednn")));
6571
if (node1_onednn && node2_onednn)
6672
return FUSE_ONEDNN;
6773
else if (!node1_onednn && !node2_onednn)

paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
444444
if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) {
445445
desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn"));
446446
}
447+
if (matmul_v2_op->Op()->HasAttr("use_onednn")) {
448+
desc.SetAttr("use_onednn", matmul_v2_op->Op()->GetAttr("use_onednn"));
449+
}
447450
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
448451
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
449452
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));

paddle/fluid/framework/op_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class TEST_API OpDesc {
235235
// attribute name => all original attrs
236236
AttributeMap attrs_;
237237
// runtime_attrs_ contains the attributes which used for dispatching kernel
238-
// (use_mkldnn, use_cudnn, ...) or passing additional configuration for
238+
// (use_onednn, use_cudnn, ...) or passing additional configuration for
239239
// special heterogeneous kernel (workspace_size_MB, ...).
240240
// The attributes in runtime_attrs_ are set by framework (such as PASS),
241241
// and not in the python api.

paddle/phi/kernels/fusion/onednn/fc_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ void RunKernel(const phi::OneDNNContext& dev_ctx,
411411
const paddle::optional<DenseTensor>& bias,
412412
const int in_num_col_dims,
413413
const std::string& activation_type,
414-
const bool use_mkldnn,
414+
const bool use_onednn,
415415
const bool padding_weights,
416416
const bool use_quantizer,
417417
const std::string& mkldnn_data_type,

0 commit comments

Comments
 (0)