Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get set lod for infershape #7606

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);

bool IsCompileTime() const override { return true; }

bool HasInput(const std::string &name) const override;

bool HasOutput(const std::string &name) const override;
Expand All @@ -51,6 +53,14 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs(
const std::string &name) const override;

LoD GetLoD(const std::string &name) const override {
PADDLE_THROW("does not support GetLoD in Compile stage");
}

void SetLoD(const std::string &name, const framework::LoD &lod) override {
PADDLE_THROW("does not support SetLoD in Compile stage");
}

void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
Expand Down
35 changes: 35 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}

bool IsCompileTime() const override { return false; }

bool HasInput(const std::string& name) const override {
auto& ins = Inputs(name);
size_t length = ins.size();
Expand Down Expand Up @@ -388,6 +390,39 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name);
}

LoD GetLoD(const std::string& name) const override {
auto names = Inputs(name);
PADDLE_ENFORCE_EQ(names.size(), 1,
"%s not found, "
"GetLoD only support get input lod",
name);
Variable* var = scope_.FindVar(names[0]);
PADDLE_ENFORCE_NOT_NULL(var, "%s not found", name);
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>()->lod();
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor.", name,
var->Type().name());
}
}

void SetLoD(const std::string& name, const framework::LoD& lod) override {
auto names = Outputs(name);
PADDLE_ENFORCE_EQ(names.size(), 1,
"%s not found, "
"SetLoD only support set output lod",
name);

Variable* var = scope_.FindVar(names[0]);
PADDLE_ENFORCE_NOT_NULL(var, "%s not found", name);
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->set_lod(lod);
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor.", name,
var->Type().name());
}
}

void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
Expand Down
4 changes: 4 additions & 0 deletions paddle/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"

namespace paddle {
namespace framework {

class InferShapeContext {
public:
virtual ~InferShapeContext() = default;
virtual bool IsCompileTime() const = 0;
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;

Expand All @@ -50,6 +52,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;

virtual LoD GetLoD(const std::string &name) const = 0;
virtual void SetLoD(const std::string &name, const LoD &lod) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;

Expand Down