Skip to content

Commit 4bbe1a6

Browse files
authored
add FLAGS_use_onednn [fluid_ops] (#74315)
* Fix * fix
1 parent 3beb3b3 commit 4bbe1a6

File tree

20 files changed

+68
-36
lines changed

20 files changed

+68
-36
lines changed

paddle/common/flags.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,16 @@ PHI_DEFINE_EXPORTED_bool(
696696
*/
697697
PHI_DEFINE_EXPORTED_bool(use_mkldnn, false, "Use MKLDNN to run");
698698

699+
/**
700+
* ONEDNN related FLAG
701+
* Name: use_onednn
702+
* Since Version:
703+
* Value Range: bool, default=false
704+
* Example:
705+
* Note:
706+
*/
707+
PHI_DEFINE_EXPORTED_bool(use_onednn, false, "Use ONEDNN to run");
708+
699709
/**
700710
* Debug related FLAG
701711
* Name: FLAGS_call_stack_level

paddle/fluid/eager/to_static/run_program_impl.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
COMMON_DECLARE_bool(enable_pir_with_pt_in_dy2st);
4747
COMMON_DECLARE_bool(enable_pir_in_executor);
4848
COMMON_DECLARE_bool(use_mkldnn);
49+
COMMON_DECLARE_bool(use_onednn);
4950
COMMON_DECLARE_bool(specialize_device_in_dy2st);
5051
COMMON_DECLARE_bool(parameters_persistent_mode_in_dy2st);
5152

@@ -673,7 +674,8 @@ std::vector<paddle::Tensor> RunProgramImpl(
673674
}
674675

675676
#ifdef PADDLE_WITH_DNNL
676-
if (FLAGS_use_mkldnn) paddle::platform::DontClearONEDNNCache(place);
677+
if (FLAGS_use_mkldnn || FLAGS_use_onednn)
678+
paddle::platform::DontClearONEDNNCache(place);
677679
#endif
678680
return out;
679681
}
@@ -1014,7 +1016,8 @@ void LegacyRunProgramImpl(
10141016
}
10151017

10161018
#ifdef PADDLE_WITH_DNNL
1017-
if (FLAGS_use_mkldnn) paddle::platform::DontClearONEDNNCache(place);
1019+
if (FLAGS_use_mkldnn || FLAGS_use_onednn)
1020+
paddle::platform::DontClearONEDNNCache(place);
10181021
#endif
10191022
}
10201023

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323

2424
PD_DECLARE_bool(convert_all_blocks);
2525
COMMON_DECLARE_bool(use_mkldnn);
26+
COMMON_DECLARE_bool(use_onednn);
2627
#ifdef PADDLE_WITH_CINN
2728
PD_DECLARE_bool(use_cinn);
2829
#endif
@@ -203,22 +204,23 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
203204

204205
void AppendPassToSetMkldnnAttr(const std::string &pass_name) {
205206
#ifdef PADDLE_WITH_DNNL
206-
if (FLAGS_use_mkldnn) {
207+
if (FLAGS_use_mkldnn || FLAGS_use_onednn) {
207208
AppendPass(pass_name);
208209
} else if (!strategy_.onednn_enabled_op_types_.empty()) {
209-
VLOG(1) << "mkldnn_enabled_op_types specify the operator type list to "
210-
"use MKLDNN acceleration. It is null in default, means "
211-
"that all the operators supported by MKLDNN will be "
210+
VLOG(1) << "onednn_enabled_op_types specify the operator type list to "
211+
"use ONEDNN acceleration. It is null in default, means "
212+
"that all the operators supported by ONEDNN will be "
212213
"accelerated. And it should not be set when "
213-
"FLAGS_use_mkldnn=false.";
214+
"FLAGS_use_onednn=false.";
214215
}
215216
#else
216-
PADDLE_ENFORCE_NE(FLAGS_use_mkldnn,
217-
true,
218-
common::errors::PreconditionNotMet(
219-
"FLAGS_use_mkldnn has been set to True, but "
220-
"PaddlePaddle is compiled without MKLDNN. "
221-
"Please compile PaddlePaddle with MKLDNN first."));
217+
PADDLE_ENFORCE_NE(
218+
FLAGS_use_mkldnn || FLAGS_use_onednn,
219+
true,
220+
common::errors::PreconditionNotMet(
221+
"FLAGS_use_mkldnn or FLAGS_use_onednn has been set to True, but "
222+
"PaddlePaddle is compiled without ONEDNN. "
223+
"Please compile PaddlePaddle with ONEDNN first."));
222224
#endif
223225
}
224226

paddle/fluid/framework/details/build_strategy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ struct BuildStrategy {
123123
bool fuse_dot_product_attention_{false};
124124
// Fuse ResUnit
125125
bool fuse_resunit_{false};
126-
// mkldnn_enabled_op_types specify the operator type list to
126+
// onednn_enabled_op_types specify the operator type list to
127127
// use OneDNN acceleration. It is null in default, means
128128
// that all the operators supported by OneDNN will be
129129
// accelerated. And it should not be set when
130-
// FLAGS_use_mkldnn=false
130+
// FLAGS_use_onednn=false
131131
std::unordered_set<std::string> onednn_enabled_op_types_;
132132

133133
// By default, memory_optimize would be opened if gc is disabled, and

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License. */
3232

3333
COMMON_DECLARE_bool(benchmark);
3434
COMMON_DECLARE_bool(use_mkldnn);
35+
COMMON_DECLARE_bool(use_onednn);
3536

3637
namespace paddle::framework {
3738
namespace {
@@ -184,7 +185,7 @@ void Executor::Run(const ProgramDesc& pdesc,
184185
phi::RecordEvent record_run(
185186
"Executor::Run", phi::TracerEventType::UserDefined, 1);
186187
platform::RecordBlock b(block_id);
187-
if (FLAGS_use_mkldnn) EnableONEDNN(pdesc);
188+
if (FLAGS_use_mkldnn || FLAGS_use_onednn) EnableONEDNN(pdesc);
188189
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
189190
#ifdef PADDLE_WITH_DNNL
190191
platform::AttachPointerHashToONEDNNKey(this, place_);
@@ -330,7 +331,7 @@ void Executor::Run(const ProgramDesc& program,
330331
phi::RecordEvent record_run(
331332
"Executor::Run", phi::TracerEventType::UserDefined, 1);
332333
platform::RecordBlock b(kProgramId);
333-
if (FLAGS_use_mkldnn) EnableONEDNN(program);
334+
if (FLAGS_use_mkldnn || FLAGS_use_onednn) EnableONEDNN(program);
334335
#ifdef PADDLE_WITH_DNNL
335336
platform::AttachPointerHashToONEDNNKey(this, place_);
336337
#endif

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#endif
6060

6161
COMMON_DECLARE_bool(use_mkldnn);
62+
COMMON_DECLARE_bool(use_onednn);
6263
COMMON_DECLARE_bool(check_nan_inf);
6364
COMMON_DECLARE_string(static_runtime_data_save_path);
6465
COMMON_DECLARE_bool(save_static_runtime_data);
@@ -344,7 +345,7 @@ void CreateAllOps(const framework::BlockDesc& block,
344345
op_base->SetRuntimeAttributeMap(op_runtime_attr_map);
345346

346347
#ifdef PADDLE_WITH_DNNL
347-
if (FLAGS_use_mkldnn) {
348+
if (FLAGS_use_mkldnn || FLAGS_use_onednn) {
348349
if (op->HasAttr("use_mkldnn")) {
349350
VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
350351
op_base->SetAttr("use_mkldnn", true);

paddle/fluid/framework/operator.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,10 +2330,11 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
23302330
auto kernel_iter = kernels.find(expected_kernel_key);
23312331

23322332
#ifdef PADDLE_WITH_DNNL
2333-
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
2333+
// workaround for missing ONEDNN kernel when FLAGS_use_mkldnn or
2334+
// FLAGS_use_onednn env var is set
23342335
if (kernel_iter == kernels.end() &&
23352336
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
2336-
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
2337+
VLOG(3) << "missing ONEDNN kernel: fallbacking to PLAIN one";
23372338
expected_kernel_key.library_type_ = LibraryType::kPlain;
23382339
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
23392340
kernel_iter = kernels.find(expected_kernel_key);

paddle/fluid/imperative/layer.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#endif
3232

3333
COMMON_DECLARE_bool(use_mkldnn);
34+
COMMON_DECLARE_bool(use_onednn);
3435
namespace paddle::imperative {
3536

3637
using framework::Variable;
@@ -228,7 +229,8 @@ void VarBase::ClearGradient(bool set_to_zero) {
228229
auto* grad_t = grad_var_->MutableVar()->GetMutable<phi::SelectedRows>();
229230
if (grad_t->mutable_value()->IsInitialized()) {
230231
#ifdef PADDLE_WITH_DNNL
231-
if (FLAGS_use_mkldnn) platform::ClearONEDNNCache(grad_t->place());
232+
if (FLAGS_use_mkldnn || FLAGS_use_onednn)
233+
platform::ClearONEDNNCache(grad_t->place());
232234
#endif
233235
grad_t->mutable_rows()->clear();
234236
grad_t->mutable_value()->clear();
@@ -246,7 +248,8 @@ void VarBase::ClearGradient(bool set_to_zero) {
246248
grad_t->clear();
247249
}
248250
#ifdef PADDLE_WITH_DNNL
249-
if (FLAGS_use_mkldnn) platform::ClearONEDNNCache(grad_t->place());
251+
if (FLAGS_use_mkldnn || FLAGS_use_onednn)
252+
platform::ClearONEDNNCache(grad_t->place());
250253
#endif
251254
}
252255
}

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ PreparedOp PrepareImpl(
176176
// OneDNN variant of code reads attributes in some of GetKernelTypeForVar and
177177
// GetKernelType functions, so we need to copy the attributes there.
178178
// Const qualifier of Attrs had to be discarded to overwrite it.
179-
if (FLAGS_use_mkldnn) {
179+
if (FLAGS_use_mkldnn || FLAGS_use_onednn) {
180180
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
181181
mutable_op_attrs = default_attrs;
182182
for (auto& attr : attrs) {

paddle/fluid/imperative/prepared_operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "paddle/phi/core/vocab/string_array.h"
3838

3939
COMMON_DECLARE_bool(use_mkldnn);
40+
COMMON_DECLARE_bool(use_onednn);
4041

4142
namespace paddle {
4243
namespace imperative {

0 commit comments

Comments
 (0)