Skip to content

Commit

Permalink
Fix mixed precision bug (#49239)
Browse files Browse the repository at this point in the history
* [Release2.4] Revert python link prs (#48573)

* Revert "Fix mac link python (#48017)"

This reverts commit 3fa7a73.

* Revert "[Cherry-pick] Fix python link error (#47811)"

This reverts commit ff642c6.

* Update config.go

* fix mixed precision inference

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
  • Loading branch information
yuanlehome and chenwhql authored Dec 22, 2022
1 parent 612bdb1 commit 11c7f57
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,20 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}

// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;

vars_should_not_low_precision.insert(out_var_node->Var()->Name());
}
}
}
}
};
Expand All @@ -449,6 +463,25 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;

for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;

auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;

if (vars_should_not_low_precision.count(
real_in_var_node->Var()->Name())) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}

if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;

for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
Expand Down

0 comments on commit 11c7f57

Please sign in to comment.