-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prelu with forward, backward and python test passed #4121
Conversation
paddle/operators/prelu_op.cc
Outdated
namespace paddle { | ||
namespace operators { | ||
|
||
class PreluOp : public framework::OperatorWithKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PreluOp -> PReluOp
paddle/operators/prelu_op.cc
Outdated
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
auto *out = ctx.Output<framework::LoDTensor>("Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add not-null check for Input, like : https://github.com/PaddlePaddle/Paddle/pull/4086/files#diff-1fcd5ee1c1e63ed40789a0e60fdb1bf6R29 . Since the more check, the more readable error messages.
paddle/operators/prelu_op.cc
Outdated
|
||
The equation is: | ||
f(x) = alpha * x , for x < 0 | ||
f(x) = x , for x >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add space before and after formula. Also add indent for formula, since the doc may be converted MarkDown.
paddle/operators/prelu_op.cc
Outdated
f(x) = x , for x >= 0 | ||
)DOC"); | ||
AddAttr<float>("alpha", "The scaling factor alpha of prelu.") | ||
.SetDefault(0.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the Attr before doc. Should use type template for attr alpha
. Please refer to scale_op.
paddle/operators/prelu_op.cc
Outdated
}; | ||
|
||
// The operator to calculate gradients of a prelu operator. | ||
class PreluGradOp : public framework::OperatorWithKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PreluGradOp -> PReluGradOp
paddle/operators/prelu_op.h
Outdated
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||
|
||
template <typename Place, typename T> | ||
class PreluKernel : public framework::OpKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PreluKernel -> PReluKernel
paddle/operators/prelu_op.h
Outdated
|
||
// auto place = context.GetEigenDevice<Place>(); | ||
// Out_vec.device(place) | ||
Out_vec = X_vec.cwiseMax(0.f) + X_vec.cwiseMin(0.f) * alpha; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to use Eigen device to support GPU and CPU at the same time.
class PreluTest(OpTest): | ||
def setUp(self): | ||
self.op_type = "prelu" | ||
self.inputs = {'X': np.random.normal(size=(3, 5)).astype("float32")} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size=(3, 5)
may be too small.
paddle/operators/prelu_op.h
Outdated
dX->data<T>()[i] = dO->data<T>()[i]; | ||
} else { | ||
dX->data<T>()[i] = dO->data<T>()[i] * alpha; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用paddle::platform::Transform + functor同时支持CPU和GPU.
Transfrom: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/transform.h
Transfrom用例可以参考单测: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/transform_test.cu
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the incomplete comments last time, the not-null check for output is also needed:
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(X) should not be null");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the comment "Output(X) should not be null" is unnecessary, PADDLE_ENFORCE_NOT_NULL is enough semantically.
paddle/operators/prelu_op.cc
Outdated
|
||
)DOC"); | ||
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.") | ||
.SetDefault(0.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the AddComment(R"DOC )DOC")
at last.
paddle/operators/prelu_op.cc
Outdated
f(x) = x , for x >= 0 | ||
|
||
)DOC"); | ||
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different from other activations, the alpha
in PRelu
is a learnable weight, so it should be the input, not the attr. And need to calculate the gradient of this weight in the backward op.
paddle/operators/prelu_op.h
Outdated
using platform::Transform; | ||
|
||
template <typename T> | ||
class Prelu_functor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use google C++ Style: https://google.github.io/styleguide/cppguide.html#Type_Names
template <typename T>
class PReluFunctor {
};
paddle/operators/prelu_op.h
Outdated
}; | ||
|
||
template <typename T> | ||
class Prelu_Grad_functor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prelu_Grad_functor -> PReluGradFunctor
paddle/operators/prelu_op.h
Outdated
class PReluGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dX -> dx , https://google.github.io/styleguide/cppguide.html#Variable_Names
The names of variables (including function parameters) and data members are all lowercase, with underscores between words.
paddle/operators/prelu_op.h
Outdated
|
||
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha")); | ||
|
||
T* dX_ptr = dX->mutable_data<T>(context.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dX_ptr -> dx_ptr
from op_test import OpTest | ||
|
||
|
||
class PreluTest(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PreluTest -> PReluTest
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If modify the alpha as input, also add the check for ignoring one of the input's gradient, like mul_op:
: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py#L49
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); | ||
auto *in = ctx.Input<framework::Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first time to get inputs, enforce all the pointers are not null, or this will core.
public: | ||
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The input tensor of prelu operator."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prelu -> PRELU or some other names, make sure this represents a shorthand for some algorithm.
paddle/operators/prelu_op.cc
Outdated
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); | ||
auto *X = ctx.Input<framework::Tensor>("X"); | ||
|
||
X_grad->Resize(X->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enforce X_grad is not null
paddle/operators/prelu_op.h
Outdated
HOSTDEVICE T operator()(const T& X) const { | ||
if (X > 0) | ||
return X; | ||
else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if () {} else {}
or just return X>0 ? X : X * alpha_;
paddle/operators/prelu_op.h
Outdated
explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {} | ||
|
||
HOSTDEVICE T operator()(const T& Out, const T& dOut) const { | ||
if (Out > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
paddle/operators/prelu_op.h
Outdated
|
||
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha")); | ||
|
||
T* dX_ptr = dX->mutable_data<T>(context.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enforce dX is not null first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zchen0211 @Superjom
I approve this PR. Since PReluGradKernel will be be updated later. @zchen0211 can update codes based on @Superjom 's comments later.
Fix #4167
Will add GPU when gpu parts are ready.