|
| 1 | +// Modified from |
| 2 | +// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated |
| 3 | +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 4 | +#include "roi_align_rotated.h" |
| 5 | +#include "../ort_mmcv_utils.h" |
| 6 | + |
| 7 | +struct PreCalc { |
| 8 | + int pos1; |
| 9 | + int pos2; |
| 10 | + int pos3; |
| 11 | + int pos4; |
| 12 | + float w1; |
| 13 | + float w2; |
| 14 | + float w3; |
| 15 | + float w4; |
| 16 | +}; |
| 17 | + |
| 18 | +void pre_calc_for_bilinear_interpolate( |
| 19 | + const int height, const int width, const int pooled_height, |
| 20 | + const int pooled_width, const int iy_upper, const int ix_upper, |
| 21 | + float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w, |
| 22 | + int roi_bin_grid_h, int roi_bin_grid_w, float roi_center_h, |
| 23 | + float roi_center_w, float cos_theta, float sin_theta, |
| 24 | + std::vector<PreCalc> &pre_calc) { |
| 25 | + int pre_calc_index = 0; |
| 26 | + for (int ph = 0; ph < pooled_height; ph++) { |
| 27 | + for (int pw = 0; pw < pooled_width; pw++) { |
| 28 | + for (int iy = 0; iy < iy_upper; iy++) { |
| 29 | + const float yy = |
| 30 | + roi_start_h + ph * bin_size_h + |
| 31 | + static_cast<float>(iy + .5f) * bin_size_h / |
| 32 | + static_cast<float>(roi_bin_grid_h); // e.g., 0.5, 1.5 |
| 33 | + for (int ix = 0; ix < ix_upper; ix++) { |
| 34 | + const float xx = roi_start_w + pw * bin_size_w + |
| 35 | + static_cast<float>(ix + .5f) * bin_size_w / |
| 36 | + static_cast<float>(roi_bin_grid_w); |
| 37 | + |
| 38 | + // Rotate by theta around the center and translate |
| 39 | + // In image space, (y, x) is the order for Right Handed System, |
| 40 | + // and this is essentially multiplying the point by a rotation matrix |
| 41 | + // to rotate it counterclockwise through angle theta. |
| 42 | + float y = yy * cos_theta - xx * sin_theta + roi_center_h; |
| 43 | + float x = yy * sin_theta + xx * cos_theta + roi_center_w; |
| 44 | + // deal with: inverse elements are out of feature map boundary |
| 45 | + if (y < -1.0 || y > height || x < -1.0 || x > width) { |
| 46 | + // empty |
| 47 | + PreCalc pc; |
| 48 | + pc.pos1 = 0; |
| 49 | + pc.pos2 = 0; |
| 50 | + pc.pos3 = 0; |
| 51 | + pc.pos4 = 0; |
| 52 | + pc.w1 = 0; |
| 53 | + pc.w2 = 0; |
| 54 | + pc.w3 = 0; |
| 55 | + pc.w4 = 0; |
| 56 | + pre_calc[pre_calc_index] = pc; |
| 57 | + pre_calc_index += 1; |
| 58 | + continue; |
| 59 | + } |
| 60 | + |
| 61 | + if (y < 0) { |
| 62 | + y = 0; |
| 63 | + } |
| 64 | + if (x < 0) { |
| 65 | + x = 0; |
| 66 | + } |
| 67 | + |
| 68 | + int y_low = (int)y; |
| 69 | + int x_low = (int)x; |
| 70 | + int y_high; |
| 71 | + int x_high; |
| 72 | + |
| 73 | + if (y_low >= height - 1) { |
| 74 | + y_high = y_low = height - 1; |
| 75 | + y = (float)y_low; |
| 76 | + } else { |
| 77 | + y_high = y_low + 1; |
| 78 | + } |
| 79 | + |
| 80 | + if (x_low >= width - 1) { |
| 81 | + x_high = x_low = width - 1; |
| 82 | + x = (float)x_low; |
| 83 | + } else { |
| 84 | + x_high = x_low + 1; |
| 85 | + } |
| 86 | + |
| 87 | + float ly = y - y_low; |
| 88 | + float lx = x - x_low; |
| 89 | + float hy = 1. - ly, hx = 1. - lx; |
| 90 | + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; |
| 91 | + |
| 92 | + // save weights and indices |
| 93 | + PreCalc pc; |
| 94 | + pc.pos1 = y_low * width + x_low; |
| 95 | + pc.pos2 = y_low * width + x_high; |
| 96 | + pc.pos3 = y_high * width + x_low; |
| 97 | + pc.pos4 = y_high * width + x_high; |
| 98 | + pc.w1 = w1; |
| 99 | + pc.w2 = w2; |
| 100 | + pc.w3 = w3; |
| 101 | + pc.w4 = w4; |
| 102 | + pre_calc[pre_calc_index] = pc; |
| 103 | + |
| 104 | + pre_calc_index += 1; |
| 105 | + } |
| 106 | + } |
| 107 | + } |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +void ROIAlignRotatedForwardCPU(const int nthreads, const float *input, |
| 112 | + const float *rois, float *output, |
| 113 | + const float &spatial_scale, const int aligned, |
| 114 | + const int clockwise, const int channels, |
| 115 | + const int height, const int width, |
| 116 | + const int pooled_height, const int pooled_width, |
| 117 | + const int sampling_ratio) { |
| 118 | + int n_rois = nthreads / channels / pooled_width / pooled_height; |
| 119 | + // (n, c, ph, pw) is an element in the pooled output |
| 120 | + // can be parallelized using omp |
| 121 | + // #pragma omp parallel for num_threads(32) |
| 122 | + for (int n = 0; n < n_rois; n++) { |
| 123 | + int index_n = n * channels * pooled_width * pooled_height; |
| 124 | + |
| 125 | + const float *current_roi = rois + n * 6; |
| 126 | + int roi_batch_ind = current_roi[0]; |
| 127 | + |
| 128 | + // Do not use rounding; this implementation detail is critical |
| 129 | + float offset = aligned ? (float)0.5 : (float)0.0; |
| 130 | + float roi_center_w = current_roi[1] * spatial_scale - offset; |
| 131 | + float roi_center_h = current_roi[2] * spatial_scale - offset; |
| 132 | + float roi_width = current_roi[3] * spatial_scale; |
| 133 | + float roi_height = current_roi[4] * spatial_scale; |
| 134 | + // float theta = current_roi[5] * M_PI / 180.0; |
| 135 | + float theta = current_roi[5]; // Radian angle by default |
| 136 | + if (clockwise) { |
| 137 | + theta = -theta; |
| 138 | + } |
| 139 | + float cos_theta = cos(theta); |
| 140 | + float sin_theta = sin(theta); |
| 141 | + if (!aligned) { // for backward-compatibility only |
| 142 | + roi_width = std::max(roi_width, (float)1.); |
| 143 | + roi_height = std::max(roi_height, (float)1.); |
| 144 | + } |
| 145 | + |
| 146 | + float bin_size_h = |
| 147 | + static_cast<float>(roi_height) / static_cast<float>(pooled_height); |
| 148 | + float bin_size_w = |
| 149 | + static_cast<float>(roi_width) / static_cast<float>(pooled_width); |
| 150 | + |
| 151 | + // We use roi_bin_grid to sample the grid and mimic integral |
| 152 | + int roi_bin_grid_h = (sampling_ratio > 0) |
| 153 | + ? sampling_ratio |
| 154 | + : ceil(roi_height / pooled_height); // e.g., = 2 |
| 155 | + int roi_bin_grid_w = |
| 156 | + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); |
| 157 | + |
| 158 | + // We do average (integral) pooling inside a bin |
| 159 | + const float count = |
| 160 | + std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 |
| 161 | + |
| 162 | + // we want to precalculate indices and weights shared by all channels, |
| 163 | + // this is the key point of optimization |
| 164 | + std::vector<PreCalc> pre_calc(roi_bin_grid_h * roi_bin_grid_w * |
| 165 | + pooled_width * pooled_height); |
| 166 | + |
| 167 | + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). |
| 168 | + // Appropriate translation needs to be applied after. |
| 169 | + float roi_start_h = -roi_height / 2.0; |
| 170 | + float roi_start_w = -roi_width / 2.0; |
| 171 | + |
| 172 | + pre_calc_for_bilinear_interpolate( |
| 173 | + height, width, pooled_height, pooled_width, roi_bin_grid_h, |
| 174 | + roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w, |
| 175 | + roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta, |
| 176 | + sin_theta, pre_calc); |
| 177 | + |
| 178 | + for (int c = 0; c < channels; c++) { |
| 179 | + int index_n_c = index_n + c * pooled_width * pooled_height; |
| 180 | + const float *offset_input = |
| 181 | + input + (roi_batch_ind * channels + c) * height * width; |
| 182 | + int pre_calc_index = 0; |
| 183 | + |
| 184 | + for (int ph = 0; ph < pooled_height; ph++) { |
| 185 | + for (int pw = 0; pw < pooled_width; pw++) { |
| 186 | + int index = index_n_c + ph * pooled_width + pw; |
| 187 | + |
| 188 | + float output_val = 0.; |
| 189 | + for (int iy = 0; iy < roi_bin_grid_h; iy++) { |
| 190 | + for (int ix = 0; ix < roi_bin_grid_w; ix++) { |
| 191 | + PreCalc pc = pre_calc[pre_calc_index]; |
| 192 | + output_val += pc.w1 * offset_input[pc.pos1] + |
| 193 | + pc.w2 * offset_input[pc.pos2] + |
| 194 | + pc.w3 * offset_input[pc.pos3] + |
| 195 | + pc.w4 * offset_input[pc.pos4]; |
| 196 | + |
| 197 | + pre_calc_index += 1; |
| 198 | + } |
| 199 | + } |
| 200 | + output_val /= count; |
| 201 | + |
| 202 | + output[index] = output_val; |
| 203 | + } // for pw |
| 204 | + } // for ph |
| 205 | + } // for c |
| 206 | + } // for n |
| 207 | +} |
| 208 | + |
| 209 | +void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext *context) { |
| 210 | + // Setup inputs |
| 211 | + const OrtValue *input_X = ort_.KernelContext_GetInput(context, 0); |
| 212 | + const float *X_data = |
| 213 | + reinterpret_cast<const float *>(ort_.GetTensorData<float>(input_X)); |
| 214 | + const OrtValue *input_rois = ort_.KernelContext_GetInput(context, 1); |
| 215 | + const float *rois = reinterpret_cast<const float *>( |
| 216 | + ort_.GetTensorData<const float *>(input_rois)); |
| 217 | + |
| 218 | + // Setup output |
| 219 | + OrtTensorDimensions out_dimensions(ort_, input_X); |
| 220 | + OrtTensorDimensions roi_dimensions(ort_, input_rois); |
| 221 | + |
| 222 | + int batch_size = out_dimensions.data()[0]; |
| 223 | + int input_channels = out_dimensions.data()[1]; |
| 224 | + int input_height = out_dimensions.data()[2]; |
| 225 | + int input_width = out_dimensions.data()[3]; |
| 226 | + |
| 227 | + out_dimensions.data()[0] = roi_dimensions.data()[0]; |
| 228 | + out_dimensions.data()[2] = aligned_height_; |
| 229 | + out_dimensions.data()[3] = aligned_width_; |
| 230 | + |
| 231 | + OrtValue *output = ort_.KernelContext_GetOutput( |
| 232 | + context, 0, out_dimensions.data(), out_dimensions.size()); |
| 233 | + float *out = ort_.GetTensorMutableData<float>(output); |
| 234 | + OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output); |
| 235 | + ort_.ReleaseTensorTypeAndShapeInfo(output_info); |
| 236 | + |
| 237 | + // TODO: forward here |
| 238 | + int output_size = out_dimensions.data()[0]; |
| 239 | + for (auto i = 1; i < out_dimensions.size(); ++i) { |
| 240 | + output_size *= out_dimensions.data()[i]; |
| 241 | + } |
| 242 | + ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, |
| 243 | + aligned_, clockwise_, input_channels, input_height, |
| 244 | + input_width, aligned_height_, aligned_width_, |
| 245 | + sampling_ratio_); |
| 246 | +} |
0 commit comments