diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 927a8c4754c1a..6f74fbe5f8f8b 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -62,11 +62,6 @@ static T* DynLoad(void* handle, std::string name) { return func; } -inline static bool IsGradVar(const std::string& var_name) { - std::string suffix = kGradVarSuffix; - return var_name.rfind(suffix) != std::string::npos; -} - inline static bool IsDuplicableVar(const std::string& var_name) { std::string suffix = kTensorVectorSuffix; return var_name.rfind(suffix) != std::string::npos; @@ -77,6 +72,17 @@ inline static std::string NoGrad(const std::string& var_name) { return var_name.substr(0, var_name.size() - kGradVarSuffixSize); } +inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) { + std::string suffix = kGradVarSuffix; + if (!is_double_grad) { + return var_name.rfind(suffix) != std::string::npos; + } else { + // for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a + // grad var, here we remove a @GRAD suffix + return NoGrad(var_name).rfind(suffix) != std::string::npos; + } +} + inline static bool IsMemberOf(const std::vector& vec, const std::string& name) { return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); @@ -493,11 +499,12 @@ class CustomGradOpMaker : public SingleGradOpMaker { std::unordered_map* grad_to_var, const std::vector& grad_block, const std::string& name, const std::vector& inputs, - const std::vector& outputs) + const std::vector& outputs, bool is_double_grad) : SingleGradOpMaker(fwd_op, no_grad_set, grad_to_var, grad_block), name_(name), inputs_(inputs), - outputs_(outputs) {} + outputs_(outputs), + is_double_grad_(is_double_grad) {} protected: void Apply(GradOpPtr grad_op) const override { @@ -508,7 +515,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { for (auto& in_name : inputs_) { VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name; - if (!detail::IsGradVar(in_name)) { + if (!detail::IsGradVar(in_name, is_double_grad_)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) { grad_op->SetInput(in_name, this->Input(in_name)); } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { @@ -540,6 +547,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { std::string name_; std::vector inputs_; std::vector outputs_; + bool is_double_grad_{false}; }; template <> @@ -553,12 +561,13 @@ class CustomGradOpMaker const AttributeMap& attrs, const std::map& inplace_map, const std::string& name, const std::vector& inputs, - const std::vector& outputs) + const std::vector& outputs, bool is_double_grad) : SingleGradOpMaker( type, var_base_map_in, var_base_map_out, attrs, inplace_map), name_(name), inputs_(inputs), - outputs_(outputs) {} + outputs_(outputs), + is_double_grad_(is_double_grad) {} protected: // TODO(chenweihang): The code is duplicated with the previous one, because @@ -574,7 +583,7 @@ class CustomGradOpMaker for (auto& in_name : inputs_) { VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name; - if (!detail::IsGradVar(in_name)) { + if (!detail::IsGradVar(in_name, is_double_grad_)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) { grad_op->SetInput(in_name, this->Input(in_name)); } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { @@ -600,6 +609,7 @@ class CustomGradOpMaker std::string name_; std::vector inputs_; std::vector outputs_; + bool is_double_grad_{false}; }; //////////// Operator and Kernel Register ////////////// @@ -832,21 +842,24 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, VLOG(3) << "Custom Operator: backward, op outputs: " << string::join_strings(grad_op_outputs, ','); + bool is_double_grad = (i == 2); + // GradOpDescMaker - info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs]( + info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs, + is_double_grad]( const OpDesc& fwd_op, const std::unordered_set& no_grad_set, std::unordered_map* grad_to_var, const std::vector& grad_block) { CustomGradOpMaker maker( fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name, - grad_op_inputs, grad_op_outputs); + grad_op_inputs, grad_op_outputs, is_double_grad); return maker(); }; // GradOpBaseMaker info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs, - grad_op_outputs]( + grad_op_outputs, is_double_grad]( const std::string& type, const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, @@ -855,7 +868,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, const std::map& inplace_map) { CustomGradOpMaker maker( type, var_base_map_in, var_base_map_out, attrs, inplace_map, - grad_op_name, grad_op_inputs, grad_op_outputs); + grad_op_name, grad_op_inputs, grad_op_outputs, is_double_grad); maker.SetDygraphDefaultAttrsMap(default_attrs); return maker(); };