-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Using LoDTensor instead of Tensor in every operator. #4048
Conversation
"Variable must be type LoDTensor or Tensor"); | ||
return *static_cast<const Tensor*>(holder_->Ptr()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了使得不用LoD信息的operators比较方便的从LoDTensor
里获取Tensor
:
- 这里特化了
Variable::Get<Tensor>
接口使得可以 - 也可通过方式 https://github.com/PaddlePaddle/Paddle/pull/4048/files#diff-91a47df6a639afa5ad046a8d43c640e0R298 在特化
InferShapeContext::Input<Tensor>(const std::string& name)
和InferShapeContext::Output<Tensor>(const std::string& name)
接口。
{x_mat_dims[0], y_mat_dims[1]}); | ||
// Only the forward operator needs to pass the lod. | ||
// pass the lod of Input(X) to output. | ||
ctx.CopyLoD(/*in = */ "X", /* out = */ "Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在每个Op里显示调用LoD传递函数, 解释: #4047 (comment)
|
||
void set_tensor(Tensor* tensor) { tensor_ = tensor; } | ||
~LoDTensor() { delete tensor_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoDTensor
管理成员变量Tensor* tensor_
的构造与释放,原因:#4047 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe LoDTensor should
- Disable copy
- Disable assign
- Disable move
The simplest way to implement it is make tensor_
as unique_ptr
|
||
void set_lod(const LoD& lod) { lod_ = lod; } | ||
LoDTensor(const LoD& lod) : lod_(lod), tensor_(new Tensor()) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explicit LoDTensor
@@ -289,13 +289,26 @@ class InferShapeContext { | |||
} | |||
|
|||
template <typename T> | |||
const T* Input(const std::string& name) const { | |||
virtual const T* Input(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A template method cannot be virtual.
auto* var = InputVar(name); | ||
return var == nullptr ? nullptr : &var->Get<T>(); | ||
} | ||
|
||
// template <> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
??
either delete it or uncomment it
// template <> | ||
// virtual const Tensor* Input<Tensor>(const std::string& name) const { | ||
// auto* var = InputVar(name); | ||
// if (var == nullptr) return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete these
@@ -289,13 +289,26 @@ class InferShapeContext { | |||
} | |||
|
|||
template <typename T> | |||
const T* Input(const std::string& name) const { | |||
virtual const T* Input(const std::string& name) const { | |||
auto* var = InputVar(name); | |||
return var == nullptr ? nullptr : &var->Get<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var ? var : &var->Get<T>()
var is false only if it equals nullptr.
template <typename T> | ||
T* Output(const std::string& name) const { | ||
virtual T* Output(const std::string& name) const { | ||
auto var = OutputVar(name); | ||
return var == nullptr ? nullptr : var->GetMutable<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
@@ -363,6 +384,27 @@ class ExecutionContext : public InferShapeContext { | |||
return device_context_; | |||
} | |||
|
|||
template <typename T> | |||
T* Output(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output be a const T&
better?
const T& Output
T* mutable_output
or
T* MutableOutput
borrowed from Protobuf.
} | ||
|
||
template <typename T> | ||
std::vector<T*> MultiOutput(const std::string& name) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
T* Output(const std::string& name) const override { | ||
auto var = OutputVar(name); | ||
// Different from InferShapeContext, call Get instread of GetMutable. | ||
return var == nullptr ? nullptr : var->Get<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable::Get 返回的是一个const reference。Variable::GetMutable返回的是一个mutable pointer。也就是说如要要返回mutable pointer,应该调用的是 GetMutable,而不是Get。这里为什么要调用Get而不是GetMutable呢?
Fix #4047
Fix #3717
解决的问题:
Tensor* tensor_
的构造和释放@wangkuiyi @reyoung @Superjom @QiJune @hedaoyuan