Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support corner_pool related custom operators for onnxruntime in mmcv #997

Merged
merged 13 commits into from
May 1, 2021
Merged
26 changes: 26 additions & 0 deletions mmcv/ops/corner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
'right_pool_forward', 'right_pool_backward'
])

_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}


class TopPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.top_pool_forward(input)
Expand All @@ -28,6 +36,12 @@ def backward(ctx, grad_output):

class BottomPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.bottom_pool_forward(input)
Expand All @@ -43,6 +57,12 @@ def backward(ctx, grad_output):

class LeftPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.left_pool_forward(input)
Expand All @@ -58,6 +78,12 @@ def backward(ctx, grad_output):

class RightPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.right_pool_forward(input)
Expand Down
47 changes: 47 additions & 0 deletions mmcv/ops/csrc/onnxruntime/corner_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifndef ONNXRUNTIME_CORNER_POOL_H
#define ONNXRUNTIME_CORNER_POOL_H

#include <assert.h>
#include <string>
#include <onnxruntime_cxx_api.h>

struct MMCVCornerPoolKernel {
public:
MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info): ort_(ort) {
mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "mode");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}

void Compute(OrtKernelContext* context);

private:
Ort::CustomOpApi ort_;
Ort::AllocatorWithDefaultOptions allocator_;

int64_t mode_;
};

struct MMCVCornerPoolCustomOp : Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
void *CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
return new MMCVCornerPoolKernel(api, info);
}

const char* GetName() const { return "MMCVCornerPool"; }

size_t GetInputTypeCount() const { return 1; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

// force cpu
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif // ONNXRUNTIME_CORNER_POOL_H
157 changes: 157 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include "corner_pool.h"
#include "../ort_mmcv_utils.h"

void TopPoolForwardCPU(const float *input, float *output, float *tmp_output,
const int nthreads, const int channels,
const int height, const int width) {
int batch_size = nthreads / channels / width / height;
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// copy column from output to tmp_output
for (int h = 0; h < height; h++) {
int index = index_n_c + h * width + w;
tmp_output[h] = output[index];
}
// do top_pool
for (int ind = 1; ind < height; ind <<= 1) {
for (int h = 0; h < height - ind; h++) {
output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h+ind]);
}
// copy column from updated output to tmp_output
for (int h = 0; h < height - ind; h++) {
tmp_output[h] = output[index_n_c + h * width + w];
}
} // for ind
} // for w
} // for c
} // for n

}

void BottomPoolForwardCPU(const float *input, float *output, float *tmp_output,
const int nthreads, const int channels,
const int height, const int width) {
int batch_size = nthreads / channels / width / height;
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// copy column from output to tmp_output
for (int h = 0; h < height; h++) {
int index = index_n_c + h * width + w;
tmp_output[h] = output[index];
}
// do bottom_pool
for (int ind = 1; ind < height; ind <<= 1) {
for (int h = ind; h < height; h++) {
output[index_n_c + h * width + w] = std::max(tmp_output[h], tmp_output[h-ind]);
}
// copy column from updated output to tmp_output
for (int h = ind; h < height; h++) {
tmp_output[h] = output[index_n_c + h * width + w];
}
} // for ind
} // for w
} // for c
} // for n

}

void LeftPoolForwardCPU(const float *input, float *output, float *tmp_output,
const int nthreads, const int channels,
const int height, const int width) {
int batch_size = nthreads / channels / width / height;
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// copy row from output to tmp_output
for (int w = 0; w < width; w++) {
int index = index_n_c + h * width + w;
tmp_output[w] = output[index];
}
// do left_pool
for (int ind = 1; ind < width; ind <<= 1) {
for (int w = 0; w < width - ind; w++) {
output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w+ind]);
}
// copy row from updated output to tmp_output
for (int w = 0; w < width - ind; w++) {
tmp_output[w] = output[index_n_c + h * width + w];
}
} // for ind
} // for h
} // for c
} // for n

}

void RightPoolForwardCPU(const float *input, float *output, float *tmp_output,
const int nthreads, const int channels,
const int height, const int width) {
int batch_size = nthreads / channels / width / height;
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// copy row from output to tmp_output
for (int w = 0; w < width; w++) {
int index = index_n_c + h * width + w;
tmp_output[w] = output[index];
}
// do right_pool
for (int ind = 1; ind < width; ind <<= 1) {
for (int w = ind; w < width; w++) {
output[index_n_c + h * width + w] = std::max(tmp_output[w], tmp_output[w-ind]);
}
// copy row from updated output to tmp_output
for (int w = ind; w < width; w++) {
tmp_output[w] = output[index_n_c + h * width + w];
}
} // for ind
} // for h
} // for c
} // for n

}

void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) {
const int mode = int(mode_);
typedef float T;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const T *input_data = reinterpret_cast<const float *>(ort_.GetTensorData<T>(input));

// allocate output memory
OrtTensorDimensions out_dimensions(ort_, input);
OrtValue *output = ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size());
T *output_data = ort_.GetTensorMutableData<T>(output);

// copy input_data to output_data
int batch_size = out_dimensions.data()[0];
int input_channels = out_dimensions.data()[1];
int input_height = out_dimensions.data()[2];
int input_width = out_dimensions.data()[3];
int output_size = batch_size * input_channels * input_height * input_width;
memcpy(output_data, input_data, sizeof(T) * output_size);

// allocate tmp_output memory
// 'top': 0, 'bottom': 1, 'left': 2, 'right':3
assert(mode == 0 || mode == 1 || mode == 2 || mode == 3);
int tmp_output_size;
if (mode == 0 || mode_ == 1) tmp_output_size = input_height;
else tmp_output_size = input_width;
T *tmp_output_data = (T *)allocator_.Alloc(sizeof(T) * tmp_output_size);

// do corner_pool
if (mode == 0) TopPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width);
else if (mode == 1) BottomPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width);
else if (mode == 2) LeftPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width);
else RightPoolForwardCPU(input_data, output_data, tmp_output_data, output_size, input_channels, input_height, input_width);

}
6 changes: 6 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
#include "roi_align.h"
#include "roi_align_rotated.h"
#include "soft_nms.h"
#include "corner_pool.h"

const char *c_MMCVOpDomain = "mmcv";
SoftNmsOp c_SoftNmsOp;
NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
GridSampleOp c_GridSampleOp;
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;

OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
Expand Down Expand Up @@ -45,5 +47,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}

if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) {
return status;
}

return ortApi->AddCustomOpDomain(options, domain);
}
46 changes: 46 additions & 0 deletions tests/test_ops/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,49 @@ def func(feat, scale_factor=2):
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)


@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
def test_corner_pool(mode, opset=11):
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')

from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')

input = torch.rand((2, 3, 9, 12)) # (n,c,h,w)

from mmcv.ops.corner_pool import CornerPool

def corner_pool_func(input):
corner_pool_module = CornerPool(mode)
return corner_pool_module.corner_pool.apply(input)

wrapped_model = WrapFunction(corner_pool_func).eval()

with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)

onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)

session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-5)