Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add crop layers for supporting FCN model #2490

Merged
merged 20 commits into from
Jul 19, 2017

Conversation

wanghaoshuang
Copy link
Contributor

solve #2470

@qingqing01
Copy link
Contributor

Cropping and Padding are the opposite operations. Can we use the padding function for this cropping operation?

@wanghaoshuang
Copy link
Contributor Author

@qingqing01 @pkuyym padding function and cropping funtion is not opposite operations, strictly?
Padding forward funtion pad zero(or other values) to output tensor, while cropping backward function won`t change marginal element of output tensor.

'Padding forward funtion' != 'Cropping backward function'

@wanghaoshuang
Copy link
Contributor Author

wanghaoshuang commented Jun 20, 2017

@qingqing01 @pkuyym Thks for interpretion and discussion about cropping function and padding function. I think i got it. Padding function and cropping funtion is opposite operations.
Using the padding function in crop layer make code clean, but confusing. Maybe the best solution is putting computaion code into tensor library.
So, shall we keep the current code structure?

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc shows how to write a new layer. This code also needs the following contents.

required ImageConfig image_conf = 1;
repeated uint32 crop_corner = 2;
repeated uint32 crop_shape = 3;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Caffe and MXNet use offset to indicate crop_corner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx,get it.


void CropLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr input = inputLayers_[0]->getOutputValue();
Copy link
Contributor

@qingqing01 qingqing01 Jun 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layer should be able to accept two inputs, get the height and width from the second input. So, we can support to accept one or two inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get it.

Copy link
Contributor Author

@wanghaoshuang wanghaoshuang Jul 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qingqing01 I have submit a pr to add python wrapper and grad test for crop layer. But TeamCity build job was blocked by issue #2490

@qingqing01
Copy link
Contributor

@wanghaoshuang 如果提交了comments,请回复下哈~

1. change configure content to 'axis, offset, shape'
2. add an optional input to crop layer as cropping reference
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
if isinstance(input, LayerOutput):
input = [input]
elif isinstance(input, Projection):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不支持Projection 输入,应该去掉。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


@wrap_name_default()
@layer_support()
def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis给个默认值吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

:param offset: The crop offset
:type offset: Sequence
:param shape: The shape to be cropped. Default is None.
:type shape: Sqquence | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sqquence是啥?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Sqquence->Sequence

@layer_support()
def crop_layer(input, axis, offset, shape=None, name=None, layer_attr=None):
"""
The crop layer crop images by offset and shape. User can set crop shape by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The crop layer crop -> The crop layer crops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

inDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是:inputLayers_[1]?看Python接口里,第2个input是reference_input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是取原始input的dims,所以是inputLayers_[0].

}

void CropLayer::setTensorDim(const size_t batchSize) {
CHECK_EQ(static_cast<int>(inputLayers_.size()), 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python接口写的是支持一个或两个输入,这块却是CHECK_EQ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@wanghaoshuang
Copy link
Contributor Author

@qingqing01 crop layer相关bugs已经fixed.已commit.

@wanghaoshuang wanghaoshuang merged commit 2e2a674 into PaddlePaddle:develop Jul 19, 2017
@wanghaoshuang wanghaoshuang changed the title add crop layer Add layers for supporting FCN model Aug 11, 2017
@wanghaoshuang wanghaoshuang changed the title Add layers for supporting FCN model Add crop layers for supporting FCN model Aug 11, 2017
@wanghaoshuang wanghaoshuang deleted the crop_layer branch May 20, 2022 03:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants