-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change staticRNN to while #48213
change staticRNN to while #48213
Changes from all commits
04bafba
06590b2
8790ec4
5de707c
e48cb85
c8f184c
53b0eb6
25968b4
f380f10
715fb46
caf97ed
f6877e1
61b0f4f
2adafc6
360735d
e0df6ba
5bb85e8
2471bac
5e03fda
2ed574c
436281a
11e5767
5f40c97
2cedb0d
75bc922
f111e38
a4bb359
149ba2e
24f24ca
3213c42
acf64d4
f9967c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,6 +305,7 @@ class WhileOp : public framework::OperatorBase { | |
cond_data = GetCondData( | ||
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>()); | ||
} | ||
|
||
scope.DeleteScope(¤t_scope); | ||
} | ||
} | ||
|
@@ -367,6 +368,7 @@ class WhileGradOp : public framework::OperatorBase { | |
|
||
auto *block = Attr<framework::BlockDesc *>(kStepBlock); | ||
auto *program = block->Program(); | ||
auto *parent_block = block->ParentBlock(); | ||
|
||
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); | ||
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); | ||
|
@@ -428,6 +430,35 @@ class WhileGradOp : public framework::OperatorBase { | |
continue; | ||
} | ||
|
||
if (cur_scope_iter == step_scopes->rbegin()) { | ||
auto &og_outside = *scope.FindVar(outside_og_name); | ||
if (og_outside.IsType<phi::DenseTensor>() && | ||
!og_outside.GetMutable<phi::DenseTensor>()->IsInitialized()) { | ||
auto *var_desc = parent_block->FindVarRecursive(outside_og_name); | ||
PADDLE_ENFORCE_NOT_NULL(var_desc, | ||
platform::errors::PreconditionNotMet( | ||
"Var `%s` is not found in parent " | ||
"block, can't fill constant.", | ||
outside_og_name)); | ||
auto shape = var_desc->GetShape(); | ||
VLOG(8) << "Found uninitialized tensor " << outside_og_name | ||
<< " in step 0, fill it with 0.0f. dims=" | ||
<< phi::make_ddim(shape); | ||
framework::AttributeMap attrs; | ||
attrs["dtype"] = var_desc->GetDataType(); | ||
attrs["shape"] = phi::vectorize<int>(phi::make_ddim(shape)); | ||
attrs["value"] = 0.0f; | ||
|
||
auto var_name = outside_og_name; | ||
auto zero_op = | ||
framework::OpRegistry::CreateOp("fill_constant", | ||
framework::VariableNameMap{}, | ||
{{"Out", {var_name}}}, | ||
attrs); | ||
zero_op->Run(scope, dev_place); | ||
} | ||
} | ||
|
||
auto &og_outside = *scope.FindVar(outside_og_name); | ||
auto &og_inside = *cur_scope.Var(inside_og_name); | ||
if (og_outside.IsType<phi::DenseTensor>()) { | ||
|
@@ -534,9 +565,10 @@ class WhileGradOp : public framework::OperatorBase { | |
// continue; | ||
// } | ||
|
||
auto var_iter = std::find(outside_og_names.begin(), | ||
outside_og_names.end(), | ||
pg_ig_names[param_id]); | ||
auto is_var_input_and_output = | ||
std::find(outside_og_names.begin(), | ||
outside_og_names.end(), | ||
pg_ig_names[param_id]) != outside_og_names.end(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 原来是获取一个iterator,现在是把iterator的判断提前 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
// zero gradient variable in step 0 | ||
if (cur_scope_iter == step_scopes->rbegin()) { | ||
|
@@ -555,8 +587,7 @@ class WhileGradOp : public framework::OperatorBase { | |
inside_grad_name, | ||
framework::ToTypeName(var->Type()))); | ||
|
||
if ((var_iter == outside_og_names.end()) && | ||
var->IsType<phi::DenseTensor>()) { | ||
if (!is_var_input_and_output && var->IsType<phi::DenseTensor>()) { | ||
auto &inside_tensor = var->Get<phi::DenseTensor>(); | ||
framework::AttributeMap attrs; | ||
attrs["dtype"] = | ||
|
@@ -575,10 +606,7 @@ class WhileGradOp : public framework::OperatorBase { | |
inside_tensor.lod()); | ||
} | ||
} | ||
auto var_outside = scope.FindVar(pg_ig_names[param_id]); | ||
if ((var_iter == outside_og_names.end()) || | ||
((var_iter != outside_og_names.end()) && | ||
var_outside->IsType<framework::LoDTensorArray>())) { | ||
if (!is_var_input_and_output) { | ||
auto new_inside_name = cur_scope.Rename(inside_grad_name); | ||
auto sum_op = framework::OpRegistry::CreateOp( | ||
"sum", | ||
|
@@ -587,6 +615,8 @@ class WhileGradOp : public framework::OperatorBase { | |
framework::AttributeMap{{"use_mkldnn", {false}}}); | ||
sum_op->Run(cur_scope, dev_place); | ||
cur_scope.Rename(new_inside_name, inside_grad_name); | ||
} else { | ||
ShareVariable(cur_scope, scope, pg_ig_names[param_id]); | ||
} | ||
} | ||
dev_ctx.Wait(); | ||
|
@@ -595,6 +625,29 @@ class WhileGradOp : public framework::OperatorBase { | |
step_scopes->clear(); | ||
} | ||
|
||
void ShareVariable(const framework::Scope &source, | ||
const framework::Scope &dest, | ||
std::string name) const { | ||
auto from_var = source.FindVar(name); | ||
auto to_var = dest.FindVar(name); | ||
if (from_var->IsType<phi::DenseTensor>()) { | ||
if (from_var->Get<phi::DenseTensor>().IsInitialized()) { | ||
to_var->GetMutable<phi::DenseTensor>()->ShareDataWith( | ||
from_var->Get<phi::DenseTensor>()); | ||
} | ||
} else if (from_var->IsType<framework::LoDTensorArray>()) { | ||
auto from_arr = from_var->GetMutable<framework::LoDTensorArray>(); | ||
auto to_arr = to_var->GetMutable<framework::LoDTensorArray>(); | ||
to_arr->clear(); | ||
to_arr->resize(from_arr->size()); | ||
for (size_t i = 0; i < to_arr->size(); ++i) { | ||
if (from_arr->at(i).IsInitialized()) { | ||
to_arr->at(i).ShareDataWith(from_arr->at(i)); | ||
} | ||
} | ||
} | ||
} | ||
|
||
private: | ||
mutable std::shared_ptr<framework::Executor> executor_{nullptr}; | ||
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr}; | ||
|
@@ -646,6 +699,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { | |
block_ins.insert(o); | ||
} | ||
std::unordered_set<std::string> output_grads; | ||
|
||
for (const auto *op : grad_block->AllOps()) { | ||
for (auto &input_name : op->InputArgumentNames()) { | ||
// If the input of Op has been recorded or is generated by the forward | ||
|
@@ -658,7 +712,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { | |
parent_block->FindVarRecursive(input_name) != nullptr)) { | ||
continue; | ||
} | ||
|
||
output_grads.insert(input_name); | ||
} | ||
for (auto &output_name : op->OutputArgumentNames()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use SetConstant or Full API directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑上可以改成full_like; 但是
这种方式改完,逻辑比当前还要复杂