-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RecurrentOp implementation #2890
Conversation
Tensor* step_input = step_scopes[j] | ||
->CreateVariable(inlinks[i].internal) | ||
->GetMutable<Tensor>(); | ||
*step_input = input->Slice<float>(j, j + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据类型float后续需要处理~
->GetMutable<Tensor>() | ||
->dims(); | ||
std::vector<int> dims_vec = vectorize(step_dims); | ||
dims_vec.insert(dims_vec.begin(), seq_len); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果output的shape也是在RNNOp的InferShape里事先构造好,这里就不需要推断dims_vec。
->GetMutable<Tensor>(); | ||
// TODO data type and platform::DeviceContext() should set correctly | ||
(output->Slice<float>(j, j + 1)) | ||
.CopyFrom<float>(*step_output, platform::CPUDeviceContext()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据类型和设备类型后续需要正确处理
// InferShape. That's a problem. Wether the RNN op needs InferShape or not? | ||
// Wether the following functions (SegmentInputs, InitMemories, ...) need | ||
// to rewrite for RNN op? | ||
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RNN op的InferShape里需要对输入切片,这里调用的SegmentInputs函数对input做了切分,调用了tensor的Slice函数,这里要求tensor拥有内存。
但是,在当前的设计中,普通op的InferShape里只设置output的shape,一开始output并没有内存,所以这样是有问题的。
InferShape里调用的这些函数是否需要重新设计?RNN op是否需要InferShape?
for (size_t i = 0; i < seq_len_; i++) { | ||
if (i > 0) { | ||
rnn::LinkMemories(step_scopes, arg_->memories, i, -1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了推断step net op里output的shape,连接了所有时刻的memroy。是否有必要? 是否可以只对第0时刻做InferShape,然后在Run()函数里,stepnet调用Run()之前,调用stepnet的InferShape,正确设置每个时刻 step net op里output的shape?
namespace paddle { | ||
namespace operators { | ||
|
||
using namespace paddle::framework; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad code style, maybe you can use "paddle/framework/type_alias.h"
// 4. More Complex RNN architecture, such as Gated Feedback RNN. | ||
// Refer to: https://arxiv.org/pdf/1502.02367.pdf | ||
|
||
class RecurrentAlgorithm { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is 'RecurrentAlgorithm' needed? Why cannot use simple
class RNNForward: public OperatorBase
} // namespace rnn | ||
|
||
void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { | ||
seq_len_ = scope->GetVariable((arg_->inlinks[0]).external) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(arg_->inlinks[0]).external
--> arg_->inliks[0].external
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is an inlink?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about sequence length not equal between all in_links?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the in_link[0] of RNN is not a tensor?
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); | ||
} | ||
|
||
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::shared_ptr&
} | ||
|
||
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { | ||
// TODO(xxx) Only two scopes are needed for inference, this case will be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who is XXX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could merge it first, and change it later.
No description provided.