Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Using LoDTensor instead of Tensor in every operator. #4048

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions paddle/framework/lod_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,23 @@ bool operator==(const LoD& a, const LoD& b);
*/
class LoDTensor {
public:
LoDTensor() {}
LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {}
LoDTensor() : tensor_(new Tensor()) {}

void set_lod(const LoD& lod) { lod_ = lod; }
LoDTensor(const LoD& lod) : lod_(lod), tensor_(new Tensor()) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicit LoDTensor


void set_tensor(Tensor* tensor) { tensor_ = tensor; }
~LoDTensor() { delete tensor_; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoDTensor管理成员变量Tensor* tensor_的构造与释放,原因:#4047 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe LoDTensor should

  • Disable copy
  • Disable assign
  • Disable move

The simplest way to implement it is make tensor_ as unique_ptr


Tensor& tensor() { return *tensor_; }
void set_lod(const LoD& lod) { lod_ = lod; }

Tensor& tensor() {
PADDLE_ENFORCE(tensor_, "The tensor_ must be null.");
return *tensor_;
}

LoD lod() { return lod_; }

void CopyLoDFrom(const LoDTensor& src) { set_lod(src.lod()); }

/*
* Get a element from LoD.
*/
Expand Down
48 changes: 45 additions & 3 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,26 @@ class InferShapeContext {
}

template <typename T>
const T* Input(const std::string& name) const {
virtual const T* Input(const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A template method cannot be virtual.

auto* var = InputVar(name);
return var == nullptr ? nullptr : &var->Get<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var ? var : &var->Get<T>()

var is false only if it equals nullptr.

}

// template <>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??
either delete it or uncomment it

// virtual const Tensor* Input<Tensor>(const std::string& name) const {
// auto* var = InputVar(name);
// if (var == nullptr) return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete these

// if (var->IsType<LoDTensor>()>) {
// return &var->Get<LoDTensor>()->tensor();
// } else {
// PADDLE_ENFORCE(var->IsType<Tensor>()>);
// return &var->Get<Tensor>();
// }
// return;
// }

template <typename T>
T* Output(const std::string& name) const {
virtual T* Output(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

}
Expand All @@ -314,7 +327,7 @@ class InferShapeContext {
}

template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const {
virtual std::vector<T*> MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
Expand All @@ -326,6 +339,14 @@ class InferShapeContext {
return res;
}

void CopyLoD(const std::string& in_name, const std::string& out_name) {
PADDLE_ENFORCE(InputVar(in_name)->IsType<LoDTensor>(),
"The Input(%s) must be LoDTensor.", in_name);
PADDLE_ENFORCE(OutputVar(out_name)->IsType<LoDTensor>(),
"The Output(%s) must be LoDTensor.", out_name);
Output<LoDTensor>(out_name)->set_lod(Input<LoDTensor>(in_name)->lod());
}

private:
const OperatorBase& op_;
const Scope& scope_;
Expand Down Expand Up @@ -363,6 +384,27 @@ class ExecutionContext : public InferShapeContext {
return device_context_;
}

template <typename T>
T* Output(const std::string& name) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output be a const T& better?

const T& Output
T* mutable_output
or 
T* MutableOutput

borrowed from Protobuf.

auto var = OutputVar(name);
// Different from InferShapeContext, call Get instread of GetMutable.
return var == nullptr ? nullptr : var->Get<T>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable::Get 返回的是一个const reference。Variable::GetMutable返回的是一个mutable pointer。也就是说如要要返回mutable pointer,应该调用的是 GetMutable,而不是Get。这里为什么要调用Get而不是GetMutable呢?

}

template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above.

auto names = op_.Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
// Different from InferShapeContext, call Get instread of GetMutable.
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->Get<T>();
});
return res;
}

const platform::DeviceContext* device_context_;
};

Expand Down
15 changes: 15 additions & 0 deletions paddle/framework/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ class Variable {
return *static_cast<const T*>(holder_->Ptr());
}

// Specialized template for Tensor,
// so that the Tensor can be got from LoDTensor.
// How to get Tensor from LoDTensor also can be put in InferShapeContext.
template <>
const Tensor& Get<Tensor>() const {
if (IsType<LoDTensor>()) {
auto lod_t = static_cast<const LoDTensor*>(holder_->Ptr());
return *lod_t->tensor();
} else {
PADDLE_ENFORCE(IsType<Tensor>(),
"Variable must be type LoDTensor or Tensor");
return *static_cast<const Tensor*>(holder_->Ptr());
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了使得不用LoD信息的operators比较方便的从LoDTensor里获取Tensor

  1. 这里特化了Variable::Get<Tensor>接口使得可以
  2. 也可通过方式 https://github.com/PaddlePaddle/Paddle/pull/4048/files#diff-91a47df6a639afa5ad046a8d43c640e0R298 在特化InferShapeContext::Input<Tensor>(const std::string& name)InferShapeContext::Output<Tensor>(const std::string& name) 接口。


template <typename T>
T* GetMutable() {
if (!IsType<T>()) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
// Each operator's InferShape must call Output<LoDTensor> to create
// LoDTensor in variable.
ctx.Output<LoDTensor>("Out")->tensor()->Resize(
{x_mat_dims[0], y_mat_dims[1]});
// Only the forward operator needs to pass the lod.
// pass the lod of Input(X) to output.
ctx.CopyLoD(/*in = */ "X", /* out = */ "Out");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在每个Op里显示调用LoD传递函数, 解释: #4047 (comment)

}
};

Expand Down