diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e0a9a3648c..97c9126027 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -60,3 +60,4 @@ We implement common ops used in detection, segmentation, etc. | UpFirDn2d | | √ | | | | Voxelization | √ | √ | | | | PrRoIPool | | √ | | | +| BezierAlign | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 6b4622146c..3bcc048e7c 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -60,3 +60,4 @@ MMCV 提供了检测、分割等任务中常用的算子 | UpFirDn2d | | √ | | | | Voxelization | √ | √ | | | | PrRoIPool | | √ | | | +| BezierAlign | √ | √ | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index e96c1577ba..e48628e969 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -3,6 +3,7 @@ from .assign_score_withk import assign_score_withk from .ball_query import ball_query from .bbox import bbox_overlaps +from .bezier_align import BezierAlign, bezier_align from .border_align import BorderAlign, border_align from .box_iou_quadri import box_iou_quadri from .box_iou_rotated import box_iou_rotated @@ -102,5 +103,5 @@ 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance', - 'PrRoIPool', 'prroi_pool' + 'PrRoIPool', 'prroi_pool', 'BezierAlign', 'bezier_align' ] diff --git a/mmcv/ops/bezier_align.py b/mmcv/ops/bezier_align.py new file mode 100644 index 0000000000..6db7f5c8d8 --- /dev/null +++ b/mmcv/ops/bezier_align.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['bezier_align_forward', 'bezier_align_backward']) + + +class BezierAlignFunction(Function): + + @staticmethod + def forward(ctx, + input: torch.Tensor, + beziers: torch.Tensor, + output_size: Union[int, Tuple[int, int]], + spatial_scale: Union[int, float] = 1.0, + sampling_ratio: int = 0, + aligned: bool = True) -> torch.Tensor: + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + ctx.sampling_ratio = sampling_ratio + ctx.aligned = aligned + + assert beziers.size(1) == 17 + output_shape = (beziers.size(0), input.size(1), ctx.output_size[0], + ctx.output_size[1]) + output = input.new_zeros(output_shape) + ext_module.bezier_align_forward( + input, + beziers, + output, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + aligned=ctx.aligned) + + ctx.save_for_backward(beziers) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output: torch.Tensor): + beziers = ctx.saved_tensors[0] + grad_input = grad_output.new_zeros(ctx.input_shape) + grad_output = grad_output.contiguous() + ext_module.bezier_align_backward( + grad_output, + beziers, + grad_input, + aligned_height=ctx.output_size[0], + aligned_width=ctx.output_size[1], + spatial_scale=ctx.spatial_scale, + sampling_ratio=ctx.sampling_ratio, + aligned=ctx.aligned) + return grad_input, None, None, None, None, None + + +bezier_align = BezierAlignFunction.apply + + +class BezierAlign(nn.Module): + """Bezier align pooling layer. + + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each + output sample. 0 to take samples densely for current models. + aligned (bool): if False, use the legacy implementation in + MMDetection. If True, align the results more perfectly. + + Note: + The implementation of BezierAlign is modified from + https://github.com/aim-uofa/AdelaiDet + + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel + indices (in our pixel model) are computed by floor(c - 0.5) and + ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete + indices [0] and [1] (which are sampled from the underlying signal + at continuous coordinates 0.5 and 1.5). But the original roi_align + (aligned=False) does not subtract the 0.5 when computing + neighboring pixel indices and therefore it uses pixels with a + slightly incorrect alignment (relative to our pixel model) when + performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; + + The difference does not make a difference to the model's + performance if ROIAlign is used together with conv layers. + """ + + def __init__( + self, + output_size: Tuple, + spatial_scale: Union[int, float], + sampling_ratio: int, + aligned: bool = True, + ) -> None: + super().__init__() + + self.output_size = _pair(output_size) + self.spatial_scale = float(spatial_scale) + self.sampling_ratio = int(sampling_ratio) + self.aligned = aligned + + def forward(self, input: torch.Tensor, + beziers: torch.Tensor) -> torch.Tensor: + """BezierAlign forward. + + Args: + inputs (Tensor): input features. + beziers (Tensor): beziers for align. + """ + return bezier_align(input, beziers, self.output_size, + self.spatial_scale, self.sampling_ratio, + self.aligned) + + def __repr__(self): + s = self.__class__.__name__ + s += f'(output_size={self.output_size}, ' + s += f'spatial_scale={self.spatial_scale})' + s += f'sampling_ratio={self.sampling_ratio})' + s += f'aligned={self.aligned})' + return s diff --git a/mmcv/ops/csrc/common/cuda/bezier_align_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/bezier_align_cuda_kernel.cuh new file mode 100644 index 0000000000..537610416e --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/bezier_align_cuda_kernel.cuh @@ -0,0 +1,230 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/aim-uofa/AdelaiDet/blob/master/adet/layers/csrc/BezierAlign/BezierAlign_cuda.cu +#ifndef BEZIER_ALIGN_CUDA_KERNEL_CUH +#define BEZIER_ALIGN_CUDA_KERNEL_CUH + +#include +#ifdef MMCV_WITH_TRT +#include "common_cuda_helper.hpp" +#else // MMCV_WITH_TRT +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else // MMCV_USE_PARROTS +#include "pytorch_cuda_helper.hpp" +#endif // MMCV_USE_PARROTS +#endif // MMCV_WITH_TRT + +template +__device__ T bezier_curve(const T p0, const T p1, const T p2, const T p3, + const T u) { + return ((1. - u) * (1. - u) * (1. - u) * p0 + + 3. * u * (1. - u) * (1. - u) * p1 + 3. * u * u * (1. - u) * p2 + + u * u * u * p3); +} + +template +__global__ void bezier_align_forward_cuda_kernel( + const int nthreads, + const T *bottom_data, // inputs + const T *bottom_rois, // bottom rois contains the bezier curve + T *top_data, // outputs + const int pooled_height, const int pooled_width, const T spatial_scale, + const int sampling_ratio, bool aligned, const int channels, + const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + // beziers have size Nx(1+8*2) = Nx17 + const T *offset_bottom_rois = bottom_rois + n * 17; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + + // TODO: avoid this by using parallel annotation, for good + T p0_x = offset_bottom_rois[1] * spatial_scale; + T p0_y = offset_bottom_rois[2] * spatial_scale; + T p1_x = offset_bottom_rois[3] * spatial_scale; + T p1_y = offset_bottom_rois[4] * spatial_scale; + T p2_x = offset_bottom_rois[5] * spatial_scale; + T p2_y = offset_bottom_rois[6] * spatial_scale; + T p3_x = offset_bottom_rois[7] * spatial_scale; + T p3_y = offset_bottom_rois[8] * spatial_scale; + T p4_x = offset_bottom_rois[15] * spatial_scale; + T p4_y = offset_bottom_rois[16] * spatial_scale; + T p5_x = offset_bottom_rois[13] * spatial_scale; + T p5_y = offset_bottom_rois[14] * spatial_scale; + T p6_x = offset_bottom_rois[11] * spatial_scale; + T p6_y = offset_bottom_rois[12] * spatial_scale; + T p7_x = offset_bottom_rois[9] * spatial_scale; + T p7_y = offset_bottom_rois[10] * spatial_scale; + + // compute the coords + const T u = pw / static_cast(pooled_width); + const T v = ph / static_cast(pooled_height); + const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u); + const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u); + const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u); + const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u); + const T x_center = x1 * v + x0 * (1. - v) - offset; + const T y_center = y1 * v + y0 * (1. - v) - offset; + + T roi_width = max(abs(p0_x - p3_x), abs(p4_x - p7_x)); + T roi_height = max(abs(p0_y - p3_y), abs(p4_y - p7_y)); + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T *offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros == 0/1, instead of NaN. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = y_center - (T)0.5 * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = x_center - (T)0.5 * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, + index); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + +template +__global__ void bezier_align_backward_cuda_kernel( + const int nthreads, const T *top_diff, const T *bottom_rois, T *bottom_diff, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int sampling_ratio, bool aligned, const int channels, + const int height, const int width) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + // beziers have size Nx(1+8*2) = Nx17 + const T *offset_bottom_rois = bottom_rois + n * 17; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T p0_x = offset_bottom_rois[1] * spatial_scale; + T p0_y = offset_bottom_rois[2] * spatial_scale; + T p1_x = offset_bottom_rois[3] * spatial_scale; + T p1_y = offset_bottom_rois[4] * spatial_scale; + T p2_x = offset_bottom_rois[5] * spatial_scale; + T p2_y = offset_bottom_rois[6] * spatial_scale; + T p3_x = offset_bottom_rois[7] * spatial_scale; + T p3_y = offset_bottom_rois[8] * spatial_scale; + T p4_x = offset_bottom_rois[15] * spatial_scale; + T p4_y = offset_bottom_rois[16] * spatial_scale; + T p5_x = offset_bottom_rois[13] * spatial_scale; + T p5_y = offset_bottom_rois[14] * spatial_scale; + T p6_x = offset_bottom_rois[11] * spatial_scale; + T p6_y = offset_bottom_rois[12] * spatial_scale; + T p7_x = offset_bottom_rois[9] * spatial_scale; + T p7_y = offset_bottom_rois[10] * spatial_scale; + + // compute the coords + const T u = pw / static_cast(pooled_width); + const T v = ph / static_cast(pooled_height); + const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u); + const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u); + const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u); + const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u); + const T x_center = x1 * v + x0 * (1. - v) - offset; + const T y_center = y1 * v + y0 * (1. - v) - offset; + + T roi_width = max(abs(p0_x - p3_x), abs(p4_x - p7_x)); + T roi_height = max(abs(p0_y - p3_y), abs(p4_y - p7_y)); + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T *offset_bottom_diff = + bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T *offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = y_center - (T)0.5 * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = x_center - (T)0.5 * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_bottom_diff + y_low * width + x_low, + static_cast(g1)); + atomicAdd(offset_bottom_diff + y_low * width + x_high, + static_cast(g2)); + atomicAdd(offset_bottom_diff + y_high * width + x_low, + static_cast(g3)); + atomicAdd(offset_bottom_diff + y_high * width + x_high, + static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // BezierAlignBackward + +#endif // BEZIER_ALIGN_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/bezier_align.cpp b/mmcv/ops/csrc/pytorch/bezier_align.cpp new file mode 100644 index 0000000000..b8521d66cb --- /dev/null +++ b/mmcv/ops/csrc/pytorch/bezier_align.cpp @@ -0,0 +1,38 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void bezier_align_forward_impl(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned) { + DISPATCH_DEVICE_IMPL(bezier_align_forward_impl, input, rois, output, + aligned_height, aligned_width, spatial_scale, + sampling_ratio, aligned); +} + +void bezier_align_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned) { + DISPATCH_DEVICE_IMPL(bezier_align_backward_impl, grad_output, rois, + grad_input, aligned_height, aligned_width, spatial_scale, + sampling_ratio, aligned); +} + +void bezier_align_forward(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned) { + bezier_align_forward_impl(input, rois, output, aligned_height, aligned_width, + spatial_scale, sampling_ratio, aligned); +} + +void bezier_align_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned) { + bezier_align_backward_impl(grad_output, rois, grad_input, aligned_height, + aligned_width, spatial_scale, sampling_ratio, + aligned); +} diff --git a/mmcv/ops/csrc/pytorch/cpu/bezier_align.cpp b/mmcv/ops/csrc/pytorch/cpu/bezier_align.cpp new file mode 100644 index 0000000000..7eb0e5b940 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cpu/bezier_align.cpp @@ -0,0 +1,447 @@ +// Modified from +// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/BezierAlign +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#include +#include + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +// implementation taken from Caffe2 +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +T bezier_curve(const T p0, const T p1, const T p2, const T p3, const T u) { + return ((1. - u) * (1. - u) * (1. - u) * p0 + + 3. * u * (1. - u) * (1. - u) * p1 + 3. * u * u * (1. - u) * p2 + + u * u * u * p3); +} + +template +void pre_calc_for_bilinear_interpolate( + const int height, const int width, const int pooled_height, + const int pooled_width, const int iy_upper, const int ix_upper, T p0_x, + T p0_y, T p1_x, T p1_y, T p2_x, T p2_y, T p3_x, T p3_y, T p4_x, T p4_y, + T p5_x, T p5_y, T p6_x, T p6_y, T p7_x, T p7_y, T bin_size_h, T bin_size_w, + int roi_bin_grid_h, int roi_bin_grid_w, T offset, + std::vector> &pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + // compute the coords + const T u = pw / static_cast(pooled_width); + const T v = ph / static_cast(pooled_height); + const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u); + const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u); + const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u); + const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u); + const T x_center = x1 * v + x0 * (1. - v) - offset; + const T y_center = y1 * v + y0 * (1. - v) - offset; + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = y_center - (T)0.5 * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = x_center - (T)0.5 * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void BezierAlignForward(const int nthreads, const T *input, const T *rois, + T *output, const int pooled_height, + const int pooled_width, const T &spatial_scale, + const int sampling_ratio, bool aligned, + const int channels, const int height, const int width) { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + // beziers have size Nx(1+8*2) = Nx17 + const T *offset_rois = rois + n * 17; + int roi_batch_ind = offset_rois[0]; + + T offset = aligned ? (T)0.5 : (T)0.0; + // Do not use rounding; this implementation detail is critical + T p0_x = offset_rois[1] * spatial_scale; + T p0_y = offset_rois[2] * spatial_scale; + T p1_x = offset_rois[3] * spatial_scale; + T p1_y = offset_rois[4] * spatial_scale; + T p2_x = offset_rois[5] * spatial_scale; + T p2_y = offset_rois[6] * spatial_scale; + T p3_x = offset_rois[7] * spatial_scale; + T p3_y = offset_rois[8] * spatial_scale; + T p4_x = offset_rois[15] * spatial_scale; + T p4_y = offset_rois[16] * spatial_scale; + T p5_x = offset_rois[13] * spatial_scale; + T p5_y = offset_rois[14] * spatial_scale; + T p6_x = offset_rois[11] * spatial_scale; + T p6_y = offset_rois[12] * spatial_scale; + T p7_x = offset_rois[9] * spatial_scale; + T p7_y = offset_rois[10] * spatial_scale; + + T roi_width = std::max(std::abs(p0_x - p3_x), std::abs(p4_x - p7_x)); + T roi_height = std::max(std::abs(p0_y - p3_y), std::abs(p4_y - p7_y)); + if (aligned) { + AT_ASSERTM(roi_width >= 0 && roi_height >= 0, + "Beziers in BezierAlign cannot have non-negative size!"); + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros == 0/1, instead of NaN. + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w * + pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, width, pooled_height, pooled_width, roi_bin_grid_h, + roi_bin_grid_w, p0_x, p0_y, p1_x, p1_y, p2_x, p2_y, p3_x, p3_y, p4_x, + p4_y, p5_x, p5_y, p6_x, p6_y, p7_x, p7_y, bin_size_h, bin_size_w, + roi_bin_grid_h, roi_bin_grid_w, offset, pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T *offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void bilinear_interpolate_gradient(const int height, const int width, T y, T x, + T &w1, T &w2, T &w3, T &w4, int &x_low, + int &x_high, int &y_low, int &y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline void add(T *address, const T &val) { + *address += val; +} + +template +void BezierAlignBackward(const int nthreads, const T *grad_output, + const T *rois, T *grad_input, const int pooled_height, + const int pooled_width, const T &spatial_scale, + const int sampling_ratio, bool aligned, + const int channels, const int height, const int width, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T *offset_rois = rois + n * 17; + int roi_batch_ind = offset_rois[0]; + + // Do not use rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T p0_x = offset_rois[1] * spatial_scale; + T p0_y = offset_rois[2] * spatial_scale; + T p1_x = offset_rois[3] * spatial_scale; + T p1_y = offset_rois[4] * spatial_scale; + T p2_x = offset_rois[5] * spatial_scale; + T p2_y = offset_rois[6] * spatial_scale; + T p3_x = offset_rois[7] * spatial_scale; + T p3_y = offset_rois[8] * spatial_scale; + T p4_x = offset_rois[15] * spatial_scale; + T p4_y = offset_rois[16] * spatial_scale; + T p5_x = offset_rois[13] * spatial_scale; + T p5_y = offset_rois[14] * spatial_scale; + T p6_x = offset_rois[11] * spatial_scale; + T p6_y = offset_rois[12] * spatial_scale; + T p7_x = offset_rois[9] * spatial_scale; + T p7_y = offset_rois[10] * spatial_scale; + + // compute the coords + const T u = pw / static_cast(pooled_width); + const T v = ph / static_cast(pooled_height); + const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u); + const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u); + const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u); + const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u); + const T x_center = x1 * v + x0 * (1. - v) - offset; + const T y_center = y1 * v + y0 * (1. - v) - offset; + + T roi_width = std::max(std::abs(p0_x - p3_x), std::abs(p4_x - p7_x)); + T roi_height = std::max(std::abs(p0_y - p3_y), std::abs(p4_y - p7_y)); + if (aligned) { + AT_ASSERTM(roi_width >= 0 && roi_height >= 0, + "Beziers in BezierAlign do not have non-negative size!"); + } else { // for backward-compatibility only + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T *offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T *offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = y_center - (T)0.5 * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = x_center - (T)0.5 * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} // BezierAlignBackward + +void BezierAlignForwardCPULauncher(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "BezierAlign_forward", [&] { + BezierAlignForward( + output_size, input.data_ptr(), rois.data_ptr(), + output.data_ptr(), aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, aligned, + channels, height, width); + }); +} + +void BezierAlignBackwardCPULauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad_output.stride(0); + int c_stride = grad_output.stride(1); + int h_stride = grad_output.stride(2); + int w_stride = grad_output.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "BezierAlign_backward", [&] { + BezierAlignBackward( + output_size, grad_output.data_ptr(), + rois.data_ptr(), grad_input.data_ptr(), + aligned_height, aligned_width, static_cast(spatial_scale), + sampling_ratio, aligned, channels, height, width, n_stride, + c_stride, h_stride, w_stride); + }); +} + +void bezier_align_forward_impl(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned); + +void bezier_align_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned); + +REGISTER_DEVICE_IMPL(bezier_align_forward_impl, CPU, + BezierAlignForwardCPULauncher); +REGISTER_DEVICE_IMPL(bezier_align_backward_impl, CPU, + BezierAlignBackwardCPULauncher); diff --git a/mmcv/ops/csrc/pytorch/cuda/bezier_align_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/bezier_align_cuda.cu new file mode 100644 index 0000000000..b2786a84eb --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/bezier_align_cuda.cu @@ -0,0 +1,53 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "bezier_align_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void BezierAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int aligned_height, + int aligned_width, + float spatial_scale, + int sampling_ratio, bool aligned) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "bezier_align_forward_cuda_kernel", [&] { + bezier_align_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), output.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, aligned, + channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void BezierAlignBackwardCUDAKernelLauncher( + Tensor grad_output, Tensor rois, Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, int sampling_ratio, bool aligned) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "bezier_align_backward_cuda_kernel", [&] { + bezier_align_backward_cuda_kernel + <<>>( + output_size, grad_output.data_ptr(), + rois.data_ptr(), grad_input.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, aligned, + channels, height, width); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index e558634068..0fcfdec7be 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1867,3 +1867,28 @@ REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, CUDA, prroi_pool_forward_cuda); REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, CUDA, prroi_pool_backward_cuda); REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, CUDA, prroi_pool_coor_backward_cuda); + +void BezierAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int aligned_height, + int aligned_width, + float spatial_scale, + int sampling_ratio, bool aligned); + +void BezierAlignBackwardCUDAKernelLauncher( + Tensor grad_output, Tensor rois, Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, int sampling_ratio, bool aligned); + +void bezier_align_forward_impl(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned); + +void bezier_align_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned); + +REGISTER_DEVICE_IMPL(bezier_align_forward_impl, CUDA, + BezierAlignForwardCUDAKernelLauncher); +REGISTER_DEVICE_IMPL(bezier_align_backward_impl, CUDA, + BezierAlignBackwardCUDAKernelLauncher); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 4947b72152..d59baca836 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -446,6 +446,16 @@ Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order, const Tensor dets_sorted, const float iou_threshold, const int multi_label); +void bezier_align_forward(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned); + +void bezier_align_backward(Tensor grad_output, Tensor rois, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -899,4 +909,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("dets"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); + m.def("bezier_align_forward", &bezier_align_forward, "bezier_align forward", + py::arg("input"), py::arg("rois"), py::arg("output"), + py::arg("aligned_height"), py::arg("aligned_width"), + py::arg("spatial_scale"), py::arg("sampling_ratio"), + py::arg("aligned")); + m.def("bezier_align_backward", &bezier_align_backward, + "bezier_align backward", py::arg("grad_output"), py::arg("rois"), + py::arg("grad_input"), py::arg("aligned_height"), + py::arg("aligned_width"), py::arg("spatial_scale"), + py::arg("sampling_ratio"), py::arg("aligned")); } diff --git a/tests/test_ops/test_bezier_align.py b/tests/test_ops/test_bezier_align.py new file mode 100644 index 0000000000..b86812acee --- /dev/null +++ b/tests/test_ops/test_bezier_align.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.utils import IS_CUDA_AVAILABLE + +inputs = ([[[ + [1., 2., 5., 6.], + [3., 4., 7., 8.], + [9., 10., 13., 14.], + [11., 12., 15., 16.], +]]], [[0., 0., 0., 1, 0., 2., 0., 3., 0., 3., 3., 2., 3., 1., 3., 0., 3.]]) +outputs = ([[[[1., 1.75, 3.5, 5.25], [2.5, 3.25, 5., 6.75], + [6., 6.75, 8.5, 10.25], + [9.5, 10.25, 12., 13.75]]]], [[[[1.5625, 1.5625, 1.5625, 0.3125], + [1.5625, 1.5625, 1.5625, 0.3125], + [1.5625, 1.5625, 1.5625, 0.3125], + [0.3125, 0.3125, 0.3125, + 0.0625]]]]) + + +@pytest.mark.parametrize('device', [ + 'cpu', + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) +]) +@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half]) +def test_bezieralign(device, dtype): + try: + from mmcv.ops import bezier_align + except ModuleNotFoundError: + pytest.skip('test requires compilation') + pool_h = 4 + pool_w = 4 + spatial_scale = 1.0 + sampling_ratio = 1 + np_input = np.array(inputs[0]) + np_rois = np.array(inputs[1]) + np_output = np.array(outputs[0]) + np_grad = np.array(outputs[1]) + + x = torch.tensor(np_input, dtype=dtype, device=device, requires_grad=True) + rois = torch.tensor(np_rois, dtype=dtype, device=device) + + output = bezier_align(x, rois, (pool_h, pool_w), spatial_scale, + sampling_ratio, False) + output.backward(torch.ones_like(output)) + assert np.allclose( + output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3) + assert np.allclose( + x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)