Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add crop op #3906

Merged
merged 22 commits into from
Sep 21, 2017
Merged

Add crop op #3906

merged 22 commits into from
Sep 21, 2017

Conversation

wanghaoshuang
Copy link
Contributor

@wanghaoshuang wanghaoshuang commented Sep 6, 2017

fixed #3905
fixed #3911

AddInput("X", "The input of crop op");
AddInput("Y", "The input used as reference for cropping. ");
AddOutput("Out", "The output of crop op.");
AddComment(R"DOC(
Copy link
Contributor

@dzhwinter dzhwinter Sep 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bad documentation, The offsets for cropping contains no information at all. Maybe describe some detail like (not necessary)

  1. in which scene this Op will be used.
  2. Input, Output of the Op.
  3. the detail of this attribute, such as shape refers to the sliced tensor shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. I will refine it.

typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

using Tensor = framework::Tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow the name convention of the google style click here in detail. we should not use any using in the header file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, Tensor has no template parameter, Just write

using framework::Tensor;

is okay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto Y = ctx.Input<Tensor>("Y");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can not find the test case contains Y as input. There is only normal one with offset and shape.
Can you give a proper application case that this usage('Y' as input) is necessary? Thx.

"Shape size should be equal to dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
for (int i = 0; i < shape.size(); ++i) {
tensor_shape[i] = (int64_t)shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have the interface accept the parameter of vector<int> make_ddim in ddim, do we have to convert its type?

typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

using Tensor = framework::Tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, Tensor has no template parameter, Just write

using framework::Tensor;

is okay.

auto out_dims = out->dims();

auto offsets = context.op().GetAttr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this enforce check into Infershape.
Operator's Infershape will be run before runtime, that's where real config check happens.

… crop_op

Conflicts:
	paddle/pybind/pybind.cc
	python/paddle/v2/framework/tests/op_test_util.py
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferShape里的检查需要全面一些,比如对输入非空的检查: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/squared_l2_distance_op.cc#L26

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

for (int i = 0; i < shape.size(); ++i) {
tensor_shape[i] = (int64_t)shape[i];
}
ctx.Output<Tensor>("Out")->Resize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoDTensor合入后,InferShape的Output里要使用:Output<framework::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

ctx.Output<Tensor>("Out")->Resize(
paddle::framework::make_ddim(tensor_shape));
} else {
ctx.Output<Tensor>("Out")->Resize(Y->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoDTensor合入后,InferShape的Output里要使用:Output<framework::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output使用:Output< framework ::LoDTensor>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

CropFunction<Place, T, 6>(context);
break;
default:
LOG(ERROR) << "Only ranks up to 6 supported.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错不要使用GLOG,使用PADDLE_ENFORCE/PADDLE_THROW:

PADDLE_THROW("Tensors with rank at most 6 are supported").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

CropGradFunction<Place, T, 6>(context);
break;
default:
LOG(ERROR) << "Only ranks up to 6 supported.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

self.outputs = {'Out': self.inputs['X'][2:10, 3:11]}


class TestCropGradOp(GradientChecker):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用新的单测框架。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


def test_normal(self):
self.check_grad(
self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_relative_error=0.5 太大

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the check before line 30. 一般先check,再使用,检查正常了,再开始使用。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"Shape size should be equal to dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = (int64_t)shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast< int64_t>(shape[i])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"The size of shape list should be as same as "
"dimension size of input X.")
.SetDefault(std::vector<int>());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the AddComment(R"DOC( )DOC") at last.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将来Doc可能转成MarkDown,为了显示效果, line 84 - line 86加缩进吧, line 90, 94,98相同。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

T *out_data = out->mutable_data<T>(context.GetPlace());
auto x_dims = x->dims();
auto out_dims = out->dims();
int64_t out_count = framework::product(out_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_count -> out_dims->numel()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<LoDTensor>("X");
auto *out = context.Output<LoDTensor>("Out");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在非sequence operators里的OpKernel里实际可以不使用LoDTensor,用Tensor即可:

#4151

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

std::vector<int64_t> x_shape = framework::vectorize(x_dims);
std::vector<int64_t> out_shape = framework::vectorize(out_dims);

auto offsets = context.op().Attr<std::vector<int>>("offsets");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context.op().Attr<> -> context.Attr<>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

T* out_data = out->mutable_data<T>(paddle::platform::GPUPlace());
auto x_dims = x->dims();
auto out_dims = out->dims();
int64_t out_count = framework::product(out_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Tensor::numel() instead of framework::product.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

cudaMemcpy(x_shape_gpu, x_shape, sizeof(int64_t) * D, cudaMemcpyHostToDevice);
cudaMalloc((void**)&out_shape_gpu, sizeof(int64_t) * D);
cudaMemcpy(out_shape_gpu, out_shape, sizeof(int64_t) * D,
cudaMemcpyHostToDevice);
Copy link
Contributor

@qingqing01 qingqing01 Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use cudaMalloc and cudaMemcpy, use Tensor's function. Use Tensor as the temp variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@reyoung
Copy link
Collaborator

reyoung commented Sep 19, 2017

@wanghaoshuang I just push PR #4174, which will make this PR not to use const_cast. Please update.

Also, I think maybe Transform is better to implement Crop.

@reyoung
Copy link
Collaborator

reyoung commented Sep 19, 2017

I will give a try and submit a PR to you.

@reyoung
Copy link
Collaborator

reyoung commented Sep 19, 2017

Sorry, I misunderstand what crop means. It seems that it is not related to Transform at all.

@reyoung
Copy link
Collaborator

reyoung commented Sep 19, 2017

Please reference #4205. I think Crop and Concat can be abstracted by StridedMemcpy

dzhwinter
dzhwinter previously approved these changes Sep 21, 2017
Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM++, need to modify the code based on comments slightly.

PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto Y = ctx.Input<LoDTensor>("Y");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if auto keyword is a pointer, please use Y*

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(Y->dims()),
"Tensor rank of both CropOp's "
"inputs must be same.");
ctx.Output<LoDTensor>("Out")->Resize(Y->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ctx.Output<LodTensor> twice, it's better to reuse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

int block = 512;
int grid = (n * d + block - 1) / block;

CropKernel<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it re-fold by the pre-commit? The format is strange, maybe it's better to add a new line character in next line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -0,0 +1,91 @@
import unittest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the copyrights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像所有单测都没copyrights。确定需要加上么?

Copy link
Contributor

@dzhwinter dzhwinter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM++

@wanghaoshuang wanghaoshuang merged commit da2aabb into PaddlePaddle:develop Sep 21, 2017
@wanghaoshuang wanghaoshuang deleted the crop_op branch May 20, 2022 03:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Crop Operator. Add crop op
4 participants