@@ -124,7 +124,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
124124 // in_black_list = (kernelCode >> 5) & 1
125125 // is_operator_base = (kernelCode >> 4) & 1
126126 // is_custom_op = (kernelCode >> 3) & 1
127- // use_mkldnn = (kernelCode >> 2) & 1
127+ // use_onednn = (kernelCode >> 2) & 1
128128 // sub_block_can_not_static_build = (kernelCode >> 1) & 1
129129 using KernelCode = int8_t ;
130130 std::set<std::pair<std::string, KernelCode>> invalid_ops;
@@ -150,6 +150,12 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
150150 use_mkldnn = attr.index () == 1 ? PADDLE_GET_CONST (int , attr)
151151 : PADDLE_GET_CONST (bool , attr);
152152 }
153+ bool use_onednn = use_mkldnn;
154+ if (!use_mkldnn && op->HasAttr (" use_onednn" )) {
155+ Attribute attr = op->GetAttr (" use_onednn" );
156+ use_onednn = attr.index () == 1 ? PADDLE_GET_CONST (int , attr)
157+ : PADDLE_GET_CONST (bool , attr);
158+ }
153159
154160 bool sub_block_can_not_static_build = false ;
155161 if (op->HasAttr (" sub_block" )) {
@@ -160,9 +166,9 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
160166
161167 KernelCode kernel_code = static_cast <KernelCode>(
162168 (in_black_list << 5 ) + (is_operator_base << 4 ) + (is_custom_op << 3 ) +
163- (use_mkldnn << 2 ) + (sub_block_can_not_static_build << 1 ));
169+ (use_onednn << 2 ) + (sub_block_can_not_static_build << 1 ));
164170
165- if (in_black_list || is_operator_base || is_custom_op || use_mkldnn ||
171+ if (in_black_list || is_operator_base || is_custom_op || use_onednn ||
166172 sub_block_can_not_static_build) {
167173 invalid_ops.insert (std::make_pair (op_type, kernel_code));
168174 }
0 commit comments