-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Backward #2949
[WIP] Backward #2949
Conversation
limitations under the License. */ | ||
|
||
#include "paddle/framework/fully_connected_op.h" | ||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include <iostream>
#include "paddle/framework/fully_connected_op.h"
We put forward Op's inputs, outputs and output gradients into Grad Op's inputs, and put forward Op's input gradients into Grad Op's output. So Grad Op's `in_out_idx`, `input_format` and 'output format' need to be rebuilt during Op creating.
paddle/framework/op_registry.h
Outdated
*/ | ||
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__reg_op__##__op_type, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名字可以再特殊一些,例如
__reg_gradient_op__##__op_type,
否则比较容易和其他的REGISTER冲突。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/op_registry.h
Outdated
__reg_op__##__op_type, \ | ||
"REGISTER_GRADIENT_OP must be in global namespace"); \ | ||
static ::paddle::framework::GradOpRegisterHelper<__op_class> \ | ||
__op_register_##__op_type##__(#__op_type); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理,这个名字也可以特殊一些。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/op_registry.h
Outdated
@@ -274,6 +281,18 @@ class OpRegistry { | |||
return CreateOp(op_desc.type(), inputs, outputs, attrs); | |||
} | |||
|
|||
static OperatorPtr CreateGradOp(OperatorPtr op) { | |||
OperatorPtr grad_op(grad_creators().at(op->type_)()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not found that Gradient Op,just return nullptr is cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -101,5 +101,7 @@ class PlainNet : public Net { | |||
} | |||
}; | |||
|
|||
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this method might be useless, we can directly invoke OpRegistry::CreateGradOp(...)
even that Op is a net.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. I remembered that AddBackwardOp
is an interface for PlainNet
, which generate the backward ops, then it can be used in Python. Remove it straightly is fine to me.
@reyoung reminded us that not all inputs/outputs are required by Op's gradient calculating. So we are going to refactor related functions to support flexible Grad Op inputs/outputs generating. |
paddle/framework/op_registry.h
Outdated
@@ -274,6 +285,24 @@ class OpRegistry { | |||
return CreateOp(op_desc.type(), inputs, outputs, attrs); | |||
} | |||
|
|||
static OperatorPtr CreateGradOp(OperatorPtr op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CreateGradientOp ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we named every Grad operator in XXXGrad, so we create interface named in CreateGradOp
paddle/framework/op_registry.h
Outdated
static std::unordered_map<std::string, OpCreator>& creators() { | ||
static std::unordered_map<std::string, OpCreator> creators_; | ||
return creators_; | ||
static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AssembleInOutGradient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -63,6 +63,11 @@ class OperatorBase { | |||
/// but it will be convert to a unique name in scope after OpCreator. | |||
static std::string TMP_VAR_NAME() { return "@TEMP@"; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a static const member better? no need a getter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const static member can not ensure the calling order. See the effective C++ rule.
Replace `OperatorPtr` with `std::shared_ptr<OperatorBase>`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
naive implement of backward op. Need @Canpio add reorder of input index.