-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactorize framework/*.proto #3322
Refactorize framework/*.proto #3322
Conversation
paddle/framework/framework.proto
Outdated
|
||
message Var { | ||
required string name; // e.g. "X" | ||
optional int dup = 2 [ default = 0 ]; // e.g., "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment here ?
// indices the duplica
optional int dup = 2 [ default = 0 ];
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
message Var {
required string op_proto_name;
repeated string name;
};
paddle/framework/framework.proto
Outdated
required string comment = 2; | ||
// OpDesc::Var::dup indices the duplica. | ||
optional bool duplicable = 3 [ default = false ]; | ||
optional bool intermediate = 4 [ default = false ]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is intermediate
used to replace temporary
?
The source code should be changed a lot after this PR and we also need a design_doc to describe the design of OpDesc and OpProto. |
* Although backward_test/rnn_test is not pass, just comment them.
Step 1: Make code compile well.
…orize_framework_proto
…orize_framework_proto
…orize_framework_proto
Update grad_op_builder after refactoring framework proto.
…ork_proto Modify rnn op unit test after refactoring framework proto.
Feature/refactorize framework proto
…orize_framework_proto
Catch-up with develop branch
paddle/framework/grad_op_builder.cc
Outdated
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
typedef std::vector<int> Ints; | ||
class OpRegistry; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use forward declare in C++;
https://google.github.io/styleguide/cppguide.html#Forward_Declarations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
std::vector<int>* dst_format = AttrFormat(grad_attrs, dst_type); | ||
|
||
const OpProto& proto = OpRegistry::protos().at(src_op->type_); | ||
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since src_op
is not nullable, it should be const OperatorBase&
.
But it is not related to this PR. Please issue another PR when this PR has been merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/framework/op_registry.h
Outdated
|
||
VariableBuilder& SetMultiple() { | ||
var_->set_multiple(true); | ||
on_multiple_(); | ||
var_->set_duplicable(true); | ||
return *this; | ||
} | ||
|
||
VariableBuilder& SetTemporary() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method name should be renamed to SetIntermediate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/op_registry.h
Outdated
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); | ||
var_->set_temporary(true); | ||
on_temporary_(); | ||
var_->set_intermediate(true); | ||
return *this; | ||
} | ||
|
||
VariableBuilder& IgnoreGradient() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to NoGradient
.
However, I do not think NoGradient
is a straight name. Maybe we should change it to another name in following PR @Canpio @qingqing01
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think the meaning of IgnoreGradient
is not clear. But NoNeedInGradientOp
is too long.
paddle/framework/op_registry.h
Outdated
}; | ||
|
||
class OpRegistry { | ||
using OpCreator = std::function<OperatorBase*()>; | ||
using VarIndexMap = std::unordered_map<std::string, int>; | ||
using VarNameList = std::vector<std::string>; | ||
using VarNameMap = std::map<std::string, std::vector<std::string>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this typedef to OperatorBase
is a better idea. Because it is used by OperatorBase::inputs_/outputs_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/op_registry.h
Outdated
inputs.reserve((size_t)op_desc.inputs_size()); | ||
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), | ||
std::back_inserter(inputs)); | ||
VarNameMap inputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy and Paste code here. This logic should be extracted as a common function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/op_registry_test.cc
Outdated
@@ -61,8 +59,13 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp, | |||
TEST(OpRegistry, CreateOp) { | |||
paddle::framework::OpDesc op_desc; | |||
op_desc.set_type("cos_sim"); | |||
op_desc.add_inputs("aa"); | |||
op_desc.add_outputs("bb"); | |||
auto input = op_desc.add_inputs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same Code Again & Again.
Maybe we can extract them into a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/operator.cc
Outdated
"Input Out Of Range"); | ||
const std::string& OperatorBase::Input(const std::string& name) const { | ||
auto it = inputs_.find(name); | ||
PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have input %s", type_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_NE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It cannot use PADDLE_ENFORCE_NE
because operator <<
of it is not defined.
paddle/framework/operator.cc
Outdated
inputs_.begin() + input_format.at(offset), | ||
inputs_.begin() + input_format.at(offset + 1)}; | ||
const std::vector<std::string>& OperatorBase::Inputs( | ||
const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE when not found name. Give a reasonable error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/operator.cc
Outdated
auto it = in_out_idxs_->find(name); | ||
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", | ||
auto it = outputs_.find(name); | ||
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_NE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It cannot use PADDLE_ENFORCE_NE because operator << of it is not defined.
outputs_.begin() + output_format.at(offset), | ||
outputs_.begin() + output_format.at(offset + 1)}; | ||
const std::vector<std::string>& OperatorBase::Outputs( | ||
const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_NOT_FOUND
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/framework/operator.h
Outdated
std::vector<std::string> Outputs(const std::string& name) const; | ||
const std::vector<std::string>& Outputs(const std::string& name) const; | ||
|
||
virtual std::vector<std::string> OutputVars(bool has_intermediate) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in the header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/operator.h
Outdated
|
||
const std::string Type() const { return type_; } | ||
const std::vector<std::string> Inputs() const { return inputs_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad design in develop. Return a const T
is not useful in C++. It should return T
or const T&
paddle/framework/operator.h
Outdated
|
||
public: | ||
std::string type_; | ||
// NOTE: in case of OpGrad, inputs_ contains: | ||
// I (Inputs) | ||
// O (Outputs) | ||
// OG (Output Gradients) | ||
std::vector<std::string> inputs_; | ||
std::map<std::string, std::vector<std::string>> inputs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use VarNameMap
before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
const Variable* InputVar(const size_t index) const { | ||
return scope_.FindVar(op_.inputs_.at(index)); | ||
size_t InputSize(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
Variable* OutputVar(const size_t index) const { | ||
return scope_.FindVar(op_.outputs_.at(index)); | ||
size_t OutputSize(const std::string& name) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/operator_test.cc
Outdated
@@ -62,8 +62,13 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, | |||
TEST(OperatorBase, all) { | |||
paddle::framework::OpDesc op_desc; | |||
op_desc.set_type("test_operator"); | |||
*op_desc.mutable_inputs()->Add() = "IN1"; | |||
*op_desc.mutable_outputs()->Add() = "OUT1"; | |||
auto* ipt = op_desc.mutable_inputs()->Add(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same code again & again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
rnn_op_->Run(scope_, ctx); | ||
} | ||
using namespace paddle::framework; | ||
// using framework::make_ddim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either Remove or uncomment this lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).''' | ||
self.assertEqual(expected, str(backward_op)) | ||
|
||
#class TestAddGradOp(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unit test should be uncomment
@@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( | |||
insert_position.push_back( | |||
{dup_op.back(), | |||
OpRegistry::CreateOp( | |||
"add", {dup_outputs}, {name}, | |||
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}}, | |||
{{"input_format", | |||
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input_format is out of use now. It can be removed.
if (b_name != kEmptyVarName) { | ||
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, | ||
{Output("add_result")}, {})); | ||
if (input_b.size() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (InputSize("b") != 0) {
{Output("add_result")}, {})); | ||
if (input_b.size() != 0) { | ||
AddOp(OpRegistry::CreateOp( | ||
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input("b") and remove line 87.
paddle/framework/op_registry.h
Outdated
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); | ||
var_->set_temporary(true); | ||
on_temporary_(); | ||
var_->set_intermediate(true); | ||
return *this; | ||
} | ||
|
||
VariableBuilder& IgnoreGradient() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think the meaning of IgnoreGradient
is not clear. But NoNeedInGradientOp
is too long.
@@ -295,11 +217,13 @@ class OpRegistry { | |||
|
|||
static void GenerateTempVariableName(OperatorBase* op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we replace temp by intermediate? since we use intermediate in proto.
@qingqing01 I think we should create several issues that this PR discovered but has not been fixed. |
…kuiyi/Paddle into refactorize_framework_proto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge it to develop soon, not block others implementation.
Fixes #3320
attribute.proto
,op_desc.proto
,op_proto.proto
intoframework.proto
.VarProto
intoOpProto::Var
.OpDesc::Var
to change the way to representing "multiple" (which should be "duplicable") attributes.op_desc_test.cc
andop_proto_test.cc
, which doesn't include effective test code.