Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dilation for conv_trans_op #6279

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions paddle/operators/conv_transpose_cudnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
Expand All @@ -37,8 +35,6 @@ class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
Expand Down
24 changes: 21 additions & 3 deletions paddle/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");

PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
Expand All @@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The input channel should be the same "
"as the number of filters.");

std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
filter_dims[i + 2]);
filter_extent);
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
Expand All @@ -73,6 +78,12 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW.");

AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>(
"strides",
"(vector<int> default:{1, 1}), the strides(h_stride, w_stride) of "
Expand All @@ -87,7 +98,7 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
Convolution2D Transpose Operator.

The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Expand Down Expand Up @@ -136,6 +147,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature.");

AddAttr<std::vector<int>>(
"dilations",
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation,h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1, 1}), the "
"strides{d_stride, h_stride, w_stride} of "
Expand All @@ -149,7 +167,7 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
Convolution3D Transpose Operator.

The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/conv_transpose_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {

std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
// groups will alway be disabled in conv2dtranspose.

const int batch_size = static_cast<int>(input->dims()[0]);
Expand Down Expand Up @@ -113,7 +114,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {

math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol;
std::vector<int> dilations({1, 1, 1});

// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
Expand Down Expand Up @@ -165,6 +165,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {

std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");

const int batch_size = static_cast<int>(input->dims()[0]);

Expand Down Expand Up @@ -219,7 +220,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {

math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
std::vector<int> dilations({1, 1, 1});

if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
Expand Down
22 changes: 17 additions & 5 deletions python/paddle/v2/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ def conv2d_transpose(input,
filter_size=None,
padding=None,
stride=None,
dilation=None,
param_attr=None):
"""
The transpose of conv2d layer.
Expand All @@ -727,6 +728,9 @@ def conv2d_transpose(input,
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation.
param_attr: Parameter Attribute.
main_program(Program): the main program
startup_program(Program): the startup program
Expand All @@ -747,10 +751,15 @@ def conv2d_transpose(input,
op_attr['paddings'] = padding

if isinstance(stride, int):
op_attr['strides'] = stride
op_attr['strides'] = [stride, stride]
elif stride is not None:
op_attr['strides'] = stride

if isinstance(dilation, int):
op_attr['dilations'] = [dilation, dilation]
elif dilation is not None:
op_attr['dilations'] = dilation

if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
Expand All @@ -759,14 +768,17 @@ def conv2d_transpose(input,

padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])

h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = output_size[0] - \
(h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]

filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 *
padding[0] - 1) / dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w]

elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]

Expand Down
73 changes: 63 additions & 10 deletions python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from op_test import OpTest


def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
def conv2dtranspose_forward_naive(input_, filter_, attrs):
in_n, in_c, in_h, in_w = input_.shape
f_c, out_c, f_h, f_w = filter_.shape
assert in_c == f_c

stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad']
out_h = (in_h - 1) * stride[0] + f_h
out_w = (in_w - 1) * stride[1] + f_w
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
d_bolck_h = dilations[0] * (f_h - 1) + 1
d_bolck_w = dilations[1] * (f_w - 1) + 1
out_h = (in_h - 1) * stride[0] + d_bolck_h
out_w = (in_w - 1) * stride[1] + d_bolck_w

out = np.zeros((in_n, out_c, out_h, out_w))

Expand All @@ -23,9 +26,9 @@ def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):

for k in range(out_c):
tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
i1, i2 = i * stride[0], i * stride[0] + f_h
j1, j2 = j * stride[0], j * stride[0] + f_w
out[n, k, i1:i2, j1:j2] += tmp_out
i1, i2 = i * stride[0], i * stride[0] + d_bolck_h
j1, j2 = j * stride[0], j * stride[0] + d_bolck_h
out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out

out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]]
return out
Expand All @@ -37,18 +40,19 @@ def setUp(self):
self.init_op_type()
self.init_test_case()

conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive(
input_, filter_, conv2dtranspose_param).astype('float32')

self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'dilations': self.dilations
}

output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype('float32')

self.outputs = {'Output': output}

def test_check_output(self):
Expand Down Expand Up @@ -104,11 +108,60 @@ def init_test_case(self):
self.filter_size = [f_c, 6, 3, 3]


class TestWithDilation(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]


# ------------ test_cudnn ------------
class TestCudnn(TestConv2dTransposeOp):
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"


class TestCudnnWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]

def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"


class TestCudnnWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]

def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"


# #cudnn v5 does not support dilation conv.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接注释不太合适吧?对不同CUDNN_VERSION做测试,你可以参考这里:
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/gserver/tests/test_LayerGrad.cpp#L198

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得可以在下一个PR里面添加一个从python端获取CUDNN_VERSION的接口

# class TestCudnnWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1]
# self.stride = [2, 2]
# self.dilations = [2, 2]
# self.input_size = [2, 3, 5, 5] # NCHW
# f_c = self.input_size[1]
# self.filter_size = [f_c, 6, 3, 3]
#
# def init_op_type(self):
# self.op_type = "conv2d_transpose_cudnn"

if __name__ == '__main__':
unittest.main()
Loading