@@ -23,6 +23,7 @@ limitations under the License. */
2323
2424PD_DECLARE_bool (convert_all_blocks);
2525COMMON_DECLARE_bool (use_mkldnn);
26+ COMMON_DECLARE_bool (use_onednn);
2627#ifdef PADDLE_WITH_CINN
2728PD_DECLARE_bool (use_cinn);
2829#endif
@@ -203,22 +204,23 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
203204
204205 void AppendPassToSetMkldnnAttr (const std::string &pass_name) {
205206#ifdef PADDLE_WITH_DNNL
206- if (FLAGS_use_mkldnn) {
207+ if (FLAGS_use_mkldnn || FLAGS_use_onednn ) {
207208 AppendPass (pass_name);
208209 } else if (!strategy_.onednn_enabled_op_types_ .empty ()) {
209- VLOG (1 ) << " mkldnn_enabled_op_types specify the operator type list to "
210- " use MKLDNN acceleration. It is null in default, means "
211- " that all the operators supported by MKLDNN will be "
210+ VLOG (1 ) << " onednn_enabled_op_types specify the operator type list to "
211+ " use ONEDNN acceleration. It is null in default, means "
212+ " that all the operators supported by ONEDNN will be "
212213 " accelerated. And it should not be set when "
213- " FLAGS_use_mkldnn =false." ;
214+ " FLAGS_use_onednn =false." ;
214215 }
215216#else
216- PADDLE_ENFORCE_NE (FLAGS_use_mkldnn,
217- true ,
218- common::errors::PreconditionNotMet (
219- " FLAGS_use_mkldnn has been set to True, but "
220- " PaddlePaddle is compiled without MKLDNN. "
221- " Please compile PaddlePaddle with MKLDNN first." ));
217+ PADDLE_ENFORCE_NE (
218+ FLAGS_use_mkldnn || FLAGS_use_onednn,
219+ true ,
220+ common::errors::PreconditionNotMet (
221+ " FLAGS_use_mkldnn or FLAGS_use_onednn has been set to True, but "
222+ " PaddlePaddle is compiled without ONEDNN. "
223+ " Please compile PaddlePaddle with ONEDNN first." ));
222224#endif
223225 }
224226
0 commit comments