diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 568f4e89819c83..a514b96553af05 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -53,17 +53,23 @@ bool operator==(const LoD& a, const LoD& b); */ class LoDTensor { public: - LoDTensor() {} - LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {} + LoDTensor() : tensor_(new Tensor()) {} - void set_lod(const LoD& lod) { lod_ = lod; } + LoDTensor(const LoD& lod) : lod_(lod), tensor_(new Tensor()) {} - void set_tensor(Tensor* tensor) { tensor_ = tensor; } + ~LoDTensor() { delete tensor_; } - Tensor& tensor() { return *tensor_; } + void set_lod(const LoD& lod) { lod_ = lod; } + + Tensor& tensor() { + PADDLE_ENFORCE(tensor_, "The tensor_ must be null."); + return *tensor_; + } LoD lod() { return lod_; } + void CopyLoDFrom(const LoDTensor& src) { set_lod(src.lod()); } + /* * Get a element from LoD. */ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4600b06009bcef..3cae7cb81c1d66 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -289,13 +289,26 @@ class InferShapeContext { } template - const T* Input(const std::string& name) const { + virtual const T* Input(const std::string& name) const { auto* var = InputVar(name); return var == nullptr ? nullptr : &var->Get(); } + // template <> + // virtual const Tensor* Input(const std::string& name) const { + // auto* var = InputVar(name); + // if (var == nullptr) return nullptr; + // if (var->IsType()>) { + // return &var->Get()->tensor(); + // } else { + // PADDLE_ENFORCE(var->IsType()>); + // return &var->Get(); + // } + // return; + // } + template - T* Output(const std::string& name) const { + virtual T* Output(const std::string& name) const { auto var = OutputVar(name); return var == nullptr ? nullptr : var->GetMutable(); } @@ -314,7 +327,7 @@ class InferShapeContext { } template - std::vector MultiOutput(const std::string& name) const { + virtual std::vector MultiOutput(const std::string& name) const { auto names = op_.Outputs(name); std::vector res; res.reserve(names.size()); @@ -326,6 +339,14 @@ class InferShapeContext { return res; } + void CopyLoD(const std::string& in_name, const std::string& out_name) { + PADDLE_ENFORCE(InputVar(in_name)->IsType(), + "The Input(%s) must be LoDTensor.", in_name); + PADDLE_ENFORCE(OutputVar(out_name)->IsType(), + "The Output(%s) must be LoDTensor.", out_name); + Output(out_name)->set_lod(Input(in_name)->lod()); + } + private: const OperatorBase& op_; const Scope& scope_; @@ -363,6 +384,27 @@ class ExecutionContext : public InferShapeContext { return device_context_; } + template + T* Output(const std::string& name) const override { + auto var = OutputVar(name); + // Different from InferShapeContext, call Get instread of GetMutable. + return var == nullptr ? nullptr : var->Get(); + } + + template + std::vector MultiOutput(const std::string& name) const override { + auto names = op_.Outputs(name); + std::vector res; + res.reserve(names.size()); + // Different from InferShapeContext, call Get instread of GetMutable. + std::transform(names.begin(), names.end(), std::back_inserter(res), + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + return var == nullptr ? nullptr : var->Get(); + }); + return res; + } + const platform::DeviceContext* device_context_; }; diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index 38fc2720a30230..0904e8067e136a 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -29,6 +29,21 @@ class Variable { return *static_cast(holder_->Ptr()); } + // Specialized template for Tensor, + // so that the Tensor can be got from LoDTensor. + // How to get Tensor from LoDTensor also can be put in InferShapeContext. + template <> + const Tensor& Get() const { + if (IsType()) { + auto lod_t = static_cast(holder_->Ptr()); + return *lod_t->tensor(); + } else { + PADDLE_ENFORCE(IsType(), + "Variable must be type LoDTensor or Tensor"); + return *static_cast(holder_->Ptr()); + } + } + template T* GetMutable() { if (!IsType()) { diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 710a56a0e8e2d1..2cda41fafd5399 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -45,7 +45,13 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); + // Each operator's InferShape must call Output to create + // LoDTensor in variable. + ctx.Output("Out")->tensor()->Resize( + {x_mat_dims[0], y_mat_dims[1]}); + // Only the forward operator needs to pass the lod. + // pass the lod of Input(X) to output. + ctx.CopyLoD(/*in = */ "X", /* out = */ "Out"); } };