Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomOp] Fix double grad var judging #41072

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) {
return func;
}

inline static bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos;
}

inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
Expand All @@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) {
return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
}

inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}

inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
Expand Down Expand Up @@ -493,11 +499,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}

protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
Expand All @@ -508,7 +515,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {

for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
Expand Down Expand Up @@ -540,6 +547,7 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};

template <>
Expand All @@ -553,12 +561,13 @@ class CustomGradOpMaker<imperative::OpBase>
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
const std::vector<std::string>& outputs, bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
outputs_(outputs),
is_double_grad_(is_double_grad) {}

protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
Expand All @@ -574,7 +583,7 @@ class CustomGradOpMaker<imperative::OpBase>

for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
Expand All @@ -600,6 +609,7 @@ class CustomGradOpMaker<imperative::OpBase>
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};

//////////// Operator and Kernel Register //////////////
Expand Down Expand Up @@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');

bool is_double_grad = (i == 2);

// GradOpDescMaker
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs](
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs,
is_double_grad](
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block) {
CustomGradOpMaker<paddle::framework::OpDesc> maker(
fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name,
grad_op_inputs, grad_op_outputs);
grad_op_inputs, grad_op_outputs, is_double_grad);
return maker();
};

// GradOpBaseMaker
info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs,
grad_op_outputs](
grad_op_outputs, is_double_grad](
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
Expand All @@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
grad_op_name, grad_op_inputs, grad_op_outputs, is_double_grad);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
Expand Down