Skip to content

Commit

Permalink
[AMP] fix bf16 amp training error (#54571) (#54643)
Browse files Browse the repository at this point in the history
fix bf16 amp training error
cherry pick #54571
  • Loading branch information
zhangting2020 authored Jun 15, 2023
1 parent 6b778b9 commit e93e48e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 41 deletions.
12 changes: 7 additions & 5 deletions paddle/fluid/eager/amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << input_name << ") dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ").";

if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "X") {
return input;
}
if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "X") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
Expand All @@ -86,6 +87,7 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
}
}
}

if (NeedCast(input, dst_dtype)) {
paddle::framework::AttributeMap cast_attrs = {
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
Expand Down
52 changes: 21 additions & 31 deletions paddle/fluid/eager/amp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,24 @@ static inline phi::DataType GetPromoteType(
kSlotSmallVectorSize>& amp_tensors_vector,
const phi::DataType& amp_dtype) {
auto dst_type = amp_dtype;
// only consider the dtype of input(X).
if (op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm" ||
op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
}
return dst_type;
}

if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
"float16") {
if (op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
}
} else if (op_name == "fused_attention") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
if (i != 3 || i != 4 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
break;
return dst_type;
}
}
}
Expand All @@ -47,37 +52,22 @@ static inline phi::DataType GetPromoteType(
if (i != 7 || i != 8 || i != 9 || i != 10) {
if (amp_tensors_vector[i][0].dtype() == phi::DataType::FLOAT32) {
dst_type = phi::DataType::FLOAT32;
break;
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
return dst_type;
}
}
}
}
} else {
for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}
}
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
// input(X)
if (op_name == "moving_average_abs_max_scale") {
if (amp_tensors_vector[0][0].dtype() == phi::DataType::FLOAT16) {
dst_type = phi::DataType::FLOAT16;

for (const auto& tensors : amp_tensors_vector) {
for (const auto& tensor : tensors) {
if (tensor.dtype() == phi::DataType::FLOAT32) {
dst_type = tensor.dtype();
break;
}
}
}

return dst_type;
}

Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/eager/eager_amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ inline paddle::Tensor EagerAmpAutoCast(const std::string& input_name,
VLOG(6) << "AMP AmpAutoCasts:"
<< " input(" << egr::EagerUtils::TensorStr(input) << " to dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ").";
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "x") {
return input;
}

if (dst_dtype == phi::DataType::FLOAT16) {
if (op_name == "run_program") {
return input;
}
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm") &&
input_name != "x") {
return input;
}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
Expand Down

0 comments on commit e93e48e

Please sign in to comment.