-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add crop op #3906
Merged
Merged
Add crop op #3906
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
06b42e9
Add crop op.
wanghaoshuang 96a7c70
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang f23ab48
Fix attr int_64 error.
wanghaoshuang 7deddab
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang b299d07
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang a8584a9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang b21aee6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang e2d75bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang 57011b2
reste
wanghaoshuang 57a3b8b
1. Implement GPUCrop kernel instead of eigen.
wanghaoshuang 46888c3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang 0c05ea3
Pull latest pybind.cc to crop_op
wanghaoshuang 5e0e455
Add CUDA stream when launching kernel.
wanghaoshuang 8d9d537
remove op_test_util.py
wanghaoshuang 2c29cf1
Use Tensor as the temp variables instead of CUDA api
wanghaoshuang fa4908d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang a4b1abe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang 94fa9d1
Remove const cast for device context
wanghaoshuang ce709b7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang 68b5e5b
Use stridecpy instead of CUDA kernel
wanghaoshuang bc632df
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang 8b6fda6
move stride function to ddim.h
wanghaoshuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/crop_op.h" | ||
#include <boost/lexical_cast.hpp> | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using framework::Tensor; | ||
using framework::LoDTensor; | ||
|
||
class CropOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), | ||
"Input(X) of CropOp should not be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), | ||
"Output(Out) of CropOp should not be null."); | ||
auto x_dim = ctx.Input<LoDTensor>("X")->dims(); | ||
auto *y = ctx.Input<LoDTensor>("Y"); | ||
auto *out = ctx.Output<LoDTensor>("Out"); | ||
if (y == nullptr) { | ||
auto shape = Attr<std::vector<int>>("shape"); | ||
PADDLE_ENFORCE_EQ( | ||
int64_t(shape.size()), x_dim.size(), | ||
"Shape size should be equal to dimention size of input tensor."); | ||
std::vector<int64_t> tensor_shape(shape.size()); | ||
for (size_t i = 0; i < shape.size(); ++i) { | ||
tensor_shape[i] = static_cast<int64_t>(shape[i]); | ||
} | ||
out->Resize(framework::make_ddim(tensor_shape)); | ||
} else { | ||
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), | ||
"Tensor rank of both CropOp's " | ||
"inputs must be same."); | ||
out->Resize(y->dims()); | ||
} | ||
} | ||
}; | ||
|
||
class CropOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", | ||
"The input of pad op. " | ||
"The input should be a k-D tensor(k > 0 and k < 7)"); | ||
AddInput("Y", | ||
"The input used as reference for cropping" | ||
" with the same dimension as X. "); | ||
AddOutput("Out", | ||
"The output of crop op " | ||
"with the same dimension as X."); | ||
AddAttr<std::vector<int>>("offsets", | ||
"A list<int> describing offsets to be cropped." | ||
"The size of offsets list should be as same as " | ||
"dimension size of input X."); | ||
AddAttr<std::vector<int>>("shape", | ||
"A list<int> describing the shape of output." | ||
"The size of shape list should be as same as " | ||
"dimension size of input X.") | ||
.SetDefault(std::vector<int>()); | ||
AddComment(R"DOC( | ||
Crop Operator. | ||
Crop input into output, as specified by offsets and shape. | ||
|
||
There are two ways to set shape: | ||
1. referenc input: crop input X as shape as reference input. | ||
The dimension of reference input should | ||
be as same as input X. | ||
2. shape list: crop input X by shape described by a list<int>. | ||
The size of shape list should be as same as | ||
dimension size of input X. | ||
|
||
The input should be a k-D tensor(k > 0 and k < 7). As an example: | ||
|
||
Given: | ||
|
||
X = [[0, 1, 2, 0, 0] | ||
[0, 3, 4, 0, 0] | ||
[0, 0, 0, 0, 0]] | ||
|
||
and | ||
|
||
offsets = [0, 1] | ||
|
||
and | ||
|
||
shape = [2, 2] | ||
|
||
then we get | ||
|
||
Out = [[1, 2], | ||
[3, 4]] | ||
|
||
)DOC"); | ||
} | ||
}; | ||
|
||
class CropOpGrad : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | ||
"Input(Out@GRAD) should not be null"); | ||
auto x_dims = ctx.Input<LoDTensor>("X")->dims(); | ||
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); | ||
if (x_grad != nullptr) { | ||
x_grad->Resize(x_dims); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad); | ||
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>); | ||
REGISTER_OP_CPU_KERNEL(crop_grad, | ||
ops::CropGradKernel<paddle::platform::CPUPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#define EIGEN_USE_GPU | ||
#include "paddle/operators/crop_op.h" | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>); | ||
REGISTER_OP_GPU_KERNEL(crop_grad, | ||
ops::CropGradKernel<paddle::platform::GPUPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/strided_memcpy.h" | ||
|
||
namespace paddle { | ||
namespace operators { // Internal | ||
|
||
template <typename T, size_t D, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; | ||
using framework::Tensor; | ||
|
||
template <typename T> | ||
class CropKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* x = context.Input<Tensor>("X"); | ||
auto* out = context.Output<Tensor>("Out"); | ||
const T* x_data = x->data<T>(); | ||
T* out_data = out->mutable_data<T>(context.GetPlace()); | ||
auto x_stride = framework::stride(x->dims()); | ||
auto out_stride = framework::stride(out->dims()); | ||
auto offsets = context.Attr<std::vector<int>>("offsets"); | ||
PADDLE_ENFORCE_EQ( | ||
x->dims().size(), offsets.size(), | ||
"Offsets size should be equal to dimension size of input tensor."); | ||
int64_t offset = 0; | ||
for (int i = 0; i < offsets.size(); ++i) { | ||
offset += (x_stride[i] * offsets[i]); | ||
} | ||
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride, | ||
out->dims(), out_stride, out_data); | ||
} | ||
}; | ||
|
||
template <typename Place, typename T, size_t D> | ||
void CropGradFunction(const framework::ExecutionContext& context) { | ||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); | ||
if (d_x != nullptr) { | ||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); | ||
d_x->mutable_data<T>(context.GetPlace()); | ||
auto offsets = context.Attr<std::vector<int>>("offsets"); | ||
Eigen::array<std::pair<int, int>, D> paddings; | ||
for (int i = 0; i < D; ++i) { | ||
paddings[i].first = offsets[i]; | ||
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; | ||
} | ||
auto d_x_tensor = EigenTensor<T, D>::From(*d_x); | ||
auto d_out_tensor = EigenTensor<T, D>::From(*d_out); | ||
d_x_tensor.device(context.GetEigenDevice<Place>()) = | ||
d_out_tensor.pad(paddings, 0); | ||
} | ||
} | ||
|
||
template <typename Place, typename T> | ||
class CropGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
size_t rank = | ||
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size(); | ||
switch (rank) { | ||
case 1: | ||
CropGradFunction<Place, T, 1>(context); | ||
break; | ||
case 2: | ||
CropGradFunction<Place, T, 2>(context); | ||
break; | ||
case 3: | ||
CropGradFunction<Place, T, 3>(context); | ||
break; | ||
case 4: | ||
CropGradFunction<Place, T, 4>(context); | ||
break; | ||
case 5: | ||
CropGradFunction<Place, T, 5>(context); | ||
break; | ||
case 6: | ||
CropGradFunction<Place, T, 6>(context); | ||
break; | ||
default: | ||
PADDLE_THROW( | ||
"CropOp only support tensors with no more than 6 dimensions."); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import unittest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the copyrights. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好像所有单测都没copyrights。确定需要加上么? |
||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
def crop(data, offsets, crop_shape): | ||
def indexOf(shape, index): | ||
result = [] | ||
for dim in reversed(shape): | ||
result.append(index % dim) | ||
index = index / dim | ||
return result[::-1] | ||
|
||
result = [] | ||
for i, value in enumerate(data.flatten()): | ||
index = indexOf(data.shape, i) | ||
selected = True | ||
if len(index) == len(offsets): | ||
for j, offset in enumerate(offsets): | ||
selected = selected and index[j] >= offset and index[ | ||
j] < crop_shape[j] + offset | ||
if selected: | ||
result.append(value) | ||
return np.array(result).reshape(crop_shape) | ||
|
||
|
||
class TestCropOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "crop" | ||
self.crop_by_input = False | ||
self.attrs = {} | ||
self.initTestCase() | ||
self.attrs['offsets'] = self.offsets | ||
if self.crop_by_input: | ||
self.inputs = { | ||
'X': np.random.random(self.x_shape).astype("float32"), | ||
'Y': np.random.random(self.crop_shape).astype("float32") | ||
} | ||
else: | ||
self.attrs['shape'] = self.crop_shape | ||
self.inputs = { | ||
'X': np.random.random(self.x_shape).astype("float32"), | ||
} | ||
self.outputs = { | ||
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) | ||
} | ||
|
||
def initTestCase(self): | ||
self.x_shape = (8, 8) | ||
self.crop_shape = (2, 2) | ||
self.offsets = [1, 2] | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad_normal(self): | ||
self.check_grad(['X'], 'Out', max_relative_error=0.006) | ||
|
||
|
||
class TestCase1(TestCropOp): | ||
def initTestCase(self): | ||
self.x_shape = (16, 8, 32) | ||
self.crop_shape = [2, 2, 3] | ||
self.offsets = [1, 5, 3] | ||
|
||
|
||
class TestCase2(TestCropOp): | ||
def initTestCase(self): | ||
self.x_shape = (4, 8) | ||
self.crop_shape = [4, 8] | ||
self.offsets = [0, 0] | ||
|
||
|
||
class TestCase3(TestCropOp): | ||
def initTestCase(self): | ||
self.x_shape = (4, 8, 16) | ||
self.crop_shape = [2, 2, 3] | ||
self.offsets = [1, 5, 3] | ||
self.crop_by_input = True | ||
|
||
|
||
class TestCase4(TestCropOp): | ||
def initTestCase(self): | ||
self.x_shape = (4, 4) | ||
self.crop_shape = [4, 4] | ||
self.offsets = [0, 0] | ||
self.crop_by_input = True | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InferShape
里的检查需要全面一些,比如对输入非空的检查: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/squared_l2_distance_op.cc#L26There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.