Skip to content

Commit

Permalink
refine the code
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Jun 18, 2021
1 parent bbe2e45 commit c76e017
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToFP32(
}

static inline framework::proto::VarType::Type GetPromoteType(
const NameVarBaseMap& ins) {
const std::string& op_type, const NameVarBaseMap& ins) {
auto dst_type = framework::proto::VarType::FP16;
for (const auto& pair : ins) {
for (const auto& var : pair.second) {
Expand All @@ -151,6 +151,18 @@ static inline framework::proto::VarType::Type GetPromoteType(
}
}
}

// NOTE(juncai): moving_average_abs_max_scale only consider the
// dtype of input(X)
if (op_type == "moving_average_abs_max_scale") {
for (const auto& pair : ins) {
if (pair.first == "X" &&
pair.second.front()->DataType() == framework::proto::VarType::FP16) {
dst_type = framework::proto::VarType::FP16;
}
}
}

return dst_type;
}

Expand Down Expand Up @@ -183,17 +195,7 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
}
return new_ins;
} else {
auto dst_type = GetPromoteType(ins);

if (op_type == "moving_average_abs_max_scale") {
for (const auto& pair : ins) {
if (pair.first == "X" &&
pair.second.front()->DataType() ==
framework::proto::VarType::FP16) {
dst_type = framework::proto::VarType::FP16;
}
}
}
auto dst_type = GetPromoteType(op_type, ins);

// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 &&
Expand Down

1 comment on commit c76e017

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.