-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change staticRNN to while #48213
change staticRNN to while #48213
Conversation
… remove_recurrent_op_and_static_rnn_api
你的PR提交成功,感谢你对开源项目的贡献! |
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
[bugfix] fix _find_op_path_ bugs in append_backward.
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
auto shape = var_desc->GetShape(); | ||
VLOG(8) << "Found uninitialized tensor " << outside_og_name | ||
<< " in step 0, fill it with 0.0f. dims=" | ||
<< phi::make_ddim(shape); | ||
framework::AttributeMap attrs; | ||
attrs["dtype"] = var_desc->GetDataType(); | ||
attrs["shape"] = phi::vectorize<int>(phi::make_ddim(shape)); | ||
attrs["value"] = 0.0f; | ||
|
||
auto var_name = outside_og_name; | ||
auto zero_op = | ||
framework::OpRegistry::CreateOp("fill_constant", | ||
framework::VariableNameMap{}, | ||
{{"Out", {var_name}}}, | ||
attrs); | ||
zero_op->Run(scope, dev_place); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use SetConstant or Full API directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个逻辑上可以改成full_like; 但是
- while op内部没有数据类型的T,需要使用visitor
- full like的api的dense tensor是返回值, 要摁到scope里面,还需要一些操作
这种方式改完,逻辑比当前还要复杂
auto &og_outside = *scope.FindVar(outside_og_name); | ||
auto &og_inside = *cur_scope.Var(inside_og_name); | ||
if (og_outside.IsType<phi::DenseTensor>()) { | ||
auto &outside_tensor = og_outside.Get<phi::DenseTensor>(); | ||
auto &inside_tensor = *og_inside.GetMutable<phi::DenseTensor>(); | ||
inside_tensor.set_lod(outside_tensor.lod()); | ||
inside_tensor.ShareDataWith(outside_tensor); | ||
if (outside_tensor.IsInitialized()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line needed after set zero
added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shoule be removed here
auto is_var_input_and_output = | ||
std::find(outside_og_names.begin(), | ||
outside_og_names.end(), | ||
pg_ig_names[param_id]) != outside_og_names.end(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来是获取一个iterator,现在是把iterator的判断提前
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… remove_recurrent_op_and_static_rnn_api
… remove_recurrent_op_and_static_rnn_api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… remove_recurrent_op_and_static_rnn_api
PR types
Others
PR changes
Others
Describe
我们计划移除recurrent op 和 StaticRNN,将paddle.nn.RNN中依赖的staticRNN替换为while