Skip to content

Commit

Permalink
Fix for Bfloat16 placement pass. (#43109)
Browse files Browse the repository at this point in the history
* Fix bfloat16 placement pass

* Make it nicer

* Fix leftovers

* Style
  • Loading branch information
tsocha authored Jun 2, 2022
1 parent 990c5e7 commit 030b23d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2631,8 +2631,10 @@ PDNode *patterns::Bfloat16Placement::operator()(
PDNode *patterns::OrphanedBfloat16::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
"mkldnn_data_type") == "float32";
return data_type_is_missing || data_type_is_fp32;
});
auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput();

Expand All @@ -2645,8 +2647,10 @@ PDNode *patterns::OrphanedBfloat16::operator()() {

auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"float32";
bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type");
bool data_type_is_fp32 = node->Op()->GetAttrIfExists<std::string>(
"mkldnn_data_type") == "float32";
return data_type_is_missing || data_type_is_fp32;
});

prev_op->LinksTo({prev_out});
Expand Down

0 comments on commit 030b23d

Please sign in to comment.