Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maxpool op nhwc #7214

Merged
merged 27 commits into from
Jan 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
62cf062
maxpool2d_support_nhwc
Flowingsun007 Jan 8, 2022
b6d746d
refine
Flowingsun007 Jan 9, 2022
eac5c10
add test case
Flowingsun007 Jan 9, 2022
944fcbc
format
Flowingsun007 Jan 9, 2022
feb7597
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 9, 2022
d75c323
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 12, 2022
851efdc
refine
Flowingsun007 Jan 12, 2022
6d67786
refine
Flowingsun007 Jan 12, 2022
419c869
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 12, 2022
d82ce7b
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 12, 2022
3b13cad
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 13, 2022
e24676e
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 13, 2022
5f9a6cd
fix comments
Flowingsun007 Jan 13, 2022
959654b
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 13, 2022
e32629f
refine
Flowingsun007 Jan 13, 2022
1169bf5
Merge branch 'maxpool_op_nhwc' of github.com:Oneflow-Inc/oneflow into…
Flowingsun007 Jan 13, 2022
be64341
auto format by CI
oneflow-ci-bot Jan 13, 2022
e3193c0
fix clang warning
Flowingsun007 Jan 13, 2022
43e3689
auto format by CI
oneflow-ci-bot Jan 13, 2022
6c672db
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 13, 2022
f2d9dd0
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 13, 2022
a5a6a80
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 14, 2022
041f8e5
Merge branch 'master' into maxpool_op_nhwc
Flowingsun007 Jan 14, 2022
455f0f1
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 14, 2022
e78c6b6
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 14, 2022
f66f8dd
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 14, 2022
1ebfb91
Merge branch 'master' into maxpool_op_nhwc
oneflow-ci-bot Jan 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 113 additions & 18 deletions oneflow/user/kernels/pooling_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,61 @@ std::shared_ptr<PoolingOpKernelCache> CreateOpKernelCache(user_op::KernelCacheCo
return cache;
}

namespace {

template<typename T>
void Maxpool2dForwardComputeCLast(const NdIndexOffsetHelper<int64_t, 4>& index_helper,
int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const int32_t padding_h, const int32_t padding_w,
const int64_t n_batch, const int64_t n_channel,
const int64_t x_height, const int64_t x_width,
const int64_t y_height, const int64_t y_width,
const int32_t kernel_size_h, const int32_t kernel_size_w,
const int32_t stride_h, const int32_t stride_w,
const int32_t dilation_h, const int32_t dilation_w) {
int64_t n = 0, h = 0, w = 0, c = 0;
for (int64_t num = 0; num < elem_num; ++num) {
index_helper.OffsetToNdIndex(num, n, h, w, c);

const int64_t x_start_idx = n * x_height * x_width * n_channel;
const int64_t y_start_idx = n * y_height * y_width * n_channel;
int64_t hstart = h * stride_h - padding_h;
int64_t wstart = w * stride_w - padding_w;
const int64_t hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height
? (hstart + (kernel_size_h - 1) * dilation_h + 1)
: x_height;
const int64_t wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width
? (wstart + (kernel_size_w - 1) * dilation_w + 1)
: x_width;

while (hstart < 0) { hstart += dilation_h; }
while (wstart < 0) { wstart += dilation_w; }
/* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */
int64_t max_index = hstart * x_width + wstart;
int64_t src_idx = 0;
/* equal to -std::numeric_limits<T>::infinity(); */
T max_value = detail::numeric_limits<T>::lower_bound();

for (int64_t i = hstart; i < hend; i += dilation_h) {
for (int64_t j = wstart; j < wend; j += dilation_w) {
const int64_t window_idx = i * x_width * n_channel + j * n_channel + c;
const int64_t search_idx = x_start_idx + window_idx;
T val = src[search_idx];
if (val > max_value || detail::numerics<T>::isnan(val)) {
max_value = val;
max_index = window_idx;
src_idx = search_idx;
}
}
}
const int64_t out_idx = y_start_idx + h * y_width * n_channel + w * n_channel + c;
dest[out_idx] = src[src_idx];
indice_ptr[out_idx] = max_index;
}
}

} // namespace

template<typename T>
struct PoolingKernelUtil<DeviceType::kCPU, T> {
static void Maxpool1dForward(ep::Stream* stream,
Expand All @@ -62,11 +117,11 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));
}

static void Maxpool2dForward(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
const MaxPoolingParams3D& params_3d) {
Maxpool2dForwardCompute<T>(
static void Maxpool2dForwardCFirst(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest,
int64_t* indice_ptr, const MaxPoolingParams3D& params_3d) {
Maxpool2dForwardComputeCFirst<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),
Expand All @@ -75,14 +130,39 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
}

static void Maxpool2dBackward(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr, const MaxPoolingParams3D& params_3d) {
Maxpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
static void Maxpool2dBackwardCFirst(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr,
const MaxPoolingParams3D& params_3d) {
Maxpool2dBackwardComputeCFirst<T>(index_helper, elem_num, src, dest, indice_ptr,
params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
}

static void Maxpool2dForwardCLast(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest,
int64_t* indice_ptr, const MaxPoolingParams3D& params_3d) {
Maxpool2dForwardComputeCLast<T>(
index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),
params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1],
params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
}

static void Maxpool2dBackwardCLast(ep::Stream* stream,
const NdIndexOffsetHelper<int64_t, 4>& index_helper,
const int64_t elem_num, const T* src, T* dest,
const int64_t* indice_ptr,
const MaxPoolingParams3D& params_3d) {
Maxpool2dBackwardComputeCLast<T>(index_helper, elem_num, src, dest, indice_ptr,
params_3d.num_batch(), params_3d.num_channel(),
params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
}

static void Maxpool3dForward(ep::Stream* stream,
Expand Down Expand Up @@ -216,9 +296,16 @@ class MaxPool2dKernel final : public user_op::OpKernel {
DimVector y_vector;
y->shape().ToDimVector(&y_vector);
NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());

PoolingKernelUtil<device_type, T>::Maxpool2dForward(ctx->stream(), index_helper, elem_num, src,
dest, indice_ptr, params_3d);
const std::string& data_format = ctx->Attr<std::string>("data_format");
if (data_format == "channels_first") {
PoolingKernelUtil<device_type, T>::Maxpool2dForwardCFirst(
ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);
} else if (data_format == "channels_last") {
PoolingKernelUtil<device_type, T>::Maxpool2dForwardCLast(
ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);
} else {
UNIMPLEMENTED() << "Unsupported data_format";
}
};
};

Expand Down Expand Up @@ -255,8 +342,16 @@ class MaxPool2dGradKernel final : public user_op::OpKernel {
size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
Memset<device_type>(ctx->stream(), dest, 0, out_bytes_size);

PoolingKernelUtil<device_type, T>::Maxpool2dBackward(ctx->stream(), index_helper, elem_num, src,
dest, indice_ptr, params_3d);
const std::string& data_format = ctx->Attr<std::string>("data_format");
if (data_format == "channels_first") {
PoolingKernelUtil<device_type, T>::Maxpool2dBackwardCFirst(
ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);
} else if (data_format == "channels_last") {
PoolingKernelUtil<device_type, T>::Maxpool2dBackwardCLast(
ctx->stream(), index_helper, elem_num, src, dest, indice_ptr, params_3d);
} else {
UNIMPLEMENTED() << "Unsupported data_format";
}
};
};

Expand Down
Loading