-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Using LoDTensor instead of Tensor in every operator. #4048
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,17 +53,23 @@ bool operator==(const LoD& a, const LoD& b); | |
*/ | ||
class LoDTensor { | ||
public: | ||
LoDTensor() {} | ||
LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {} | ||
LoDTensor() : tensor_(new Tensor()) {} | ||
|
||
void set_lod(const LoD& lod) { lod_ = lod; } | ||
LoDTensor(const LoD& lod) : lod_(lod), tensor_(new Tensor()) {} | ||
|
||
void set_tensor(Tensor* tensor) { tensor_ = tensor; } | ||
~LoDTensor() { delete tensor_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe LoDTensor should
The simplest way to implement it is make |
||
|
||
Tensor& tensor() { return *tensor_; } | ||
void set_lod(const LoD& lod) { lod_ = lod; } | ||
|
||
Tensor& tensor() { | ||
PADDLE_ENFORCE(tensor_, "The tensor_ must be null."); | ||
return *tensor_; | ||
} | ||
|
||
LoD lod() { return lod_; } | ||
|
||
void CopyLoDFrom(const LoDTensor& src) { set_lod(src.lod()); } | ||
|
||
/* | ||
* Get a element from LoD. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,13 +289,26 @@ class InferShapeContext { | |
} | ||
|
||
template <typename T> | ||
const T* Input(const std::string& name) const { | ||
virtual const T* Input(const std::string& name) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A template method cannot be virtual. |
||
auto* var = InputVar(name); | ||
return var == nullptr ? nullptr : &var->Get<T>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
var is false only if it equals nullptr. |
||
} | ||
|
||
// template <> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?? |
||
// virtual const Tensor* Input<Tensor>(const std::string& name) const { | ||
// auto* var = InputVar(name); | ||
// if (var == nullptr) return nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete these |
||
// if (var->IsType<LoDTensor>()>) { | ||
// return &var->Get<LoDTensor>()->tensor(); | ||
// } else { | ||
// PADDLE_ENFORCE(var->IsType<Tensor>()>); | ||
// return &var->Get<Tensor>(); | ||
// } | ||
// return; | ||
// } | ||
|
||
template <typename T> | ||
T* Output(const std::string& name) const { | ||
virtual T* Output(const std::string& name) const { | ||
auto var = OutputVar(name); | ||
return var == nullptr ? nullptr : var->GetMutable<T>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
} | ||
|
@@ -314,7 +327,7 @@ class InferShapeContext { | |
} | ||
|
||
template <typename T> | ||
std::vector<T*> MultiOutput(const std::string& name) const { | ||
virtual std::vector<T*> MultiOutput(const std::string& name) const { | ||
auto names = op_.Outputs(name); | ||
std::vector<T*> res; | ||
res.reserve(names.size()); | ||
|
@@ -326,6 +339,14 @@ class InferShapeContext { | |
return res; | ||
} | ||
|
||
void CopyLoD(const std::string& in_name, const std::string& out_name) { | ||
PADDLE_ENFORCE(InputVar(in_name)->IsType<LoDTensor>(), | ||
"The Input(%s) must be LoDTensor.", in_name); | ||
PADDLE_ENFORCE(OutputVar(out_name)->IsType<LoDTensor>(), | ||
"The Output(%s) must be LoDTensor.", out_name); | ||
Output<LoDTensor>(out_name)->set_lod(Input<LoDTensor>(in_name)->lod()); | ||
} | ||
|
||
private: | ||
const OperatorBase& op_; | ||
const Scope& scope_; | ||
|
@@ -363,6 +384,27 @@ class ExecutionContext : public InferShapeContext { | |
return device_context_; | ||
} | ||
|
||
template <typename T> | ||
T* Output(const std::string& name) const override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Output be a const T& Output
T* mutable_output
or
T* MutableOutput borrowed from Protobuf. |
||
auto var = OutputVar(name); | ||
// Different from InferShapeContext, call Get instread of GetMutable. | ||
return var == nullptr ? nullptr : var->Get<T>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Variable::Get 返回的是一个const reference。Variable::GetMutable返回的是一个mutable pointer。也就是说如要要返回mutable pointer,应该调用的是 GetMutable,而不是Get。这里为什么要调用Get而不是GetMutable呢? |
||
} | ||
|
||
template <typename T> | ||
std::vector<T*> MultiOutput(const std::string& name) const override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above. |
||
auto names = op_.Outputs(name); | ||
std::vector<T*> res; | ||
res.reserve(names.size()); | ||
// Different from InferShapeContext, call Get instread of GetMutable. | ||
std::transform(names.begin(), names.end(), std::back_inserter(res), | ||
[&](const std::string& sub_name) { | ||
auto var = scope_.FindVar(sub_name); | ||
return var == nullptr ? nullptr : var->Get<T>(); | ||
}); | ||
return res; | ||
} | ||
|
||
const platform::DeviceContext* device_context_; | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,21 @@ class Variable { | |
return *static_cast<const T*>(holder_->Ptr()); | ||
} | ||
|
||
// Specialized template for Tensor, | ||
// so that the Tensor can be got from LoDTensor. | ||
// How to get Tensor from LoDTensor also can be put in InferShapeContext. | ||
template <> | ||
const Tensor& Get<Tensor>() const { | ||
if (IsType<LoDTensor>()) { | ||
auto lod_t = static_cast<const LoDTensor*>(holder_->Ptr()); | ||
return *lod_t->tensor(); | ||
} else { | ||
PADDLE_ENFORCE(IsType<Tensor>(), | ||
"Variable must be type LoDTensor or Tensor"); | ||
return *static_cast<const Tensor*>(holder_->Ptr()); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了使得不用LoD信息的operators比较方便的从
|
||
|
||
template <typename T> | ||
T* GetMutable() { | ||
if (!IsType<T>()) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,13 @@ class MulOp : public framework::OperatorWithKernel { | |
PADDLE_ENFORCE_EQ( | ||
x_mat_dims[1], y_mat_dims[0], | ||
"First matrix's width must be equal with second matrix's height."); | ||
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); | ||
// Each operator's InferShape must call Output<LoDTensor> to create | ||
// LoDTensor in variable. | ||
ctx.Output<LoDTensor>("Out")->tensor()->Resize( | ||
{x_mat_dims[0], y_mat_dims[1]}); | ||
// Only the forward operator needs to pass the lod. | ||
// pass the lod of Input(X) to output. | ||
ctx.CopyLoD(/*in = */ "X", /* out = */ "Out"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在每个Op里显示调用LoD传递函数, 解释: #4047 (comment) |
||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explicit LoDTensor