Skip to content

Commit

Permalink
Fix cond_block_grad error when handle no need grad vras
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed May 26, 2022
1 parent 6af32a7 commit 9c1efa2
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions paddle/fluid/operators/controlflow/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
/* keep_kid_scopes */ false);

AssignLocalGradientToParentScope(dev_place, cur_scope, scope,
inside_grads, outside_grads);
inside_grads, outside_grads, inputs);
return;
}

Expand All @@ -165,27 +165,34 @@ class ConditionalBlockGradOp : public ConditionalOp {
const platform::Place &place, const framework::Scope &cur_scope,
const framework::Scope &parent_scope,
const std::vector<std::string> &inside_grads,
const std::vector<std::string> &outside_grads) const {
const std::vector<std::string> &outside_grads,
const std::vector<std::string> &inputs) const {
std::vector<std::string> assign_zero_outside_grads;
std::vector<std::string> assign_zero_inputs;
for (size_t i = 0; i < outside_grads.size(); ++i) {
const std::string &outside_grad_name = outside_grads[i];
const std::string &inside_grad_name = inside_grads[i];
VLOG(4) << "inside_grad_name = " << inside_grad_name
<< ", outside_grad_name = " << outside_grad_name;
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
continue;
}
framework::Variable *outside_var =
parent_scope.FindVar(outside_grad_name);
if (outside_var == nullptr) {
continue;
}
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
assign_zero_outside_grads.push_back(outside_grad_name);
assign_zero_inputs.push_back(inputs[i]);
continue;
}
platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
framework::VisitVarType(*inside_var,
AssignFunctor(outside_var, *dev_ctx));
}
AssignZeroToParentScope(place, parent_scope, assign_zero_inputs,
assign_zero_outside_grads);
}

void AssignZeroToParentScope(
Expand Down

0 comments on commit 9c1efa2

Please sign in to comment.