Skip to content

Commit 7db3dff

Browse files
authored
executor add use_onednn [fluid_ops] (#74326)
* Fix * Fix * Fix
1 parent ee2d0e5 commit 7db3dff

File tree

9 files changed

+32
-11
lines changed

9 files changed

+32
-11
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,10 @@ void Executor::EnableONEDNN(const ProgramDesc& program) {
598598
for (size_t bid = 0; bid < program.Size(); ++bid) {
599599
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
600600
for (auto* op : block->AllOps()) {
601-
if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op))
601+
if (FoundOneDNNKernel(op) || FoundPhiOneDNNKernel(op)) {
602602
op->SetAttr("use_mkldnn", true);
603+
op->SetAttr("use_onednn", true);
604+
}
603605
}
604606
}
605607
#else

paddle/fluid/framework/ir/op_compat_sensible_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ std::unordered_set<std::string> global_extra_attrs = {
3232
"is_test",
3333
"use_mkldnn",
3434
"mkldnn_data_type",
35+
"use_onednn",
36+
"onednn_data_type",
3537
"use_quantizer",
3638
"use_cudnn",
3739
"name",

paddle/fluid/framework/new_executor/interpreter/data_transfer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
357357
attr_map["out_dtype"] = static_cast<int>(out_dtype);
358358
// NOTE(Aurelius84): In which case use_mkldnn = true?
359359
attr_map["use_mkldnn"] = false;
360+
attr_map["use_onednn"] = false;
360361

361362
// 3. Create cast op
362363
std::string op_type("cast");

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ void CreateAllOps(const framework::BlockDesc& block,
349349
VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
350350
op_base->SetAttr("use_mkldnn", true);
351351
}
352+
if (op->HasAttr("use_onednn")) {
353+
VLOG(4) << "Set use_onednn=True for " << op_base->Type();
354+
op_base->SetAttr("use_onednn", true);
355+
}
352356
}
353357
#endif
354358

paddle/fluid/framework/new_executor/interpreter/static_build.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
124124
// in_black_list = (kernelCode >> 5) & 1
125125
// is_operator_base = (kernelCode >> 4) & 1
126126
// is_custom_op = (kernelCode >> 3) & 1
127-
// use_mkldnn = (kernelCode >> 2) & 1
127+
// use_onednn = (kernelCode >> 2) & 1
128128
// sub_block_can_not_static_build = (kernelCode >> 1) & 1
129129
using KernelCode = int8_t;
130130
std::set<std::pair<std::string, KernelCode>> invalid_ops;
@@ -150,6 +150,12 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
150150
use_mkldnn = attr.index() == 1 ? PADDLE_GET_CONST(int, attr)
151151
: PADDLE_GET_CONST(bool, attr);
152152
}
153+
bool use_onednn = use_mkldnn;
154+
if (!use_mkldnn && op->HasAttr("use_onednn")) {
155+
Attribute attr = op->GetAttr("use_onednn");
156+
use_onednn = attr.index() == 1 ? PADDLE_GET_CONST(int, attr)
157+
: PADDLE_GET_CONST(bool, attr);
158+
}
153159

154160
bool sub_block_can_not_static_build = false;
155161
if (op->HasAttr("sub_block")) {
@@ -160,9 +166,9 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
160166

161167
KernelCode kernel_code = static_cast<KernelCode>(
162168
(in_black_list << 5) + (is_operator_base << 4) + (is_custom_op << 3) +
163-
(use_mkldnn << 2) + (sub_block_can_not_static_build << 1));
169+
(use_onednn << 2) + (sub_block_can_not_static_build << 1));
164170

165-
if (in_black_list || is_operator_base || is_custom_op || use_mkldnn ||
171+
if (in_black_list || is_operator_base || is_custom_op || use_onednn ||
166172
sub_block_can_not_static_build) {
167173
invalid_ops.insert(std::make_pair(op_type, kernel_code));
168174
}

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1599,7 +1599,8 @@ bool OperatorWithKernel::SupportsKernelType(
15991599

16001600
bool OperatorWithKernel::CanONEDNNBeUsed(const framework::ExecutionContext& ctx,
16011601
phi::DataType data_type) const {
1602-
return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") &&
1602+
return ((ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn")) ||
1603+
(ctx.HasAttr("use_onednn") && ctx.Attr<bool>("use_onednn"))) &&
16031604
phi::is_cpu_place(ctx.GetPlace()) && this->SupportsONEDNN(data_type);
16041605
}
16051606

paddle/fluid/framework/phi_utils.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
199199
for (int i = 0; i < op_proto_->attrs_size(); ++i) {
200200
auto& attr = op_proto_->attrs()[i];
201201
auto& attr_name = attr.name();
202-
if (attr_name == "use_mkldnn" || attr_name == "use_cudnn" ||
203-
attr_name == "op_role" || attr_name == "op_role_var" ||
204-
attr_name == "op_namescope" || attr_name == "op_callstack" ||
205-
attr_name == "op_device") {
202+
if (attr_name == "use_mkldnn" || attr_name == "use_onednn" ||
203+
attr_name == "use_cudnn" || attr_name == "op_role" ||
204+
attr_name == "op_role_var" || attr_name == "op_namescope" ||
205+
attr_name == "op_callstack" || attr_name == "op_device") {
206206
continue;
207207
}
208208
if ((attr.has_extra() && attr.extra()) ||

paddle/fluid/imperative/tracer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,12 @@ void Tracer::TraceOpImpl(const std::string& type,
246246
if (!FLAGS_tracer_onednn_ops_on.empty()) {
247247
auto is_on = FLAGS_tracer_onednn_ops_on.find(type) != std::string::npos;
248248
attrs["use_mkldnn"] = is_on;
249+
attrs["use_onednn"] = is_on;
249250
} else {
250251
// if ops_on list is empty all ops are enabled except types from off_list
251252
auto is_off = FLAGS_tracer_onednn_ops_off.find(type) != std::string::npos;
252253
attrs["use_mkldnn"] = !is_off;
254+
attrs["use_onednn"] = !is_off;
253255
}
254256
}
255257

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) {
262262
}
263263
#ifdef PADDLE_WITH_DNNL
264264
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
265+
op_desc.GetAttrIfExists<bool>("use_onednn") ||
265266
paddle::dialect::IsOneDNNOnlyOp(op_desc.Type())) {
266267
if (!HasOpInfo(ctx, op_desc, kOneDNNTargetDialectPrefix)) {
267268
VLOG(3) << op_desc.Type()
@@ -1838,7 +1839,8 @@ struct MulOpTranscriber : public OpTranscriber {
18381839
const OpDesc& op_desc,
18391840
pir::Block* block) override {
18401841
#ifdef PADDLE_WITH_DNNL
1841-
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
1842+
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
1843+
op_desc.GetAttrIfExists<bool>("use_onednn")) {
18421844
return static_cast<OpTranscriber>(*this).operator()( // NOLINT
18431845
ctx,
18441846
param_map,
@@ -2015,7 +2017,8 @@ struct MulGradOpTranscriber : public OpTranscriber {
20152017
const OpDesc& op_desc,
20162018
pir::Block* block) override {
20172019
#ifdef PADDLE_WITH_DNNL
2018-
if (op_desc.GetAttrIfExists<bool>("use_mkldnn")) {
2020+
if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
2021+
op_desc.GetAttrIfExists<bool>("use_onednn")) {
20192022
return static_cast<OpTranscriber>(*this).operator()( // NOLINT
20202023
ctx,
20212024
param_map,

0 commit comments

Comments
 (0)