Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-33] Enhance mkldnn pooling to support full convention #11047

Merged
merged 18 commits into from
Nov 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,12 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param,
if (!ret)
return false;

if (param.pooling_convention == pool_enum::kValid)
if (param.pooling_convention == pool_enum::kValid) {
return true;
else
return false;

// need to support pooling convention full
// https://issues.apache.org/jira/browse/MXNET-33
#if 0
if (((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0) &&
((dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0))
return true;
else
return false;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we ignore the shape completely? even for the case of kValid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. Previously, mkldnn pooling operator only supports pooling_convention=kValid and it's no need to check shape for kValid. But if we want to support kFull, we need adjust padding size to get correct output shape.

} else {
// currently, only max-pooling is supported for full convention
return param.pool_type == pool_enum::kMaxPooling;
}
}

inline bool MKLDNNRequireWorkspace(const PoolingParam &param) {
Expand Down
40 changes: 32 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
}
}

static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const memory::desc &data_md,
const memory::desc &out_md) {
Expand All @@ -154,11 +162,17 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to write up a macro/function for the same check of kFull in the different place in this file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

const mkldnn::engine engine = CpuEngine::Get()->get_engine();
if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
}

if (pad_t_ != 0 || pad_l_ != 0) {
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
Expand All @@ -167,7 +181,6 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
CHECK_LT(pad_t_, kernel_h_);
}


const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
if (is_train && alg != algorithm::pooling_avg) {
Expand Down Expand Up @@ -227,17 +240,22 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
}

if (pad_t_ != 0 || pad_l_ != 0) {
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_l_, kernel_w_);
CHECK_LT(pad_t_, kernel_h_);
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_l_, kernel_w_);
CHECK_LT(pad_t_, kernel_h_);
}

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
Expand Down Expand Up @@ -353,6 +371,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
Expand Down
29 changes: 29 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,35 @@ def test_3d_pooling(pool_type, p_value=2, count_include_pad=True):
test_3d_pooling('lp', p_value=3)


@with_seed()
def test_pooling_full_2d():
def test_pooling_full_2d_type(pool_type):
data = (2, 2, 10, 10)
kernel = (4, 5)
pad = (1, 2)
stride = (3, 4)

convention = 'full'
ctx_list = []
sym_list = []

# o_h = ceil((10 + 1 + 1 - 4) / 3) + 1 = 4
# o_w = ceil((10 + 2 + 2 - 5) / 4) + 1 = 4
ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=False, name='pool'))

ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=False, name='pool'))

check_consistency(sym_list, ctx_list)

test_pooling_full_2d_type('max')
test_pooling_full_2d_type('avg')
test_pooling_full_2d_type('sum')


@with_seed()
def test_global_pooling():
def test_1d_pooling(pool_type, p_value=2):
Expand Down