Skip to content

Commit ee041ce

Browse files
[Feature]: Add Rotated ROI align op for pytorch (cpu&cuda), parrots (cpu&cuda) and onnxruntime (cpu) (#933)
* add roi_align_rotated * code format * Add align key to roi align rotated * Add clockwise for rotated roi align * fix bugs in onnx export * Add docstring for RoIAlignRotated * remove cuda unittest * Reformat c++ code * add onnx roi align rotated file * fix unittest * Add cpu and float64 of cuda support for parrots * code format * Add unified header to roi align rotated Co-authored-by: luopeichao <luopeichao@sensetime.com>
1 parent de4f14e commit ee041ce

18 files changed

+2298
-1
lines changed

mmcv/ops/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
rel_roi_point_to_rel_img_point)
2525
from .psa_mask import PSAMask
2626
from .roi_align import RoIAlign, roi_align
27+
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
2728
from .roi_pool import RoIPool, roi_pool
2829
from .saconv import SAConv2d
2930
from .sync_bn import SyncBatchNorm
@@ -44,5 +45,6 @@
4445
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
4546
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
4647
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
47-
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu'
48+
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
49+
'RoIAlignRotated', 'roi_align_rotated'
4850
]

mmcv/ops/box_iou_rotated.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
1616
Arguments:
1717
boxes1 (Tensor): rotated bboxes 1. \
1818
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
19+
Note that theta is in radian.
1920
boxes2 (Tensor): rotated bboxes 2. \
2021
It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
22+
Note that theta is in radian.
2123
mode (str): "iou" (intersection over union) or iof (intersection over
2224
foreground).
2325

mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
#include "nms.h"
55
#include "ort_mmcv_utils.h"
66
#include "roi_align.h"
7+
#include "roi_align_rotated.h"
78
#include "soft_nms.h"
89

910
const char *c_MMCVOpDomain = "mmcv";
1011
SoftNmsOp c_SoftNmsOp;
1112
NmsOp c_NmsOp;
1213
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
14+
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
1315
GridSampleOp c_GridSampleOp;
1416

1517
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
@@ -34,6 +36,11 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
3436
return status;
3537
}
3638

39+
if (auto status =
40+
ortApi->CustomOpDomain_Add(domain, &c_MMCVRoIAlignRotatedCustomOp)) {
41+
return status;
42+
}
43+
3744
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_GridSampleOp)) {
3845
return status;
3946
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Modified from
2+
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated
3+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
#include "roi_align_rotated.h"
5+
#include "../ort_mmcv_utils.h"
6+
7+
struct PreCalc {
8+
int pos1;
9+
int pos2;
10+
int pos3;
11+
int pos4;
12+
float w1;
13+
float w2;
14+
float w3;
15+
float w4;
16+
};
17+
18+
void pre_calc_for_bilinear_interpolate(
19+
const int height, const int width, const int pooled_height,
20+
const int pooled_width, const int iy_upper, const int ix_upper,
21+
float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w,
22+
int roi_bin_grid_h, int roi_bin_grid_w, float roi_center_h,
23+
float roi_center_w, float cos_theta, float sin_theta,
24+
std::vector<PreCalc> &pre_calc) {
25+
int pre_calc_index = 0;
26+
for (int ph = 0; ph < pooled_height; ph++) {
27+
for (int pw = 0; pw < pooled_width; pw++) {
28+
for (int iy = 0; iy < iy_upper; iy++) {
29+
const float yy =
30+
roi_start_h + ph * bin_size_h +
31+
static_cast<float>(iy + .5f) * bin_size_h /
32+
static_cast<float>(roi_bin_grid_h); // e.g., 0.5, 1.5
33+
for (int ix = 0; ix < ix_upper; ix++) {
34+
const float xx = roi_start_w + pw * bin_size_w +
35+
static_cast<float>(ix + .5f) * bin_size_w /
36+
static_cast<float>(roi_bin_grid_w);
37+
38+
// Rotate by theta around the center and translate
39+
// In image space, (y, x) is the order for Right Handed System,
40+
// and this is essentially multiplying the point by a rotation matrix
41+
// to rotate it counterclockwise through angle theta.
42+
float y = yy * cos_theta - xx * sin_theta + roi_center_h;
43+
float x = yy * sin_theta + xx * cos_theta + roi_center_w;
44+
// deal with: inverse elements are out of feature map boundary
45+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
46+
// empty
47+
PreCalc pc;
48+
pc.pos1 = 0;
49+
pc.pos2 = 0;
50+
pc.pos3 = 0;
51+
pc.pos4 = 0;
52+
pc.w1 = 0;
53+
pc.w2 = 0;
54+
pc.w3 = 0;
55+
pc.w4 = 0;
56+
pre_calc[pre_calc_index] = pc;
57+
pre_calc_index += 1;
58+
continue;
59+
}
60+
61+
if (y < 0) {
62+
y = 0;
63+
}
64+
if (x < 0) {
65+
x = 0;
66+
}
67+
68+
int y_low = (int)y;
69+
int x_low = (int)x;
70+
int y_high;
71+
int x_high;
72+
73+
if (y_low >= height - 1) {
74+
y_high = y_low = height - 1;
75+
y = (float)y_low;
76+
} else {
77+
y_high = y_low + 1;
78+
}
79+
80+
if (x_low >= width - 1) {
81+
x_high = x_low = width - 1;
82+
x = (float)x_low;
83+
} else {
84+
x_high = x_low + 1;
85+
}
86+
87+
float ly = y - y_low;
88+
float lx = x - x_low;
89+
float hy = 1. - ly, hx = 1. - lx;
90+
float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
91+
92+
// save weights and indices
93+
PreCalc pc;
94+
pc.pos1 = y_low * width + x_low;
95+
pc.pos2 = y_low * width + x_high;
96+
pc.pos3 = y_high * width + x_low;
97+
pc.pos4 = y_high * width + x_high;
98+
pc.w1 = w1;
99+
pc.w2 = w2;
100+
pc.w3 = w3;
101+
pc.w4 = w4;
102+
pre_calc[pre_calc_index] = pc;
103+
104+
pre_calc_index += 1;
105+
}
106+
}
107+
}
108+
}
109+
}
110+
111+
void ROIAlignRotatedForwardCPU(const int nthreads, const float *input,
112+
const float *rois, float *output,
113+
const float &spatial_scale, const int aligned,
114+
const int clockwise, const int channels,
115+
const int height, const int width,
116+
const int pooled_height, const int pooled_width,
117+
const int sampling_ratio) {
118+
int n_rois = nthreads / channels / pooled_width / pooled_height;
119+
// (n, c, ph, pw) is an element in the pooled output
120+
// can be parallelized using omp
121+
// #pragma omp parallel for num_threads(32)
122+
for (int n = 0; n < n_rois; n++) {
123+
int index_n = n * channels * pooled_width * pooled_height;
124+
125+
const float *current_roi = rois + n * 6;
126+
int roi_batch_ind = current_roi[0];
127+
128+
// Do not use rounding; this implementation detail is critical
129+
float offset = aligned ? (float)0.5 : (float)0.0;
130+
float roi_center_w = current_roi[1] * spatial_scale - offset;
131+
float roi_center_h = current_roi[2] * spatial_scale - offset;
132+
float roi_width = current_roi[3] * spatial_scale;
133+
float roi_height = current_roi[4] * spatial_scale;
134+
// float theta = current_roi[5] * M_PI / 180.0;
135+
float theta = current_roi[5]; // Radian angle by default
136+
if (clockwise) {
137+
theta = -theta;
138+
}
139+
float cos_theta = cos(theta);
140+
float sin_theta = sin(theta);
141+
if (!aligned) { // for backward-compatibility only
142+
roi_width = std::max(roi_width, (float)1.);
143+
roi_height = std::max(roi_height, (float)1.);
144+
}
145+
146+
float bin_size_h =
147+
static_cast<float>(roi_height) / static_cast<float>(pooled_height);
148+
float bin_size_w =
149+
static_cast<float>(roi_width) / static_cast<float>(pooled_width);
150+
151+
// We use roi_bin_grid to sample the grid and mimic integral
152+
int roi_bin_grid_h = (sampling_ratio > 0)
153+
? sampling_ratio
154+
: ceil(roi_height / pooled_height); // e.g., = 2
155+
int roi_bin_grid_w =
156+
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
157+
158+
// We do average (integral) pooling inside a bin
159+
const float count =
160+
std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
161+
162+
// we want to precalculate indices and weights shared by all channels,
163+
// this is the key point of optimization
164+
std::vector<PreCalc> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
165+
pooled_width * pooled_height);
166+
167+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
168+
// Appropriate translation needs to be applied after.
169+
float roi_start_h = -roi_height / 2.0;
170+
float roi_start_w = -roi_width / 2.0;
171+
172+
pre_calc_for_bilinear_interpolate(
173+
height, width, pooled_height, pooled_width, roi_bin_grid_h,
174+
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
175+
roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta,
176+
sin_theta, pre_calc);
177+
178+
for (int c = 0; c < channels; c++) {
179+
int index_n_c = index_n + c * pooled_width * pooled_height;
180+
const float *offset_input =
181+
input + (roi_batch_ind * channels + c) * height * width;
182+
int pre_calc_index = 0;
183+
184+
for (int ph = 0; ph < pooled_height; ph++) {
185+
for (int pw = 0; pw < pooled_width; pw++) {
186+
int index = index_n_c + ph * pooled_width + pw;
187+
188+
float output_val = 0.;
189+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
190+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
191+
PreCalc pc = pre_calc[pre_calc_index];
192+
output_val += pc.w1 * offset_input[pc.pos1] +
193+
pc.w2 * offset_input[pc.pos2] +
194+
pc.w3 * offset_input[pc.pos3] +
195+
pc.w4 * offset_input[pc.pos4];
196+
197+
pre_calc_index += 1;
198+
}
199+
}
200+
output_val /= count;
201+
202+
output[index] = output_val;
203+
} // for pw
204+
} // for ph
205+
} // for c
206+
} // for n
207+
}
208+
209+
void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext *context) {
210+
// Setup inputs
211+
const OrtValue *input_X = ort_.KernelContext_GetInput(context, 0);
212+
const float *X_data =
213+
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input_X));
214+
const OrtValue *input_rois = ort_.KernelContext_GetInput(context, 1);
215+
const float *rois = reinterpret_cast<const float *>(
216+
ort_.GetTensorData<const float *>(input_rois));
217+
218+
// Setup output
219+
OrtTensorDimensions out_dimensions(ort_, input_X);
220+
OrtTensorDimensions roi_dimensions(ort_, input_rois);
221+
222+
int batch_size = out_dimensions.data()[0];
223+
int input_channels = out_dimensions.data()[1];
224+
int input_height = out_dimensions.data()[2];
225+
int input_width = out_dimensions.data()[3];
226+
227+
out_dimensions.data()[0] = roi_dimensions.data()[0];
228+
out_dimensions.data()[2] = aligned_height_;
229+
out_dimensions.data()[3] = aligned_width_;
230+
231+
OrtValue *output = ort_.KernelContext_GetOutput(
232+
context, 0, out_dimensions.data(), out_dimensions.size());
233+
float *out = ort_.GetTensorMutableData<float>(output);
234+
OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output);
235+
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
236+
237+
// TODO: forward here
238+
int output_size = out_dimensions.data()[0];
239+
for (auto i = 1; i < out_dimensions.size(); ++i) {
240+
output_size *= out_dimensions.data()[i];
241+
}
242+
ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_,
243+
aligned_, clockwise_, input_channels, input_height,
244+
input_width, aligned_height_, aligned_width_,
245+
sampling_ratio_);
246+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#ifndef ONNXRUNTIME_ROI_ALIGN_ROTATED_H
2+
#define ONNXRUNTIME_ROI_ALIGN_ROTATED_H
3+
4+
#include <assert.h>
5+
#include <onnxruntime_cxx_api.h>
6+
7+
#include <cmath>
8+
#include <mutex>
9+
#include <string>
10+
#include <vector>
11+
12+
struct MMCVRoIAlignRotatedKernel {
13+
public:
14+
MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
15+
: ort_(ort) {
16+
aligned_height_ =
17+
ort_.KernelInfoGetAttribute<int64_t>(info, "output_height");
18+
aligned_width_ = ort_.KernelInfoGetAttribute<int64_t>(info, "output_width");
19+
sampling_ratio_ =
20+
ort_.KernelInfoGetAttribute<int64_t>(info, "sampling_ratio");
21+
spatial_scale_ = ort_.KernelInfoGetAttribute<float>(info, "spatial_scale");
22+
aligned_ = ort_.KernelInfoGetAttribute<int64_t>(info, "aligned");
23+
clockwise_ = ort_.KernelInfoGetAttribute<int64_t>(info, "clockwise");
24+
}
25+
26+
void Compute(OrtKernelContext* context);
27+
28+
private:
29+
Ort::CustomOpApi ort_;
30+
int aligned_height_;
31+
int aligned_width_;
32+
float spatial_scale_;
33+
int sampling_ratio_;
34+
int aligned_;
35+
int clockwise_;
36+
};
37+
38+
struct MMCVRoIAlignRotatedCustomOp
39+
: Ort::CustomOpBase<MMCVRoIAlignRotatedCustomOp,
40+
MMCVRoIAlignRotatedKernel> {
41+
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
42+
return new MMCVRoIAlignRotatedKernel(api, info);
43+
}
44+
const char* GetName() const { return "MMCVRoIAlignRotated"; }
45+
46+
size_t GetInputTypeCount() const { return 2; }
47+
ONNXTensorElementDataType GetInputType(size_t) const {
48+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
49+
}
50+
51+
size_t GetOutputTypeCount() const { return 1; }
52+
ONNXTensorElementDataType GetOutputType(size_t) const {
53+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
54+
}
55+
56+
// force cpu
57+
const char* GetExecutionProviderType() const {
58+
return "CPUExecutionProvider";
59+
}
60+
};
61+
#endif // ONNXRUNTIME_ROI_ALIGN_ROTATED_H

0 commit comments

Comments
 (0)