Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactorize framework/*.proto #3322

Merged
merged 45 commits into from
Aug 14, 2017

Conversation

wangkuiyi
Copy link
Collaborator

@wangkuiyi wangkuiyi commented Aug 8, 2017

Fixes #3320

  1. Merge attribute.proto, op_desc.proto, op_proto.proto into framework.proto.
  2. Move message VarProto into OpProto::Var.
  3. Add message OpDesc::Var to change the way to representing "multiple" (which should be "duplicable") attributes.
  4. Remove op_desc_test.cc and op_proto_test.cc, which doesn't include effective test code.


message Var {
required string name; // e.g. "X"
optional int dup = 2 [ default = 0 ]; // e.g., "1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment here ?

// indices the duplica
optional int dup = 2 [ default = 0 ];

Copy link
Collaborator

@reyoung reyoung Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message Var {
  required string op_proto_name;
  repeated string name;
};

required string comment = 2;
// OpDesc::Var::dup indices the duplica.
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is intermediate used to replace temporary?

@jacquesqiao
Copy link
Member

The source code should be changed a lot after this PR and we also need a design_doc to describe the design of OpDesc and OpProto.

reyoung and others added 20 commits August 8, 2017 14:55
* Although backward_test/rnn_test is not pass, just comment them.
Update grad_op_builder after refactoring framework proto.
…ork_proto

Modify rnn op unit test after refactoring framework proto.
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {

typedef std::vector<int> Ints;
class OpRegistry;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

std::vector<int>* dst_format = AttrFormat(grad_attrs, dst_type);

const OpProto& proto = OpRegistry::protos().at(src_op->type_);
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since src_op is not nullable, it should be const OperatorBase&.

But it is not related to this PR. Please issue another PR when this PR has been merged.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


VariableBuilder& SetMultiple() {
var_->set_multiple(true);
on_multiple_();
var_->set_duplicable(true);
return *this;
}

VariableBuilder& SetTemporary() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method name should be renamed to SetIntermediate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
var_->set_intermediate(true);
return *this;
}

VariableBuilder& IgnoreGradient() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to NoGradient.

However, I do not think NoGradient is a straight name. Maybe we should change it to another name in following PR @Canpio @qingqing01

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the meaning of IgnoreGradient is not clear. But NoNeedInGradientOp is too long.

};

class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this typedef to OperatorBase is a better idea. Because it is used by OperatorBase::inputs_/outputs_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(inputs));
VarNameMap inputs;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy and Paste code here. This logic should be extracted as a common function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -61,8 +59,13 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");
auto input = op_desc.add_inputs();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same Code Again & Again.

Maybe we can extract them into a function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Input Out Of Range");
const std::string& OperatorBase::Input(const std::string& name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_NE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot use PADDLE_ENFORCE_NE because operator << of it is not defined.

inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
const std::vector<std::string>& OperatorBase::Inputs(
const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE when not found name. Give a reasonable error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_NE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot use PADDLE_ENFORCE_NE because operator << of it is not defined.

outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
const std::vector<std::string>& OperatorBase::Outputs(
const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE_NOT_FOUND

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<std::string> Outputs(const std::string& name) const;
const std::vector<std::string>& Outputs(const std::string& name) const;

virtual std::vector<std::string> OutputVars(bool has_intermediate) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the header.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


const std::string Type() const { return type_; }
const std::vector<std::string> Inputs() const { return inputs_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad design in develop. Return a const T is not useful in C++. It should return T or const T&


public:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::vector<std::string> inputs_;
std::map<std::string, std::vector<std::string>> inputs_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use VarNameMap before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


const Variable* InputVar(const size_t index) const {
return scope_.FindVar(op_.inputs_.at(index));
size_t InputSize(const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

Variable* OutputVar(const size_t index) const {
return scope_.FindVar(op_.outputs_.at(index));
size_t OutputSize(const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -62,8 +62,13 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1";
*op_desc.mutable_outputs()->Add() = "OUT1";
auto* ipt = op_desc.mutable_inputs()->Add();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same code again & again

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

rnn_op_->Run(scope_, ctx);
}
using namespace paddle::framework;
// using framework::make_ddim;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either Remove or uncomment this lines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
self.assertEqual(expected, str(backward_op))

#class TestAddGradOp(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unit test should be uncomment

@@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {dup_outputs}, {name},
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input_format is out of use now. It can be removed.

if (b_name != kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("add_result")}, {}));
if (input_b.size() != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (InputSize("b") != 0) {

{Output("add_result")}, {}));
if (input_b.size() != 0) {
AddOp(OpRegistry::CreateOp(
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input("b") and remove line 87.

PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
var_->set_intermediate(true);
return *this;
}

VariableBuilder& IgnoreGradient() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the meaning of IgnoreGradient is not clear. But NoNeedInGradientOp is too long.

@@ -295,11 +217,13 @@ class OpRegistry {

static void GenerateTempVariableName(OperatorBase* op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we replace temp by intermediate? since we use intermediate in proto.

@reyoung
Copy link
Collaborator

reyoung commented Aug 14, 2017

@qingqing01 I think we should create several issues that this PR discovered but has not been fixed.

Copy link
Collaborator

@reyoung reyoung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge it to develop soon, not block others implementation.

@reyoung reyoung merged commit 81f5f86 into PaddlePaddle:develop Aug 14, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants