-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/backward #3068
Feature/backward #3068
Conversation
Remove bool argument, use a class to handle that.
…into feature/backward
…into feature/backward
…into feature/backward
https://stackoverflow.com/questions/9243646/recursive-functions-two-functions-or-last-optional-parameter/9243666#9243666 BackwardImpl => BackwardRecursive |
class FillZerosLikeKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::KernelContext& context) const override { | ||
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto output
#include "paddle/operators/fill_zeros_like_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL( | ||
fill_zeros_like, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why name this fill_zerors_like
, numpy call it zeros_like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zeros_like
is not a verb. But all operator should be a verb
. I think fill_zeros_like is more explicit.
Also, zeros_like
may be implemented without filling anything into memory but a flag. Here named fill
indicate that actually a new memory is allocated.
|
||
template <typename Place, typename T> | ||
class FillZerosLikeKernel : public framework::OpKernel { | ||
public: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a Python unit test for this operator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will add the unit test in other PR. @Canpio @dzhwinter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reyoung get it
paddle/framework/operator.cc
Outdated
auto output_format = GetAttr<std::vector<int>>("output_format"); | ||
auto offset = in_out_idxs_->at(name); | ||
|
||
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(int) ==> static_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
op1->inputs_ = {"x", "w1", "b1"}; | ||
op1->outputs_ = {"y"}; | ||
net.AddOp(op1); | ||
net.InsertOp(0, op1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if the inserted op is in the right place
paddle/framework/backward.cc
Outdated
* | ||
* See Backward.h for details | ||
*/ | ||
static std::shared_ptr<OperatorBase> BackwardImpl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is MakeBackwardOp or GetBackwardOp a better name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. BackwardRecursive.
paddle/framework/backward.cc
Outdated
} | ||
|
||
/** | ||
* All output gradients of forwarding operator do not need to calculate. Then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add If
All output gradient ...
paddle/framework/backward.cc
Outdated
OpRegistry::CreateOp( | ||
"add", {dup_outputs}, {name}, | ||
{{"input_format", | ||
std::vector<int>{0, (int)dup_outputs.size()}}})}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle/framework/backward.cc
Outdated
* | ||
* See Backward.h for details | ||
*/ | ||
static std::shared_ptr<OperatorBase> BackwardImpl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not in a class, why need static
?
paddle/framework/backward.cc
Outdated
|
||
//! Map from output gradient variable name to operator's indices in backward | ||
//! net. That operator generates that variable. | ||
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is dup
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicated
namespace paddle { | ||
namespace framework { | ||
|
||
static bool AllInSet(const std::vector<std::string>& names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to use static
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we do not export them to global symbols.
insert_position.push_back( | ||
{dup_op.back(), | ||
OpRegistry::CreateOp( | ||
"add", {dup_outputs}, {name}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个add op现在应该还没实现?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的。这个Op的实现不影响Backward算法的实现和单测。
paddle/framework/backward_test.cc
Outdated
|
||
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), | ||
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); | ||
// LOG(INFO) << gop->Output("X" + "@GRAD"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); | ||
ASSERT_NE(fwd, nullptr); | ||
auto gop = f::OpRegistry::CreateGradOp(*fwd); | ||
ASSERT_EQ(1UL, gop->inputs_.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记得说GradientOp的Input是I,O,OG,这里为什么gop的input size为1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为有一些Op的梯度可以不使用I和O,这个在注册的时候注册过了。
gop->outputs_.end()); | ||
|
||
auto no_input_gop = f::Backward(*fwd, {"X", "b"}); | ||
ASSERT_NE(no_input_gop, nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好像没有做任何测试?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -23,8 +23,8 @@ namespace operators { | |||
template <typename Place, typename T> | |||
class FillZerosLikeKernel : public framework::OpKernel { | |||
public: | |||
void Compute(const framework::KernelContext& context) const override { | |||
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); | |||
void Compute(const framework::ExecutionContext& context) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
framework::xxx ===> xxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in Next PR
有一些注释的comment,也可以顺便update一下 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!except for some typo in comment and doc
47dbbdb
to
737ea05
Compare
@dzhwinter @Canpio @wangkuiyi I prepare to merge this PR soon, all code-style comments are followed. But it still lacks documentation and comments. But it has been a huge PR already. Documentation and comments will be added in following PR. |
No description provided.