From 9c1efa2c93558da3270e63c4d70c189eae78d70a Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 26 May 2022 15:19:05 +0000 Subject: [PATCH] Fix cond_block_grad error when handle no need grad vras --- .../controlflow/conditional_block_op.cc | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index fd06e33a6bb6e..69e173eb57468 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -153,7 +153,7 @@ class ConditionalBlockGradOp : public ConditionalOp { /* keep_kid_scopes */ false); AssignLocalGradientToParentScope(dev_place, cur_scope, scope, - inside_grads, outside_grads); + inside_grads, outside_grads, inputs); return; } @@ -165,27 +165,34 @@ class ConditionalBlockGradOp : public ConditionalOp { const platform::Place &place, const framework::Scope &cur_scope, const framework::Scope &parent_scope, const std::vector &inside_grads, - const std::vector &outside_grads) const { + const std::vector &outside_grads, + const std::vector &inputs) const { + std::vector assign_zero_outside_grads; + std::vector assign_zero_inputs; for (size_t i = 0; i < outside_grads.size(); ++i) { const std::string &outside_grad_name = outside_grads[i]; const std::string &inside_grad_name = inside_grads[i]; VLOG(4) << "inside_grad_name = " << inside_grad_name << ", outside_grad_name = " << outside_grad_name; - framework::Variable *inside_var = - cur_scope.FindLocalVar(inside_grad_name); - if (inside_var == nullptr) { - continue; - } framework::Variable *outside_var = parent_scope.FindVar(outside_grad_name); if (outside_var == nullptr) { continue; } + framework::Variable *inside_var = + cur_scope.FindLocalVar(inside_grad_name); + if (inside_var == nullptr) { + assign_zero_outside_grads.push_back(outside_grad_name); + assign_zero_inputs.push_back(inputs[i]); + continue; + } platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); } + AssignZeroToParentScope(place, parent_scope, assign_zero_inputs, + assign_zero_outside_grads); } void AssignZeroToParentScope(