@@ -2342,7 +2342,9 @@ PDNode *patterns::QuantConv::operator()(const std::string &conv_type) {
23422342 auto conv_op = pattern->NewNode (conv_op_repr ())->assert_is_op (conv_type);
23432343 conv_op->assert_more ([&](Node *node) {
23442344 return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
2345- " bfloat16" ;
2345+ " bfloat16" ||
2346+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
2347+ " bfloat16" ;
23462348 });
23472349
23482350 quant_op->LinksFrom ({quant_in}).LinksTo ({conv_in});
@@ -3172,7 +3174,8 @@ PDNode *patterns::QuantizePlacement::operator()(
31723174 auto *op =
31733175 pattern->NewNode (op_repr ())->assert_is_ops (quantize_enabled_op_types);
31743176 op->assert_more ([&](Node *node) {
3175- return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" );
3177+ return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) ||
3178+ node->Op ()->GetAttrIfExists <bool >(" use_onednn" );
31763179 });
31773180 return op;
31783181}
@@ -3218,6 +3221,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
32183221 auto *op = pattern->NewNode (op_repr ())->assert_is_ops (supported_op_types);
32193222 op->assert_more ([&](Node *node) {
32203223 return node->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) ||
3224+ node->Op ()->GetAttrIfExists <bool >(" use_onednn" ) ||
32213225 node->Op ()->Type () == " reshape2" ;
32223226 });
32233227 op->LinksFrom ({op_in});
@@ -3227,25 +3231,35 @@ PDNode *patterns::Bfloat16Placement::operator()(
32273231PDNode *patterns::OrphanedBfloat16::operator ()() {
32283232 auto *prev_op = pattern->NewNode (prev_op_repr ())->assert_is_op ();
32293233 prev_op->assert_more ([&](Node *node) {
3230- bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" );
3231- bool data_type_is_fp32 = node->Op ()->GetAttrIfExists <std::string>(
3232- " mkldnn_data_type" ) == " float32" ;
3234+ bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" ) &&
3235+ !node->Op ()->HasAttr (" onednn_data_type" );
3236+ bool data_type_is_fp32 =
3237+ node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3238+ " float32" ||
3239+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3240+ " float32" ;
32333241 return data_type_is_missing || data_type_is_fp32;
32343242 });
32353243 auto *prev_out = pattern->NewNode (prev_out_repr ())->AsOutput ();
32363244
32373245 auto *op = pattern->NewNode (op_repr ())->assert_is_op ();
32383246 op->assert_more ([&](Node *node) {
32393247 return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3240- " bfloat16" ;
3248+ " bfloat16" ||
3249+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3250+ " bfloat16" ;
32413251 });
32423252 auto *op_out = pattern->NewNode (op_out_repr ())->AsOutput ();
32433253
32443254 auto *next_op = pattern->NewNode (next_op_repr ())->assert_is_op ();
32453255 next_op->assert_more ([&](Node *node) {
3246- bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" );
3247- bool data_type_is_fp32 = node->Op ()->GetAttrIfExists <std::string>(
3248- " mkldnn_data_type" ) == " float32" ;
3256+ bool data_type_is_missing = !node->Op ()->HasAttr (" mkldnn_data_type" ) &&
3257+ !node->Op ()->HasAttr (" onednn_data_type" );
3258+ bool data_type_is_fp32 =
3259+ node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3260+ " float32" ||
3261+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3262+ " float32" ;
32493263 return data_type_is_missing || data_type_is_fp32;
32503264 });
32513265
@@ -3258,14 +3272,17 @@ PDNode *patterns::OrphanedBfloat16::operator()() {
32583272PDNode *patterns::UnsupportedBfloat16::operator ()() {
32593273 auto *prev_op = pattern->NewNode (prev_op_repr ())->assert_is_op ();
32603274 prev_op->assert_more ([&](Node *node) {
3261- return node->Op ()->HasAttr (" mkldnn_data_type" ) == false ;
3275+ return node->Op ()->HasAttr (" mkldnn_data_type" ) == false &&
3276+ node->Op ()->HasAttr (" onednn_data_type" ) == false ;
32623277 });
32633278 auto *prev_out = pattern->NewNode (prev_out_repr ())->AsOutput ();
32643279
32653280 auto *op = pattern->NewNode (op_repr ())->assert_is_op ();
32663281 op->assert_more ([&](Node *node) {
32673282 return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3268- " bfloat16" ;
3283+ " bfloat16" ||
3284+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3285+ " bfloat16" ;
32693286 });
32703287 prev_op->LinksTo ({prev_out});
32713288 op->LinksFrom ({prev_out});
@@ -3276,7 +3293,9 @@ PDNode *patterns::Bloat16Ops::operator()() {
32763293 auto op = pattern->NewNode (op_repr ())->assert_is_op ();
32773294 op->assert_more ([&](Node *node) {
32783295 return node->Op ()->GetAttrIfExists <std::string>(" mkldnn_data_type" ) ==
3279- " bfloat16" ;
3296+ " bfloat16" ||
3297+ node->Op ()->GetAttrIfExists <std::string>(" onednn_data_type" ) ==
3298+ " bfloat16" ;
32803299 });
32813300 return op;
32823301}
@@ -3298,8 +3317,8 @@ PDNode *patterns::ONEDNNInPlace::operator()() {
32983317 auto next_op = pattern->NewNode (next_op_repr ())->assert_is_op ();
32993318 auto next_output = pattern->NewNode (next_op_out_repr ())->AsOutput ();
33003319
3301- // Check if op is MKL -DNN enabled
3302- possible_inplace_op->assert_op_attr (" use_mkldnn" , true );
3320+ // Check if op is ONE -DNN enabled
3321+ possible_inplace_op->assert_op_attr_or (" use_mkldnn" , " use_onednn " , true );
33033322
33043323 // linked structure
33053324 possible_inplace_op->LinksTo ({output});
0 commit comments