Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change staticRNN to while #48213

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
04bafba
change staticRNN to while
phlrain Nov 21, 2022
06590b2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 21, 2022
8790ec4
update code
phlrain Nov 21, 2022
5de707c
fix rnn bug
phlrain Nov 23, 2022
e48cb85
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 23, 2022
c8f184c
update
phlrain Nov 23, 2022
53b0eb6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 23, 2022
25968b4
fix _find_op_path_ bugs in append_backward.
2742195759 Nov 23, 2022
f380f10
Merge pull request #5 from 2742195759/pr_48213
phlrain Nov 24, 2022
715fb46
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 24, 2022
caf97ed
polish code
phlrain Nov 24, 2022
f6877e1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 24, 2022
61b0f4f
revert op proto
phlrain Nov 24, 2022
2adafc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 24, 2022
360735d
update
phlrain Nov 29, 2022
e0df6ba
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 1, 2022
5bb85e8
udpate while
phlrain Dec 1, 2022
2471bac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 1, 2022
5e03fda
format
phlrain Dec 1, 2022
2ed574c
revert test while loop op
phlrain Dec 2, 2022
436281a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 2, 2022
11e5767
fix create array
phlrain Dec 2, 2022
5f40c97
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 2, 2022
2cedb0d
fix windows error
phlrain Dec 8, 2022
75bc922
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 8, 2022
f111e38
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 11, 2022
a4bb359
fix bug
phlrain Dec 11, 2022
149ba2e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 13, 2022
24f24ca
update
phlrain Dec 13, 2022
3213c42
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 14, 2022
acf64d4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Dec 14, 2022
f9967c6
fix array write bug
phlrain Dec 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h
73 changes: 63 additions & 10 deletions paddle/fluid/operators/controlflow/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ class WhileOp : public framework::OperatorBase {
cond_data = GetCondData(
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}

scope.DeleteScope(&current_scope);
}
}
Expand Down Expand Up @@ -367,6 +368,7 @@ class WhileGradOp : public framework::OperatorBase {

auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
auto *parent_block = block->ParentBlock();

auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
Expand Down Expand Up @@ -428,6 +430,35 @@ class WhileGradOp : public framework::OperatorBase {
continue;
}

if (cur_scope_iter == step_scopes->rbegin()) {
auto &og_outside = *scope.FindVar(outside_og_name);
if (og_outside.IsType<phi::DenseTensor>() &&
!og_outside.GetMutable<phi::DenseTensor>()->IsInitialized()) {
auto *var_desc = parent_block->FindVarRecursive(outside_og_name);
PADDLE_ENFORCE_NOT_NULL(var_desc,
platform::errors::PreconditionNotMet(
"Var `%s` is not found in parent "
"block, can't fill constant.",
outside_og_name));
auto shape = var_desc->GetShape();
VLOG(8) << "Found uninitialized tensor " << outside_og_name
<< " in step 0, fill it with 0.0f. dims="
<< phi::make_ddim(shape);
framework::AttributeMap attrs;
attrs["dtype"] = var_desc->GetDataType();
attrs["shape"] = phi::vectorize<int>(phi::make_ddim(shape));
attrs["value"] = 0.0f;

auto var_name = outside_og_name;
auto zero_op =
framework::OpRegistry::CreateOp("fill_constant",
framework::VariableNameMap{},
{{"Out", {var_name}}},
attrs);
zero_op->Run(scope, dev_place);
Comment on lines +443 to +458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use SetConstant or Full API directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个逻辑上可以改成full_like; 但是

  1. while op内部没有数据类型的T,需要使用visitor
  2. full like的api的dense tensor是返回值, 要摁到scope里面,还需要一些操作
    这种方式改完,逻辑比当前还要复杂

}
}

auto &og_outside = *scope.FindVar(outside_og_name);
auto &og_inside = *cur_scope.Var(inside_og_name);
if (og_outside.IsType<phi::DenseTensor>()) {
Expand Down Expand Up @@ -534,9 +565,10 @@ class WhileGradOp : public framework::OperatorBase {
// continue;
// }

auto var_iter = std::find(outside_og_names.begin(),
outside_og_names.end(),
pg_ig_names[param_id]);
auto is_var_input_and_output =
std::find(outside_og_names.begin(),
outside_og_names.end(),
pg_ig_names[param_id]) != outside_og_names.end();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来是获取一个iterator,现在是把iterator的判断提前

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


// zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) {
Expand All @@ -555,8 +587,7 @@ class WhileGradOp : public framework::OperatorBase {
inside_grad_name,
framework::ToTypeName(var->Type())));

if ((var_iter == outside_og_names.end()) &&
var->IsType<phi::DenseTensor>()) {
if (!is_var_input_and_output && var->IsType<phi::DenseTensor>()) {
auto &inside_tensor = var->Get<phi::DenseTensor>();
framework::AttributeMap attrs;
attrs["dtype"] =
Expand All @@ -575,10 +606,7 @@ class WhileGradOp : public framework::OperatorBase {
inside_tensor.lod());
}
}
auto var_outside = scope.FindVar(pg_ig_names[param_id]);
if ((var_iter == outside_og_names.end()) ||
((var_iter != outside_og_names.end()) &&
var_outside->IsType<framework::LoDTensorArray>())) {
if (!is_var_input_and_output) {
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum",
Expand All @@ -587,6 +615,8 @@ class WhileGradOp : public framework::OperatorBase {
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
} else {
ShareVariable(cur_scope, scope, pg_ig_names[param_id]);
}
}
dev_ctx.Wait();
Expand All @@ -595,6 +625,29 @@ class WhileGradOp : public framework::OperatorBase {
step_scopes->clear();
}

void ShareVariable(const framework::Scope &source,
const framework::Scope &dest,
std::string name) const {
auto from_var = source.FindVar(name);
auto to_var = dest.FindVar(name);
if (from_var->IsType<phi::DenseTensor>()) {
if (from_var->Get<phi::DenseTensor>().IsInitialized()) {
to_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
from_var->Get<phi::DenseTensor>());
}
} else if (from_var->IsType<framework::LoDTensorArray>()) {
auto from_arr = from_var->GetMutable<framework::LoDTensorArray>();
auto to_arr = to_var->GetMutable<framework::LoDTensorArray>();
to_arr->clear();
to_arr->resize(from_arr->size());
for (size_t i = 0; i < to_arr->size(); ++i) {
if (from_arr->at(i).IsInitialized()) {
to_arr->at(i).ShareDataWith(from_arr->at(i));
}
}
}
}

private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
Expand Down Expand Up @@ -646,6 +699,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
block_ins.insert(o);
}
std::unordered_set<std::string> output_grads;

for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward
Expand All @@ -658,7 +712,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
parent_block->FindVarRecursive(input_name) != nullptr)) {
continue;
}

output_grads.insert(input_name);
}
for (auto &output_name : op->OutputArgumentNames()) {
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,10 @@ def _find_op_path_(
op.desc.output_arg_names(), output_names
):
relevant_op_flags[i] = True
if core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names():
if name not in no_grad_set:
output_names.add(name)

op_path = [
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
Expand Down
Loading