diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index 837d947184..928d530991 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -33,6 +33,18 @@ - [Inputs](#inputs-4) - [Outputs](#outputs-4) - [Type Constraints](#type-constraints-4) + - [cummax](#cummax) + - [Description](#description-5) + - [Parameters](#parameters-5) + - [Inputs](#inputs-5) + - [Outputs](#outputs-5) + - [Type Constraints](#type-constraints-5) + - [cummin](#cummin) + - [Description](#description-6) + - [Parameters](#parameters-6) + - [Inputs](#inputs-6) + - [Outputs](#outputs-6) + - [Type Constraints](#type-constraints-6) @@ -207,3 +219,67 @@ Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as ### Type Constraints - T:tensor(float32) + +## cummax + +### Description + +Returns a tuple (`values`, `indices`) where `values` is the cumulative maximum elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`. Read [torch.cummax](https://pytorch.org/docs/stable/generated/torch.cummax.html) for more details. + +### Parameters + +| Type | Parameter | Description | +| ------- | --------------- | ---------------------------------------------------------------- | +| `int` | `dim` | the dimension to do the operation over | + +### Inputs + +
+
input: T
+
The input tensor with various shapes. Tensor with empty element is also supported.
+
+ +### Outputs + +
+
output: T
+
Output the cumulative maximum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.
+
indices: tensor(int64)
+
Output the index location of each cumulative maximum value found in the dimension `dim`, with the same shape as `input`.
+
+ +### Type Constraints + +- T:tensor(float32) + +## cummin + +### Description + +Returns a tuple (`values`, `indices`) where `values` is the cumulative minimum elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`. Read [torch.cummin](https://pytorch.org/docs/stable/generated/torch.cummin.html) for more details. + +### Parameters + +| Type | Parameter | Description | +| ------- | --------------- | ---------------------------------------------------------------- | +| `int` | `dim` | the dimension to do the operation over | + +### Inputs + +
+
input: T
+
The input tensor with various shapes. Tensor with empty element is also supported.
+
+ +### Outputs + +
+
output: T
+
Output the cumulative minimum elements of `input` in the dimension `dim`, with the same shape and dtype as `input`.
+
indices: tensor(int64)
+
Output the index location of each cumulative minimum value found in the dimension `dim`, with the same shape as `input`.
+
+ +### Type Constraints + +- T:tensor(float32) diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index 0e2f62adb4..e43ce70fc6 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -21,7 +21,9 @@ | [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 | | [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 | | [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master | -| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master | +| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master | +| [cummax](onnxruntime_custom_ops.md#cummax) | Y | N | master | +| [cummin](onnxruntime_custom_ops.md#cummin) | Y | N | master | ## How to build custom operators for ONNX Runtime @@ -115,7 +117,9 @@ Take custom operator `soft_nms` for example. ## Known Issues -- None +- "RuntimeError: tuple appears in op that does not forward tuples, unsupported kind: `prim::PythonOp`." + 1. Note generally `cummax` or `cummin` is exportable to ONNX as long as the torch version >= 1.5.0, since `torch.cummax` is only supported with torch >= 1.5.0. But when `cummax` or `cummin` serves as an intermediate component whose outputs is used as inputs for another modules, it's expected that torch version must be >= 1.7.0. Otherwise the above error might arise, when running exported ONNX model with onnxruntime. + 2. Solution: update the torch version to 1.7.0 or higher. ## References diff --git a/mmcv/onnx/symbolic.py b/mmcv/onnx/symbolic.py index c73c830280..1990e3c248 100644 --- a/mmcv/onnx/symbolic.py +++ b/mmcv/onnx/symbolic.py @@ -396,6 +396,16 @@ def grid_sampler(g, align_corners_i=align_corners) +@parse_args('v', 'i') +def cummax(g, input, dim): + return g.op('mmcv::cummax', input, dim_i=dim, outputs=2) + + +@parse_args('v', 'i') +def cummin(g, input, dim): + return g.op('mmcv::cummin', input, dim_i=dim, outputs=2) + + def register_extra_symbolics(opset=11): register_op('one_hot', one_hot, '', opset) register_op('im2col', im2col, '', opset) @@ -421,3 +431,5 @@ def register_extra_symbolics(opset=11): register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset) register_op('new_full', new_full, '', opset) register_op('grid_sampler', grid_sampler, '', opset) + register_op('cummax', cummax, '', opset) + register_op('cummin', cummin, '', opset) diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py index 189506e6aa..f1593369e5 100644 --- a/mmcv/ops/corner_pool.py +++ b/mmcv/ops/corner_pool.py @@ -140,6 +140,15 @@ def __init__(self, mode): def forward(self, x): if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': + if torch.onnx.is_in_onnx_export(): + assert torch.__version__ >= '1.7.0', \ + 'When `cummax` serves as an intermediate component whose '\ + 'outputs is used as inputs for another modules, it\'s '\ + 'expected that pytorch version must be >= 1.7.0, '\ + 'otherwise Error appears like: `RuntimeError: tuple '\ + 'appears in op that does not forward tuples, unsupported '\ + 'kind: prim::PythonOp`.' + dim, flip = self.cummax_dim_flip[self.mode] if flip: x = x.flip(dim) diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index b55114b188..468fb13fff 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -4,6 +4,7 @@ #include "grid_sample.h" #include "nms.h" #include "ort_mmcv_utils.h" +#include "reduce_ops.h" #include "roi_align.h" #include "roi_align_rotated.h" #include "soft_nms.h" @@ -14,6 +15,8 @@ NmsOp c_NmsOp; MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp; MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp; GridSampleOp c_GridSampleOp; +MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp; +MMCVCumMinCustomOp c_MMCVCumMinCustomOp; MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, @@ -52,5 +55,13 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCumMaxCustomOp)) { + return status; + } + + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVCumMinCustomOp)) { + return status; + } + return ortApi->AddCustomOpDomain(options, domain); } diff --git a/mmcv/ops/csrc/onnxruntime/cpu/reduce_ops.cpp b/mmcv/ops/csrc/onnxruntime/cpu/reduce_ops.cpp new file mode 100644 index 0000000000..8c5a03a769 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/reduce_ops.cpp @@ -0,0 +1,187 @@ +#include "reduce_ops.h" + +#include + +#include + +#include "../ort_mmcv_utils.h" + +// modified from +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ReduceOps.cpp + +static inline int64_t maybe_wrap_dim(int64_t dim, int64_t ndims) { + int64_t min = -ndims; + int64_t max = ndims - 1; + assert(dim >= min && dim <= max); + if (dim < 0) dim += ndims; + return dim; +} + +static inline int64_t get_dim_stride(const int64_t dim, const int64_t ndims, + const int64_t *reversed_dim_cumprod) { + return dim == ndims - 1 ? 1 : reversed_dim_cumprod[dim + 1]; +} + +static inline int64_t get_dim_size(const int64_t dim, const int64_t ndims, + const int64_t *reversed_dim_cumprod) { + return dim == ndims - 1 + ? reversed_dim_cumprod[dim] + : reversed_dim_cumprod[dim] / reversed_dim_cumprod[dim + 1]; +} + +template +void cummax_cummin_helper(const T1 *input, T1 *output, T2 *indices, + const int64_t input_dim_size, const int64_t stride) { + Operation op; + T1 out = input[0]; + int64_t idx = 0; + for (int64_t i = 0; i < input_dim_size; i++) { + T1 curr_elem = input[i * stride]; + if (op(curr_elem, out)) { + out = curr_elem; + idx = i; + } + output[i * stride] = out; + indices[i * stride] = idx; + } +} + +// modified `tensor_dim_apply3` from +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorDimApply.h. +// the difference is that: (1) use `reversed_dim_cumprod` for fast computing of +// tensor `size` and `stride`. (2) the same `stride` is used for input, output, +// and indices, since it's unnecessary to use separate values. currently +// `tensor_dim_apply3` is only used for `cummax` and `cummin`, according to the +// official pytorch projects: https://github.com/pytorch/pytorch. +template +void tensor_dim_apply3(const T1 *input, T1 *output, T2 *indices, + const int64_t dim, const int64_t ndims, + const int64_t *reversed_dim_cumprod, Function func) { + int dim_apply_finished = 0; + int64_t input_dim_size = get_dim_size(dim, ndims, reversed_dim_cumprod); + // the same stride is used for input, output and indices + int64_t stride = get_dim_stride(dim, ndims, reversed_dim_cumprod); + std::vector counter(ndims, 0); + + while (!dim_apply_finished) { + // call `func` once to update output and indices + func(input, output, indices, input_dim_size, stride); + if (ndims == 1) break; + for (int64_t dim_i = 0; dim_i < ndims; dim_i++) { + if (dim_i == dim) { + if (dim_i == (ndims - 1)) { + dim_apply_finished = 1; + break; + } + continue; + } + counter[dim_i]++; + + // the same stride is used for input, output, and indices + int64_t stride_dim_i = get_dim_stride(dim_i, ndims, reversed_dim_cumprod); + input += stride_dim_i; + output += stride_dim_i; + indices += stride_dim_i; + + if (counter[dim_i] == get_dim_size(dim_i, ndims, reversed_dim_cumprod)) { + if (dim_i == ndims - 1) { + dim_apply_finished = 1; + break; + } else { + input -= counter[dim_i] * stride_dim_i; + output -= counter[dim_i] * stride_dim_i; + indices -= counter[dim_i] * stride_dim_i; + counter[dim_i] = 0; + } + } else { + break; + } // if + } // for + } // while +} + +template +void CumMax_CumMin_CPU(const T1 *input, T1 *output, T2 *indices, + int64_t *reversed_dim_cumprod, const int64_t dim, + const OrtTensorDimensions &out_dimensions) { + // calculate numel + const int64_t ndims = out_dimensions.size(); + int64_t numel = 1; + for (int64_t dim_i = 0; dim_i < ndims; dim_i++) { + numel *= out_dimensions.data()[dim_i]; + } + + // cummax is only applied to input which is non-zero dim and non-empty + if (numel) { + // compute the cumulative production on dimension size, + // which is then used for computing the stride or size of a specific `dim`. + reversed_dim_cumprod[ndims - 1] = out_dimensions.data()[ndims - 1]; + for (int64_t dim_i = ndims - 2; dim_i >= 0; dim_i--) { + reversed_dim_cumprod[dim_i] = + reversed_dim_cumprod[dim_i + 1] * out_dimensions.data()[dim_i]; + } + + // do cummax or cummin besed on `Operation` type + tensor_dim_apply3( + input, output, indices, dim, ndims, reversed_dim_cumprod, + cummax_cummin_helper); + } +} + +void MMCVCumMaxKernel::Compute(OrtKernelContext *context) { + // get input + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const float *input_data = + reinterpret_cast(ort_.GetTensorData(input)); + + // get ouput + OrtTensorDimensions out_dimensions(ort_, input); + OrtValue *output = ort_.KernelContext_GetOutput( + context, 0, out_dimensions.data(), out_dimensions.size()); + float *output_data = ort_.GetTensorMutableData(output); + OrtValue *indices = ort_.KernelContext_GetOutput( + context, 1, out_dimensions.data(), out_dimensions.size()); + int64_t *indices_data = ort_.GetTensorMutableData(indices); + + // allocate tmp memory for computing the cumulative production on dimension + // size + const int64_t ndims = out_dimensions.size(); + assert(ndims > 0); + int64_t *reversed_dim_cumprod = + (int64_t *)allocator_.Alloc(sizeof(int64_t) * ndims); + + // dim should be wrapped if it's negative (e.g. -1) + const int64_t dim = maybe_wrap_dim(dim_, ndims); + CumMax_CumMin_CPU>( + input_data, output_data, indices_data, reversed_dim_cumprod, dim, + out_dimensions); +} + +void MMCVCumMinKernel::Compute(OrtKernelContext *context) { + // get input + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const float *input_data = + reinterpret_cast(ort_.GetTensorData(input)); + + // get ouput + OrtTensorDimensions out_dimensions(ort_, input); + OrtValue *output = ort_.KernelContext_GetOutput( + context, 0, out_dimensions.data(), out_dimensions.size()); + float *output_data = ort_.GetTensorMutableData(output); + OrtValue *indices = ort_.KernelContext_GetOutput( + context, 1, out_dimensions.data(), out_dimensions.size()); + int64_t *indices_data = ort_.GetTensorMutableData(indices); + + // allocate tmp memory for computing the cumulative production on dimension + // size + const int64_t ndims = out_dimensions.size(); + assert(ndims > 0); + int64_t *reversed_dim_cumprod = + (int64_t *)allocator_.Alloc(sizeof(int64_t) * ndims); + + // dim should be wrapped if it's negative (e.g. -1) + const int64_t dim = maybe_wrap_dim(dim_, ndims); + CumMax_CumMin_CPU>( + input_data, output_data, indices_data, reversed_dim_cumprod, dim, + out_dimensions); +} diff --git a/mmcv/ops/csrc/onnxruntime/reduce_ops.h b/mmcv/ops/csrc/onnxruntime/reduce_ops.h new file mode 100644 index 0000000000..efd8c7b31d --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/reduce_ops.h @@ -0,0 +1,94 @@ +#ifndef ONNXRUNTIME_REDUCE_OPS_H +#define ONNXRUNTIME_REDUCE_OPS_H + +#include + +struct MMCVCumMaxKernel { + public: + MMCVCumMaxKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) { + dim_ = ort_.KernelInfoGetAttribute(info, "dim"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t dim_; +}; + +struct MMCVCumMinKernel { + public: + MMCVCumMinKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) { + dim_ = ort_.KernelInfoGetAttribute(info, "dim"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t dim_; +}; + +struct MMCVCumMaxCustomOp + : Ort::CustomOpBase { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { + return new MMCVCumMaxKernel(api, info); + } + + const char* GetName() const { return "cummax"; } + + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const { return 2; } + ONNXTensorElementDataType GetOutputType(size_t index) const { + if (index == 1) return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char* GetExecutionProviderType() const { + return "CPUExecutionProvider"; + }; +}; + +struct MMCVCumMinCustomOp + : Ort::CustomOpBase { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { + return new MMCVCumMinKernel(api, info); + } + + const char* GetName() const { return "cummin"; } + + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const { return 2; } + ONNXTensorElementDataType GetOutputType(size_t index) const { + if (index == 1) return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char* GetExecutionProviderType() const { + return "CPUExecutionProvider"; + }; +}; + +#endif // ONNXRUNTIME_REDUCE_OPS_H diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 91f0ed3618..a3ad97f6a0 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -218,6 +218,7 @@ def soft_nms(boxes, float(iou_threshold), float(sigma), float(min_score), method_dict[method], int(offset)) + dets = dets[:inds.size(0)] if is_numpy: diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 62859c32f5..c07cd908d9 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -494,3 +494,78 @@ def corner_pool_func(input): pytorch_results = wrapped_model(input.clone()) os.remove(onnx_file) assert np.allclose(pytorch_results, ort_result, atol=1e-5) + + +@pytest.mark.parametrize('key', ['cummax', 'cummin']) +def test_cummax_cummin(key, opset=11): + if torch.__version__ == 'parrots': + pytest.skip('onnx is not supported in parrots directly') + + # Note generally `cummax` or `cummin` is exportable to ONNX + # as long as the pytorch version >= 1.5.0, since `torch.cummax` + # is only supported with torch >= 1.5.0. + # But when `cummax` or `cummin` serves as an intermediate component + # whose outputs is used as inputs for another modules, it's expected + # that pytorch version must be >= 1.7.0. Otherwise error appears like: + # `RuntimeError: tuple appears in op that does not forward tuples, + # unsupported 'kind: prim::PythonOp`. + if version.parse(torch.__version__) < version.parse('1.7.0'): + pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0') + + # register custom op `mmcv::cummax` and `mmcv::cummin` + from mmcv.onnx.symbolic import register_extra_symbolics + register_extra_symbolics(opset) + + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + if not os.path.exists(ort_custom_op_path): + pytest.skip('custom ops for onnxruntime are not compiled.') + + input_list = [ + # arbitrary shape, e.g. 1-D, 2-D, 3-D, ... + torch.rand((2, 3, 4, 1, 5)), + torch.rand((1)), + torch.rand((2, 0, 1)), # tensor.numel() is 0 + torch.FloatTensor(), # empty tensor + ] + + cummax_cummin_funcs = {'cummax': torch.cummax, 'cummin': torch.cummin} + + for input in input_list: + ndims = input.dim() + # valid dim range is [-ndims, ndims-1] + # test for all `dim` value which is valid + for dim in range(-ndims, ndims): + cummax_func = partial(cummax_cummin_funcs[key], dim=dim) + wrapped_model = WrapFunction(cummax_func).eval() + + with torch.no_grad(): + torch.onnx.export( + wrapped_model, + input, + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['input'], + output_names=['output', 'indices'], + opset_version=opset) + + onnx_model = onnx.load(onnx_file) + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + + session_options = rt.SessionOptions() + session_options.register_custom_ops_library(ort_custom_op_path) + sess = rt.InferenceSession(onnx_file, session_options) + ort_output, ort_inds = sess.run(None, + {'input': input.detach().numpy()}) + pytorch_output, pytorch_inds = wrapped_model(input.clone()) + pytorch_output = pytorch_output.detach().numpy() + pytorch_inds = pytorch_inds.detach().numpy() + assert np.allclose(pytorch_output, ort_output, atol=1e-5) + assert np.all(pytorch_inds == ort_inds) + os.remove(onnx_file)