-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pad op #3765
Add pad op #3765
Conversation
… pad_op Conflicts: paddle/pybind/pybind.cc
paddle/operators/pad_op.h
Outdated
// Eigen::DenseIndex>> X_tensor = EigenTensor<T, 2>::From(*X); | ||
// Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>> | ||
// Out_tensor = EigenTensor<T, 2>::From(*Out); | ||
EigenTensor<T, dims.size()>::ConstType X_tensor = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的dims的值需要在编译期就确定了,不可以在这里通过传递一个变量来运行时确定。所以PadKernel需要额外多一个模板参数。
template<typename Place, typename T, int Dims>
class PadKernel : public framework::OpKernel {
...
EigenTensor<T, Dims>::ConstType x_tensor = EigenTensor<T, Dims>::From(*X);
...
};
可以参考tensorflow的pad_op的写法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
觉得应该是想TensorFlow一样,pad functor带模板Dims(或叫Rank)即可: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/pad_op.cc#L143
不用 PadKernel
不用带模板参数Dims
paddle/operators/pad_op.h
Outdated
|
||
#pragma once | ||
|
||
#include "paddle/operators/math/math_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
paddle/operators/pad_op.cc
Outdated
|
||
class MulOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MulOpMaker -> PadOpMaker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto dim1 = ctx.Output<Tensor>("Out")->dims(); | ||
auto paddings = GetAttr<std::vector<std::pair<int32, int32>>>("paddings"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check the paddings.size() == dim0.size()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
op = Operator("pad", paddings=((0, 1), (2, 3)), pad_value=0) | ||
inputs = {'X': np.random.random((16, 16)).astype("float32"), } | ||
|
||
self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add compare_grad
to compare the CPU and GPU results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.h
Outdated
// Eigen::DenseIndex>> X_tensor = EigenTensor<T, 2>::From(*X); | ||
// Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>> | ||
// Out_tensor = EigenTensor<T, 2>::From(*Out); | ||
EigenTensor<T, dims.size()>::ConstType X_tensor = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
觉得应该是想TensorFlow一样,pad functor带模板Dims(或叫Rank)即可: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/pad_op.cc#L143
不用 PadKernel
不用带模板参数Dims
… pad_op Conflicts: paddle/operators/CMakeLists.txt paddle/pybind/CMakeLists.txt
paddle/operators/pad_op.h
Outdated
template <typename Place, typename T, size_t D> | ||
void PadFunction(const framework::ExecutionContext& context) { | ||
auto pads = | ||
context.op().GetAttr<std::vector<std::pair<int, int>>>("paddings"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context.op().
-> context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.h
Outdated
} | ||
T pad_value = context.op().GetAttr<T>("pad_value"); | ||
|
||
auto* X = context.Input<Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names of variables (including function parameters) and data members are all lowercase
https://google.github.io/styleguide/cppguide.html#Variable_Names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.h
Outdated
for (int i = 0; i < pads.size(); ++i) { | ||
paddings[i] = pads[i]; | ||
} | ||
T pad_value = context.op().GetAttr<T>("pad_value"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context.op() -> context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.h
Outdated
PadFunction<Place, T, 6>(context); | ||
break; | ||
default: | ||
LOG(ERROR) << "Only ranks up to 6 supported."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(ERROR) -> PADDLE_THROW(“Only support tensor with ranks up to 6.”)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
paddle/operators/pad_op.h
Outdated
default: | ||
LOG(ERROR) << "Only ranks up to 6 supported."; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problem with PadFunction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
def test_normal(self): | ||
self.check_grad( | ||
self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove max_relative_error=0.5
. For pad op, the max_relative_error should be very small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.cc
Outdated
AddAttr<std::vector<std::pair<int, int>>>( | ||
"paddings", "The padding rules for each dimension"); | ||
AddAttr<float>("pad_value", "The value to be padded into tensor") | ||
.SetDefault(0.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need more detail and meaningful comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
'Out': np.pad(self.inputs['X'], | ||
self.attrs['paddings'], | ||
mode='constant', | ||
constant_values=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unit testing for other shape and padding value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/pad_op.cc
Outdated
AddComment(R"DOC( | ||
Pad Operator. | ||
)DOC"); | ||
AddAttr<std::vector<std::pair<int, int>>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个attribute是怎么设置成功的呀,我看attribute的定义,目前只能支持
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>>
Attribute;
实际上是没有 std::vector<std::pair<int,int>>的类型的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are discussing about this issue:
#3877
1. Rename variables by Google style. 2. Add more test cases. 3. Add more detail and meaningful comments. 4. Change type of "padding" to vector<int>
paddle/operators/pad_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto paddings = GetAttr<std::vector<int>>("paddings"); | ||
PADDLE_ENFORCE_EQ(dim0.size(), (int)(paddings.size() / 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(int)(...)
is C style, please use int(...)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assume that dim_size = 2 and padding_size = 5, it is illegal because padding_size shall always be even. However, current assert can't handle them properly.
We can change it to:
PADDLE_ENFORCE_EQ(dim0.size() * 2, int(paddings.size()), ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.
paddle/operators/pad_op.cc
Outdated
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto paddings = GetAttr<std::vector<int>>("paddings"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetAttr
has been renamed to Attr
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thx.
paddle/operators/pad_op.cc
Outdated
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think x_dims
is a more meaningful name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Thx.
paddle/operators/pad_op.cc
Outdated
"of input tensor."); | ||
std::vector<int> dim1(dim0.size()); | ||
for (int i = 0; i < dim0.size(); ++i) { | ||
dim1[i] = dim0[i] + paddings[i * 2] + paddings[i * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_dims
is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.
paddle/operators/pad_op.cc
Outdated
for (int i = 0; i < dim0.size(); ++i) { | ||
dim1[i] = dim0[i] + paddings[i * 2] + paddings[i * 2 + 1]; | ||
} | ||
ctx.Output<Tensor>("Out")->Resize(paddle::framework::make_ddim(dim1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
framework::make_ddim(dim1)
is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIxed.
paddle/operators/pad_op.cc
Outdated
PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The input of pad op."); | ||
AddOutput("Out", "The output of pad op."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add dim information into Input/output comments.
like https://github.com/PaddlePaddle/Paddle/pull/3887/files
paddle/operators/pad_op.cc
Outdated
" 2 * dimension size of input tensor."); | ||
AddAttr<float>("pad_value", | ||
"(float) default to 0; " | ||
"The value to be padded into tensor. ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The value to fill padded areas. "
Is this better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | ||
"Input(Out@GRAD) should not be null"); | ||
auto x_dims = ctx.Input<Tensor>("X")->dims(); | ||
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that we shall also assert x_grad
is not nullptr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I made a mistake... x_grad
can be nullptr
. If so, we don't need to compute x_grad
.
paddle/operators/pad_op.h
Outdated
auto* x = context.Input<Tensor>("X"); | ||
auto* out = context.Output<Tensor>("Out"); | ||
out->mutable_data<T>(context.GetPlace()); | ||
auto dims = x->dims(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dims
seems unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
paddle/operators/pad_op.h
Outdated
class PadKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
int dim = context.Input<Tensor>("X")->dims().size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of a tensor's dim is called rank
in TensorFlow. We can also use this name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thks.
paddle/operators/pad_op.h
Outdated
PadFunction<Place, T, 6>(context); | ||
break; | ||
default: | ||
PADDLE_THROW("Only ranks up to 6 supported."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"PadOp only support tensors with no more than 6 dimensions."
Just for your reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thks.
class TestCase3(TestPadOp): | ||
def initTestCase(self): | ||
self.shape = (8) | ||
self.paddings = [(0, 1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can paddings
be negative? If so, we shall add unit tests for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. But i think negative paddings is not suggested to be used in pad op
. crop op
is better for your scene.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
AddInput("X", | ||
"The input of pad op. " | ||
"The input should be a k-D tensor(k > 0 and k < 7)"); | ||
AddOutput("Out", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that output Out
will not be used in backward infershape and computing? If so, use NotInGradient()
to mark it, then our graph compiler is able to optimize the memory.
See https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/minus_op.cc#L45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIxed.
fix #3912