From 3bc982b8beedce82077075ce80a813f3bca7b081 Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Tue, 27 Apr 2021 14:52:49 +0800 Subject: [PATCH 01/11] supports for onnxruntime custom op `mmcv::MMCVTopPool` --- mmcv/ops/corner_pool.py | 5 ++ mmcv/ops/csrc/onnxruntime/corner_pool.h | 42 +++++++++++++ mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 62 +++++++++++++++++++ .../onnxruntime/cpu/onnxruntime_register.cpp | 6 ++ 4 files changed, 115 insertions(+) create mode 100644 mmcv/ops/csrc/onnxruntime/corner_pool.h create mode 100644 mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py index 6b0d871933..dafef096ef 100644 --- a/mmcv/ops/corner_pool.py +++ b/mmcv/ops/corner_pool.py @@ -13,6 +13,11 @@ class TopPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op('mmcv::MMCVTopPool', input) + return output + @staticmethod def forward(ctx, input): output = ext_module.top_pool_forward(input) diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h new file mode 100644 index 0000000000..cc3d41ccf5 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h @@ -0,0 +1,42 @@ +#ifndef ONNXRUNTIME_CORNER_POOL_H +#define ONNXRUNTIME_CORNER_POOL_H + +#include + +struct MMCVTopPoolKernel { + public: + MMCVTopPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) { + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + Ort::AllocatorWithDefaultOptions allocator_; +}; + +struct MMCVTopPoolCustomOp : Ort::CustomOpBase { + void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { + return new MMCVTopPoolKernel(api, info); + } + + const char* GetName() const { return "MMCVTopPool"; } + + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + // force cpu + const char* GetExecutionProviderType() const { + return "CPUExecutionProvider"; + } +}; +#endif // ONNXRUNTIME_CORNER_POOL_H diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp new file mode 100644 index 0000000000..94c86f3d4a --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -0,0 +1,62 @@ +#include "corner_pool.h" + +#include +#include + +#include "../ort_mmcv_utils.h" + +void TopPoolForwardCPU(const float *input, float *output, float *tmp_output, + const int nthreads, const int channels, + const int height, const int width) { + int batch_size = nthreads / channels / width / height; + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // copy column from output to tmp_output + for (int h = 0; h < height; h++) { + int index = index_n_c + h * width + w; + tmp_output[h] = output[index]; + } + // do top_pool + for (int ind = 1; ind < height; ind <<= 1) { + for (int h = 0; h < height - ind; h++) { + output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h+ind]); + } + // copy column from updated output to tmp_output + for (int h = 0; h < height - ind; h++) { + tmp_output[h] = output[index_n_c + h * width + w]; + } + } // for ind + } // for w + } // for c + } // for n + +} + +void MMCVTopPoolKernel::Compute(OrtKernelContext *context) { + typedef float T; + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const T *input_data = reinterpret_cast(ort_.GetTensorData(input)); + + OrtTensorDimensions out_dimensions(ort_, input); + + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + + // allocate tmp and output memory + OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); + T *output_data = ort_.GetTensorMutableData(output); + T *tmp_output_data = (T *)allocator_.Alloc(sizeof(T) * input_height); + + // copy input_data to output_data + int output_size = out_dimensions.data()[0]; + for (auto i = 1; i < out_dimensions.size(); ++i) { + output_size *= out_dimensions.data()[i]; + } + memcpy(output_data, input_data, sizeof(T) * output_size); + + TopPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); +} diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index a46e5b6215..7d9006fefc 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -6,6 +6,7 @@ #include "roi_align.h" #include "roi_align_rotated.h" #include "soft_nms.h" +#include "corner_pool.h" const char *c_MMCVOpDomain = "mmcv"; SoftNmsOp c_SoftNmsOp; @@ -13,6 +14,7 @@ NmsOp c_NmsOp; MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp; MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp; GridSampleOp c_GridSampleOp; +MMCVTopPoolCustomOp c_MMCVTopPoolCustomOp; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { @@ -45,5 +47,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVTopPoolCustomOp)) { + return status; + } + return ortApi->AddCustomOpDomain(options, domain); } From df3caca7a7bef8d30d795ba3d41bfd1860e6587d Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Tue, 27 Apr 2021 22:30:01 +0800 Subject: [PATCH 02/11] supports for onnxruntime custom op `mmcv::MMCVCornerPool`, involving TopPool, BottomPool, LeftPool and RightPool --- mmcv/ops/corner_pool.py | 23 +++- mmcv/ops/csrc/onnxruntime/corner_pool.h | 15 ++- mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 127 +++++++++++++++--- .../onnxruntime/cpu/onnxruntime_register.cpp | 4 +- 4 files changed, 145 insertions(+), 24 deletions(-) diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py index dafef096ef..189506e6aa 100644 --- a/mmcv/ops/corner_pool.py +++ b/mmcv/ops/corner_pool.py @@ -10,12 +10,15 @@ 'right_pool_forward', 'right_pool_backward' ]) +_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} + class TopPoolFunction(Function): @staticmethod def symbolic(g, input): - output = g.op('mmcv::MMCVTopPool', input) + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) return output @staticmethod @@ -33,6 +36,12 @@ def backward(ctx, grad_output): class BottomPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.bottom_pool_forward(input) @@ -48,6 +57,12 @@ def backward(ctx, grad_output): class LeftPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.left_pool_forward(input) @@ -63,6 +78,12 @@ def backward(ctx, grad_output): class RightPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.right_pool_forward(input) diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h index cc3d41ccf5..def7b80c38 100644 --- a/mmcv/ops/csrc/onnxruntime/corner_pool.h +++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h @@ -1,11 +1,14 @@ #ifndef ONNXRUNTIME_CORNER_POOL_H #define ONNXRUNTIME_CORNER_POOL_H +#include +#include #include -struct MMCVTopPoolKernel { +struct MMCVCornerPoolKernel { public: - MMCVTopPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) { + MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) { + mode_ = ort_.KernelInfoGetAttribute(info, "mode"); // create allocator allocator_ = Ort::AllocatorWithDefaultOptions(); } @@ -15,14 +18,16 @@ struct MMCVTopPoolKernel { private: Ort::CustomOpApi ort_; Ort::AllocatorWithDefaultOptions allocator_; + + int64_t mode_; }; -struct MMCVTopPoolCustomOp : Ort::CustomOpBase { +struct MMCVCornerPoolCustomOp : Ort::CustomOpBase { void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { - return new MMCVTopPoolKernel(api, info); + return new MMCVCornerPoolKernel(api, info); } - const char* GetName() const { return "MMCVTopPool"; } + const char* GetName() const { return "MMCVCornerPool"; } size_t GetInputTypeCount() const { return 1; } ONNXTensorElementDataType GetInputType(size_t) const { diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp index 94c86f3d4a..eb9f77e0c8 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -1,8 +1,4 @@ #include "corner_pool.h" - -#include -#include - #include "../ort_mmcv_utils.h" void TopPoolForwardCPU(const float *input, float *output, float *tmp_output, @@ -35,28 +31,127 @@ void TopPoolForwardCPU(const float *input, float *output, float *tmp_output, } -void MMCVTopPoolKernel::Compute(OrtKernelContext *context) { +void BottomPoolForwardCPU(const float *input, float *output, float *tmp_output, + const int nthreads, const int channels, + const int height, const int width) { + int batch_size = nthreads / channels / width / height; + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // copy column from output to tmp_output + for (int h = 0; h < height; h++) { + int index = index_n_c + h * width + w; + tmp_output[h] = output[index]; + } + // do bottom_pool + for (int ind = 1; ind < height; ind <<= 1) { + for (int h = ind; h < height; h++) { + output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h-ind]); + } + // copy column from updated output to tmp_output + for (int h = ind; h < height; h++) { + tmp_output[h] = output[index_n_c + h * width + w]; + } + } // for ind + } // for w + } // for c + } // for n + +} + +void LeftPoolForwardCPU(const float *input, float *output, float *tmp_output, + const int nthreads, const int channels, + const int height, const int width) { + int batch_size = nthreads / channels / width / height; + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // copy row from output to tmp_output + for (int w = 0; w < width; w++) { + int index = index_n_c + h * width + w; + tmp_output[w] = output[index]; + } + // do left_pool + for (int ind = 1; ind < width; ind <<= 1) { + for (int w = 0; w < width - ind; w++) { + output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w+ind]); + } + // copy row from updated output to tmp_output + for (int w = 0; w < width - ind; w++) { + tmp_output[w] = output[index_n_c + h * width + w]; + } + } // for ind + } // for h + } // for c + } // for n + +} + +void RightPoolForwardCPU(const float *input, float *output, float *tmp_output, + const int nthreads, const int channels, + const int height, const int width) { + int batch_size = nthreads / channels / width / height; + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // copy row from output to tmp_output + for (int w = 0; w < width; w++) { + int index = index_n_c + h * width + w; + tmp_output[w] = output[index]; + } + // do right_pool + for (int ind = 1; ind < width; ind <<= 1) { + for (int w = ind; w < width; w++) { + output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w-ind]); + } + // copy row from updated output to tmp_output + for (int w = ind; w < width; w++) { + tmp_output[w] = output[index_n_c + h * width + w]; + } + } // for ind + } // for h + } // for c + } // for n + +} + +void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { + const int mode = int(mode_); typedef float T; const OrtValue *input = ort_.KernelContext_GetInput(context, 0); const T *input_data = reinterpret_cast(ort_.GetTensorData(input)); + // allocate output memory OrtTensorDimensions out_dimensions(ort_, input); + OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); + T *output_data = ort_.GetTensorMutableData(output); + // copy input_data to output_data + int batch_size = out_dimensions.data()[0]; int input_channels = out_dimensions.data()[1]; int input_height = out_dimensions.data()[2]; int input_width = out_dimensions.data()[3]; + int output_size = batch_size * input_channels * input_height * input_width; + memcpy(output_data, input_data, sizeof(T) * output_size); - // allocate tmp and output memory - OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - T *output_data = ort_.GetTensorMutableData(output); - T *tmp_output_data = (T *)allocator_.Alloc(sizeof(T) * input_height); + // allocate tmp_output memory + // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 + assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); + int tmp_output_size; + if (mode == 0 || mode_ == 1) tmp_output_size = input_height; + else tmp_output_size = input_width; + T *tmp_output_data = (T *)allocator_.Alloc(sizeof(T) * tmp_output_size); - // copy input_data to output_data - int output_size = out_dimensions.data()[0]; - for (auto i = 1; i < out_dimensions.size(); ++i) { - output_size *= out_dimensions.data()[i]; - } - memcpy(output_data, input_data, sizeof(T) * output_size); + // do corner_pool + if (mode == 0) TopPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); + else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); + else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); + else RightPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); - TopPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); } diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index 7d9006fefc..c906c128dc 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -14,7 +14,7 @@ NmsOp c_NmsOp; MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp; MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp; GridSampleOp c_GridSampleOp; -MMCVTopPoolCustomOp c_MMCVTopPoolCustomOp; +MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { @@ -47,7 +47,7 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } - if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVTopPoolCustomOp)) { + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) { return status; } From 0fe054e959a93df1d6bc5d3f7ee31e9308d9fd3d Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Wed, 28 Apr 2021 15:24:48 +0800 Subject: [PATCH 03/11] add unittest for corner_pool --- tests/test_ops/test_onnx.py | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 0e50f5403f..41320f10d6 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -448,3 +448,49 @@ def func(feat, scale_factor=2): if os.path.exists(onnx_file): os.remove(onnx_file) assert np.allclose(pytorch_result, onnx_result, atol=1e-3) + + +@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right']) +def test_corner_pool(mode, opset=11): + if torch.__version__ == 'parrots': + pytest.skip('onnx is not supported in parrots directly') + + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + if not os.path.exists(ort_custom_op_path): + pytest.skip('custom ops for onnxruntime are not compiled.') + + input = torch.rand((2, 3, 9, 12)) # (n,c,h,w) + + from mmcv.ops.corner_pool import CornerPool + + def corner_pool_func(input): + corner_pool_module = CornerPool(mode) + return corner_pool_module.corner_pool.apply(input) + + wrapped_model = WrapFunction(corner_pool_func).eval() + + with torch.no_grad(): + torch.onnx.export( + wrapped_model, + input, + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['input'], + output_names=['output'], + opset_version=opset) + + onnx_model = onnx.load(onnx_file) + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [node.name for node in onnx_model.graph.initializer] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + + session_options = rt.SessionOptions() + session_options.register_custom_ops_library(ort_custom_op_path) + sess = rt.InferenceSession(onnx_file, session_options) + ort_result = sess.run(None, {'input': input.detach().numpy()}) + pytorch_results = wrapped_model(input.clone()) + os.remove(onnx_file) + assert np.allclose(pytorch_results, ort_result, atol=1e-5) From 8c03354d3fd0cfb5f6d24f55e4ce38f02d8ac0be Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 15:48:25 +0800 Subject: [PATCH 04/11] supports mmcv::CornerPool without memcpy --- mmcv/ops/csrc/onnxruntime/corner_pool.h | 4 - mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 104 ++++++------------ tests/test_ops/test_onnx.py | 4 +- 3 files changed, 33 insertions(+), 79 deletions(-) diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h index def7b80c38..3e966081ac 100644 --- a/mmcv/ops/csrc/onnxruntime/corner_pool.h +++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h @@ -2,22 +2,18 @@ #define ONNXRUNTIME_CORNER_POOL_H #include -#include #include struct MMCVCornerPoolKernel { public: MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) { mode_ = ort_.KernelInfoGetAttribute(info, "mode"); - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); } void Compute(OrtKernelContext* context); private: Ort::CustomOpApi ort_; - Ort::AllocatorWithDefaultOptions allocator_; int64_t mode_; }; diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp index eb9f77e0c8..ea03005f21 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -1,7 +1,7 @@ #include "corner_pool.h" #include "../ort_mmcv_utils.h" -void TopPoolForwardCPU(const float *input, float *output, float *tmp_output, +void TopPoolForwardCPU(const float *input, float *output, const int nthreads, const int channels, const int height, const int width) { int batch_size = nthreads / channels / width / height; @@ -10,28 +10,19 @@ void TopPoolForwardCPU(const float *input, float *output, float *tmp_output, for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * width * height; for (int w = 0; w < width; w++) { - // copy column from output to tmp_output - for (int h = 0; h < height; h++) { - int index = index_n_c + h * width + w; - tmp_output[h] = output[index]; - } + // directly copy the most bottom value from input to output + output[index_n_c + (height - 1) * width + w] = input[index_n_c + (height - 1) * width + w]; // do top_pool - for (int ind = 1; ind < height; ind <<= 1) { - for (int h = 0; h < height - ind; h++) { - output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h+ind]); - } - // copy column from updated output to tmp_output - for (int h = 0; h < height - ind; h++) { - tmp_output[h] = output[index_n_c + h * width + w]; - } - } // for ind + for (int h = height - 2; h >= 0; h--) { + output[index_n_c + h * width + w] = std::max(output[index_n_c + (h+1) * width + w], input[index_n_c + h * width + w]); + } // for h } // for w } // for c } // for n } -void BottomPoolForwardCPU(const float *input, float *output, float *tmp_output, +void BottomPoolForwardCPU(const float *input, float *output, const int nthreads, const int channels, const int height, const int width) { int batch_size = nthreads / channels / width / height; @@ -40,28 +31,19 @@ void BottomPoolForwardCPU(const float *input, float *output, float *tmp_output, for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * width * height; for (int w = 0; w < width; w++) { - // copy column from output to tmp_output - for (int h = 0; h < height; h++) { - int index = index_n_c + h * width + w; - tmp_output[h] = output[index]; - } - // do bottom_pool - for (int ind = 1; ind < height; ind <<= 1) { - for (int h = ind; h < height; h++) { - output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h-ind]); - } - // copy column from updated output to tmp_output - for (int h = ind; h < height; h++) { - tmp_output[h] = output[index_n_c + h * width + w]; - } - } // for ind + // directly copy the most top value from input to output + output[index_n_c + w] = input[index_n_c + w]; + // do top_pool + for (int h = 1; h < height; h++) { + output[index_n_c + h * width + w] = std::max(output[index_n_c + (h-1) * width + w], input[index_n_c + h * width + w]); + } // for h } // for w } // for c } // for n } -void LeftPoolForwardCPU(const float *input, float *output, float *tmp_output, +void LeftPoolForwardCPU(const float *input, float *output, const int nthreads, const int channels, const int height, const int width) { int batch_size = nthreads / channels / width / height; @@ -70,28 +52,19 @@ void LeftPoolForwardCPU(const float *input, float *output, float *tmp_output, for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * width * height; for (int h = 0; h < height; h++) { - // copy row from output to tmp_output - for (int w = 0; w < width; w++) { - int index = index_n_c + h * width + w; - tmp_output[w] = output[index]; - } + // directly copy the most right value from input to output + output[index_n_c + h * width + width - 1] = input[index_n_c + h * width + width - 1]; // do left_pool - for (int ind = 1; ind < width; ind <<= 1) { - for (int w = 0; w < width - ind; w++) { - output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w+ind]); - } - // copy row from updated output to tmp_output - for (int w = 0; w < width - ind; w++) { - tmp_output[w] = output[index_n_c + h * width + w]; - } - } // for ind + for (int w = width - 2; w >= 0; w--){ + output[index_n_c + h * width + w] = std::max(output[index_n_c + h * width + w + 1], input[index_n_c + h * width +w]); + } // for w } // for h } // for c } // for n } -void RightPoolForwardCPU(const float *input, float *output, float *tmp_output, +void RightPoolForwardCPU(const float *input, float *output, const int nthreads, const int channels, const int height, const int width) { int batch_size = nthreads / channels / width / height; @@ -100,21 +73,12 @@ void RightPoolForwardCPU(const float *input, float *output, float *tmp_output, for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * width * height; for (int h = 0; h < height; h++) { - // copy row from output to tmp_output - for (int w = 0; w < width; w++) { - int index = index_n_c + h * width + w; - tmp_output[w] = output[index]; - } + // directly copy the most left value from input to output + output[index_n_c + h * width] = input[index_n_c + h * width]; // do right_pool - for (int ind = 1; ind < width; ind <<= 1) { - for (int w = ind; w < width; w++) { - output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w-ind]); - } - // copy row from updated output to tmp_output - for (int w = ind; w < width; w++) { - tmp_output[w] = output[index_n_c + h * width + w]; - } - } // for ind + for (int w = 1; w < width; w++) { + output[index_n_c + h * width + w] = std::max(output[index_n_c + h * width + w - 1], input[index_n_c + h * width + w]); + } // for w } // for h } // for c } // for n @@ -127,31 +91,25 @@ void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { const OrtValue *input = ort_.KernelContext_GetInput(context, 0); const T *input_data = reinterpret_cast(ort_.GetTensorData(input)); - // allocate output memory + // get output memory OrtTensorDimensions out_dimensions(ort_, input); OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); T *output_data = ort_.GetTensorMutableData(output); - // copy input_data to output_data + // get output_size int batch_size = out_dimensions.data()[0]; int input_channels = out_dimensions.data()[1]; int input_height = out_dimensions.data()[2]; int input_width = out_dimensions.data()[3]; int output_size = batch_size * input_channels * input_height * input_width; - memcpy(output_data, input_data, sizeof(T) * output_size); - // allocate tmp_output memory // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); - int tmp_output_size; - if (mode == 0 || mode_ == 1) tmp_output_size = input_height; - else tmp_output_size = input_width; - T *tmp_output_data = (T *)allocator_.Alloc(sizeof(T) * tmp_output_size); // do corner_pool - if (mode == 0) TopPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); - else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); - else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); - else RightPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width); + if (mode == 0) TopPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); + else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); + else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); + else RightPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); } diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 41320f10d6..62859c32f5 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -460,8 +460,6 @@ def test_corner_pool(mode, opset=11): if not os.path.exists(ort_custom_op_path): pytest.skip('custom ops for onnxruntime are not compiled.') - input = torch.rand((2, 3, 9, 12)) # (n,c,h,w) - from mmcv.ops.corner_pool import CornerPool def corner_pool_func(input): @@ -470,6 +468,8 @@ def corner_pool_func(input): wrapped_model = WrapFunction(corner_pool_func).eval() + input = torch.rand((2, 3, 9, 12)) # (n,c,h,w) + with torch.no_grad(): torch.onnx.export( wrapped_model, From 5858960769d482cd04c618f5d61755d39d9ee213 Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 16:17:07 +0800 Subject: [PATCH 05/11] add docs for mmcv::CornerPool --- docs/onnxruntime_custom_ops.md | 37 ++++++++++++++++++++++++++++++++++ docs/onnxruntime_op.md | 1 + 2 files changed, 38 insertions(+) diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index e42032d23d..93f8e0276e 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -27,6 +27,12 @@ - [Inputs](#inputs-3) - [Outputs](#outputs-3) - [Type Constraints](#type-constraints-3) + - [CornerPool](#corner_pool) + - [Description](#description-4) + - [Parameters](#parameters-4) + - [Inputs](#inputs-4) + - [Outputs](#outputs-4) + - [Type Constraints](#type-constraints-4) @@ -171,3 +177,34 @@ Perform sample from `input` with pixel locations from `grid`. ### Type Constraints - T:tensor(float32, Linear) + +## CornerPool + +### Description + +Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details. + +### Parameters + +| Type | Parameter | Description | +| ------- | --------------- | ---------------------------------------------------------------- | +| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) | + +### Inputs + +
+
input: T
+
Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.
+
+ +### Outputs + +
+
output: tensor(int64)
+
Output the pooled features. 4-D tensor of shape (N, C, H, W).
+
+ +### Type Constraints + +- T:tensor(float32) + diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index 9324524e39..c9d8f11252 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -21,6 +21,7 @@ | [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 | | [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 | | [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master | +| [CornerPool](onnxruntime_custom_ops.md#corner_pool) | Y | N | master | ## How to build custom operators for ONNX Runtime From d7dc41641a9db99508540fe34eca16f7f3b0c1fe Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 16:30:38 +0800 Subject: [PATCH 06/11] re-add docs for mmcv::CornerPool --- docs/onnxruntime_custom_ops.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index 93f8e0276e..666c0d455a 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -207,4 +207,3 @@ Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as ### Type Constraints - T:tensor(float32) - From a7a8bad7a9c22ef1cedb2cc2dd72ddf88aecac6e Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 16:35:20 +0800 Subject: [PATCH 07/11] fix output dtype doc --- docs/onnxruntime_custom_ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index 666c0d455a..ee487cb86f 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -200,7 +200,7 @@ Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as ### Outputs
-
output: tensor(int64)
+
output: T
Output the pooled features. 4-D tensor of shape (N, C, H, W).
From 4b29326f43c9fb675d3122fed67b0796476f6bab Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 17:34:59 +0800 Subject: [PATCH 08/11] reformat --- docs/onnxruntime_custom_ops.md | 2 +- docs/onnxruntime_op.md | 2 +- mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 39 ++++++++----------- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index ee487cb86f..837d947184 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -27,7 +27,7 @@ - [Inputs](#inputs-3) - [Outputs](#outputs-3) - [Type Constraints](#type-constraints-3) - - [CornerPool](#corner_pool) + - [CornerPool](#cornerpool) - [Description](#description-4) - [Parameters](#parameters-4) - [Inputs](#inputs-4) diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index c9d8f11252..0e2f62adb4 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -21,7 +21,7 @@ | [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 | | [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 | | [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master | -| [CornerPool](onnxruntime_custom_ops.md#corner_pool) | Y | N | master | +| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master | ## How to build custom operators for ONNX Runtime diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp index ea03005f21..1371b2248c 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -2,9 +2,8 @@ #include "../ort_mmcv_utils.h" void TopPoolForwardCPU(const float *input, float *output, - const int nthreads, const int channels, - const int height, const int width) { - int batch_size = nthreads / channels / width / height; + const int batch_size, const int channels, + const int height, const int width) { for (int n = 0; n < batch_size; n++) { int index_n = n * channels * width * height; for (int c = 0; c < channels; c++) { @@ -23,9 +22,8 @@ void TopPoolForwardCPU(const float *input, float *output, } void BottomPoolForwardCPU(const float *input, float *output, - const int nthreads, const int channels, + const int batch_size, const int channels, const int height, const int width) { - int batch_size = nthreads / channels / width / height; for (int n = 0; n < batch_size; n++) { int index_n = n * channels * width * height; for (int c = 0; c < channels; c++) { @@ -44,9 +42,8 @@ void BottomPoolForwardCPU(const float *input, float *output, } void LeftPoolForwardCPU(const float *input, float *output, - const int nthreads, const int channels, - const int height, const int width) { - int batch_size = nthreads / channels / width / height; + const int batch_size, const int channels, + const int height, const int width) { for (int n = 0; n < batch_size; n++) { int index_n = n * channels * width * height; for (int c = 0; c < channels; c++) { @@ -65,9 +62,8 @@ void LeftPoolForwardCPU(const float *input, float *output, } void RightPoolForwardCPU(const float *input, float *output, - const int nthreads, const int channels, - const int height, const int width) { - int batch_size = nthreads / channels / width / height; + const int batch_size, const int channels, + const int height, const int width) { for (int n = 0; n < batch_size; n++) { int index_n = n * channels * width * height; for (int c = 0; c < channels; c++) { @@ -94,22 +90,19 @@ void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { // get output memory OrtTensorDimensions out_dimensions(ort_, input); OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - T *output_data = ort_.GetTensorMutableData(output); - - // get output_size - int batch_size = out_dimensions.data()[0]; - int input_channels = out_dimensions.data()[1]; - int input_height = out_dimensions.data()[2]; - int input_width = out_dimensions.data()[3]; - int output_size = batch_size * input_channels * input_height * input_width; + T *output_data = ort_.GetTensorMutableData(output); // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); // do corner_pool - if (mode == 0) TopPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); - else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); - else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); - else RightPoolForwardCPU(input_data, output_data, output_size, input_channels, input_height, input_width); + int batch_size = out_dimensions.data()[0]; + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + if (mode == 0) TopPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); + else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); + else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); + else RightPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); } From 5fbd0484ad51525b0da2b9b0f9e750b9065c272d Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 18:18:19 +0800 Subject: [PATCH 09/11] format with pre-commit --- docs/onnxruntime_op.md | 1 + mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index 0e2f62adb4..30e13f234d 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -119,5 +119,6 @@ Take custom operator `soft_nms` for example. ## References + - [How to export Pytorch model with custom op to ONNX and run it in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md) - [How to add a custom operator/kernel in ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md) diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp index 1371b2248c..9b2a6c331c 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -90,7 +90,7 @@ void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { // get output memory OrtTensorDimensions out_dimensions(ort_, input); OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - T *output_data = ort_.GetTensorMutableData(output); + T *output_data = ort_.GetTensorMutableData(output); // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); From 83bbb92ec8340f858b05a158fe45141854657d6d Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Thu, 29 Apr 2021 18:22:45 +0800 Subject: [PATCH 10/11] format --- docs/onnxruntime_op.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index 30e13f234d..0e2f62adb4 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -119,6 +119,5 @@ Take custom operator `soft_nms` for example. ## References - - [How to export Pytorch model with custom op to ONNX and run it in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md) - [How to add a custom operator/kernel in ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md) From 5c505f7fb3f39899f5d5cc84f38859a8a09a82a1 Mon Sep 17 00:00:00 2001 From: liqiaofei1 Date: Thu, 29 Apr 2021 22:07:57 +0800 Subject: [PATCH 11/11] fix lint error, by using google clang-format style for c/c++ --- mmcv/ops/csrc/onnxruntime/corner_pool.h | 52 ++--- mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 188 ++++++++++-------- .../onnxruntime/cpu/onnxruntime_register.cpp | 5 +- 3 files changed, 131 insertions(+), 114 deletions(-) diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h index 3e966081ac..4edca2cb8f 100644 --- a/mmcv/ops/csrc/onnxruntime/corner_pool.h +++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h @@ -5,39 +5,41 @@ #include struct MMCVCornerPoolKernel { - public: - MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) { - mode_ = ort_.KernelInfoGetAttribute(info, "mode"); - } + public: + MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) { + mode_ = ort_.KernelInfoGetAttribute(info, "mode"); + } - void Compute(OrtKernelContext* context); + void Compute(OrtKernelContext* context); - private: - Ort::CustomOpApi ort_; + private: + Ort::CustomOpApi ort_; - int64_t mode_; + int64_t mode_; }; -struct MMCVCornerPoolCustomOp : Ort::CustomOpBase { - void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { - return new MMCVCornerPoolKernel(api, info); - } +struct MMCVCornerPoolCustomOp + : Ort::CustomOpBase { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { + return new MMCVCornerPoolKernel(api, info); + } - const char* GetName() const { return "MMCVCornerPool"; } + const char* GetName() const { return "MMCVCornerPool"; } - size_t GetInputTypeCount() const { return 1; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } - // force cpu - const char* GetExecutionProviderType() const { - return "CPUExecutionProvider"; - } + // force cpu + const char* GetExecutionProviderType() const { + return "CPUExecutionProvider"; + } }; #endif // ONNXRUNTIME_CORNER_POOL_H diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp index 9b2a6c331c..d9d4dc3aad 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -1,108 +1,122 @@ #include "corner_pool.h" -#include "../ort_mmcv_utils.h" -void TopPoolForwardCPU(const float *input, float *output, - const int batch_size, const int channels, - const int height, const int width) { - for (int n = 0; n < batch_size; n++) { - int index_n = n * channels * width * height; - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * width * height; - for (int w = 0; w < width; w++) { - // directly copy the most bottom value from input to output - output[index_n_c + (height - 1) * width + w] = input[index_n_c + (height - 1) * width + w]; - // do top_pool - for (int h = height - 2; h >= 0; h--) { - output[index_n_c + h * width + w] = std::max(output[index_n_c + (h+1) * width + w], input[index_n_c + h * width + w]); - } // for h - } // for w - } // for c - } // for n +#include "../ort_mmcv_utils.h" +void TopPoolForwardCPU(const float *input, float *output, const int batch_size, + const int channels, const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // directly copy the most bottom value from input to output + output[index_n_c + (height - 1) * width + w] = + input[index_n_c + (height - 1) * width + w]; + // do top_pool + for (int h = height - 2; h >= 0; h--) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + (h + 1) * width + w], + input[index_n_c + h * width + w]); + } // for h + } // for w + } // for c + } // for n } void BottomPoolForwardCPU(const float *input, float *output, const int batch_size, const int channels, const int height, const int width) { - for (int n = 0; n < batch_size; n++) { - int index_n = n * channels * width * height; - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * width * height; - for (int w = 0; w < width; w++) { - // directly copy the most top value from input to output - output[index_n_c + w] = input[index_n_c + w]; - // do top_pool - for (int h = 1; h < height; h++) { - output[index_n_c + h * width + w] = std::max(output[index_n_c + (h-1) * width + w], input[index_n_c + h * width + w]); - } // for h - } // for w - } // for c - } // for n - + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // directly copy the most top value from input to output + output[index_n_c + w] = input[index_n_c + w]; + // do top_pool + for (int h = 1; h < height; h++) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + (h - 1) * width + w], + input[index_n_c + h * width + w]); + } // for h + } // for w + } // for c + } // for n } -void LeftPoolForwardCPU(const float *input, float *output, - const int batch_size, const int channels, - const int height, const int width) { - for (int n = 0; n < batch_size; n++) { - int index_n = n * channels * width * height; - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * width * height; - for (int h = 0; h < height; h++) { - // directly copy the most right value from input to output - output[index_n_c + h * width + width - 1] = input[index_n_c + h * width + width - 1]; - // do left_pool - for (int w = width - 2; w >= 0; w--){ - output[index_n_c + h * width + w] = std::max(output[index_n_c + h * width + w + 1], input[index_n_c + h * width +w]); - } // for w - } // for h - } // for c - } // for n - +void LeftPoolForwardCPU(const float *input, float *output, const int batch_size, + const int channels, const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // directly copy the most right value from input to output + output[index_n_c + h * width + width - 1] = + input[index_n_c + h * width + width - 1]; + // do left_pool + for (int w = width - 2; w >= 0; w--) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + h * width + w + 1], + input[index_n_c + h * width + w]); + } // for w + } // for h + } // for c + } // for n } void RightPoolForwardCPU(const float *input, float *output, const int batch_size, const int channels, const int height, const int width) { - for (int n = 0; n < batch_size; n++) { - int index_n = n * channels * width * height; - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * width * height; - for (int h = 0; h < height; h++) { - // directly copy the most left value from input to output - output[index_n_c + h * width] = input[index_n_c + h * width]; - // do right_pool - for (int w = 1; w < width; w++) { - output[index_n_c + h * width + w] = std::max(output[index_n_c + h * width + w - 1], input[index_n_c + h * width + w]); - } // for w - } // for h - } // for c - } // for n - + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // directly copy the most left value from input to output + output[index_n_c + h * width] = input[index_n_c + h * width]; + // do right_pool + for (int w = 1; w < width; w++) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + h * width + w - 1], + input[index_n_c + h * width + w]); + } // for w + } // for h + } // for c + } // for n } -void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { - const int mode = int(mode_); - typedef float T; - const OrtValue *input = ort_.KernelContext_GetInput(context, 0); - const T *input_data = reinterpret_cast(ort_.GetTensorData(input)); - - // get output memory - OrtTensorDimensions out_dimensions(ort_, input); - OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - T *output_data = ort_.GetTensorMutableData(output); +void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { + const int mode = int(mode_); + typedef float T; + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const T *input_data = + reinterpret_cast(ort_.GetTensorData(input)); - // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 - assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); + // get output memory + OrtTensorDimensions out_dimensions(ort_, input); + OrtValue *output = ort_.KernelContext_GetOutput( + context, 0, out_dimensions.data(), out_dimensions.size()); + T *output_data = ort_.GetTensorMutableData(output); - // do corner_pool - int batch_size = out_dimensions.data()[0]; - int input_channels = out_dimensions.data()[1]; - int input_height = out_dimensions.data()[2]; - int input_width = out_dimensions.data()[3]; - if (mode == 0) TopPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); - else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); - else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); - else RightPoolForwardCPU(input_data, output_data, batch_size, input_channels, input_height, input_width); + // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 + assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); + // do corner_pool + int batch_size = out_dimensions.data()[0]; + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + if (mode == 0) + TopPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else if (mode == 1) + BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else if (mode == 2) + LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else + RightPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); } diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index c906c128dc..b55114b188 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -1,12 +1,12 @@ #include "onnxruntime_register.h" +#include "corner_pool.h" #include "grid_sample.h" #include "nms.h" #include "ort_mmcv_utils.h" #include "roi_align.h" #include "roi_align_rotated.h" #include "soft_nms.h" -#include "corner_pool.h" const char *c_MMCVOpDomain = "mmcv"; SoftNmsOp c_SoftNmsOp; @@ -47,7 +47,8 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } - if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) { + if (auto status = + ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) { return status; }