Skip to content

Commit

Permalink
[Feature] Add BezierAlign CUDA op (open-mmlab#2393)
Browse files Browse the repository at this point in the history
* bezier align

* add ut

* fix comment

* updata ut

* fix link and comment

* fix comment
  • Loading branch information
Harold-lkk authored and akozlov-outrider committed May 8, 2023
1 parent 4fbdb77 commit 756a86e
Show file tree
Hide file tree
Showing 11 changed files with 1,008 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ We implement common ops used in detection, segmentation, etc.
| UpFirDn2d | || | |
| Voxelization ||| | |
| PrRoIPool | || | |
| BezierAlign ||| | |
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ MMCV 提供了检测、分割等任务中常用的算子
| UpFirDn2d | || | |
| Voxelization ||| | |
| PrRoIPool | || | |
| BezierAlign ||| | |
4 changes: 3 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .bezier_align import BezierAlign, bezier_align
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
Expand Down Expand Up @@ -105,5 +106,6 @@
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool', 'three_nn_vector_pool_by_two_step',
'stack_three_interpolate', 'vector_pool_with_voxel_query'
'stack_three_interpolate', 'vector_pool_with_voxel_query',
'BezierAlign', 'bezier_align'
]
137 changes: 137 additions & 0 deletions mmcv/ops/bezier_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
'_ext', ['bezier_align_forward', 'bezier_align_backward'])


class BezierAlignFunction(Function):

@staticmethod
def forward(ctx,
input: torch.Tensor,
beziers: torch.Tensor,
output_size: Union[int, Tuple[int, int]],
spatial_scale: Union[int, float] = 1.0,
sampling_ratio: int = 0,
aligned: bool = True) -> torch.Tensor:
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
ctx.sampling_ratio = sampling_ratio
ctx.aligned = aligned

assert beziers.size(1) == 17
output_shape = (beziers.size(0), input.size(1), ctx.output_size[0],
ctx.output_size[1])
output = input.new_zeros(output_shape)
ext_module.bezier_align_forward(
input,
beziers,
output,
aligned_height=ctx.output_size[0],
aligned_width=ctx.output_size[1],
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
aligned=ctx.aligned)

ctx.save_for_backward(beziers)
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor):
beziers = ctx.saved_tensors[0]
grad_input = grad_output.new_zeros(ctx.input_shape)
grad_output = grad_output.contiguous()
ext_module.bezier_align_backward(
grad_output,
beziers,
grad_input,
aligned_height=ctx.output_size[0],
aligned_width=ctx.output_size[1],
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
aligned=ctx.aligned)
return grad_input, None, None, None, None, None


bezier_align = BezierAlignFunction.apply


class BezierAlign(nn.Module):
"""Bezier align pooling layer.
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models.
aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly.
Note:
The implementation of BezierAlign is modified from
https://github.com/aim-uofa/AdelaiDet
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel
indices (in our pixel model) are computed by floor(c - 0.5) and
ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
indices [0] and [1] (which are sampled from the underlying signal
at continuous coordinates 0.5 and 1.5). But the original roi_align
(aligned=False) does not subtract the 0.5 when computing
neighboring pixel indices and therefore it uses pixels with a
slightly incorrect alignment (relative to our pixel model) when
performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors;
The difference does not make a difference to the model's
performance if ROIAlign is used together with conv layers.
"""

def __init__(
self,
output_size: Tuple,
spatial_scale: Union[int, float],
sampling_ratio: int,
aligned: bool = True,
) -> None:
super().__init__()

self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
self.aligned = aligned

def forward(self, input: torch.Tensor,
beziers: torch.Tensor) -> torch.Tensor:
"""BezierAlign forward.
Args:
inputs (Tensor): input features.
beziers (Tensor): beziers for align.
"""
return bezier_align(input, beziers, self.output_size,
self.spatial_scale, self.sampling_ratio,
self.aligned)

def __repr__(self):
s = self.__class__.__name__
s += f'(output_size={self.output_size}, '
s += f'spatial_scale={self.spatial_scale})'
s += f'sampling_ratio={self.sampling_ratio})'
s += f'aligned={self.aligned})'
return s
230 changes: 230 additions & 0 deletions mmcv/ops/csrc/common/cuda/bezier_align_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/aim-uofa/AdelaiDet/blob/master/adet/layers/csrc/BezierAlign/BezierAlign_cuda.cu
#ifndef BEZIER_ALIGN_CUDA_KERNEL_CUH
#define BEZIER_ALIGN_CUDA_KERNEL_CUH

#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT

template <typename T>
__device__ T bezier_curve(const T p0, const T p1, const T p2, const T p3,
const T u) {
return ((1. - u) * (1. - u) * (1. - u) * p0 +
3. * u * (1. - u) * (1. - u) * p1 + 3. * u * u * (1. - u) * p2 +
u * u * u * p3);
}

template <typename T>
__global__ void bezier_align_forward_cuda_kernel(
const int nthreads,
const T *bottom_data, // inputs
const T *bottom_rois, // bottom rois contains the bezier curve
T *top_data, // outputs
const int pooled_height, const int pooled_width, const T spatial_scale,
const int sampling_ratio, bool aligned, const int channels,
const int height, const int width) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

// beziers have size Nx(1+8*2) = Nx17
const T *offset_bottom_rois = bottom_rois + n * 17;
int roi_batch_ind = offset_bottom_rois[0];

// Do not use rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;

// TODO: avoid this by using parallel annotation, for good
T p0_x = offset_bottom_rois[1] * spatial_scale;
T p0_y = offset_bottom_rois[2] * spatial_scale;
T p1_x = offset_bottom_rois[3] * spatial_scale;
T p1_y = offset_bottom_rois[4] * spatial_scale;
T p2_x = offset_bottom_rois[5] * spatial_scale;
T p2_y = offset_bottom_rois[6] * spatial_scale;
T p3_x = offset_bottom_rois[7] * spatial_scale;
T p3_y = offset_bottom_rois[8] * spatial_scale;
T p4_x = offset_bottom_rois[15] * spatial_scale;
T p4_y = offset_bottom_rois[16] * spatial_scale;
T p5_x = offset_bottom_rois[13] * spatial_scale;
T p5_y = offset_bottom_rois[14] * spatial_scale;
T p6_x = offset_bottom_rois[11] * spatial_scale;
T p6_y = offset_bottom_rois[12] * spatial_scale;
T p7_x = offset_bottom_rois[9] * spatial_scale;
T p7_y = offset_bottom_rois[10] * spatial_scale;

// compute the coords
const T u = pw / static_cast<T>(pooled_width);
const T v = ph / static_cast<T>(pooled_height);
const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u);
const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u);
const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u);
const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u);
const T x_center = x1 * v + x0 * (1. - v) - offset;
const T y_center = y1 * v + y0 * (1. - v) - offset;

T roi_width = max(abs(p0_x - p3_x), abs(p4_x - p7_x));
T roi_height = max(abs(p0_y - p3_y), abs(p4_y - p7_y));
if (!aligned) { // for backward-compatibility only
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

const T *offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros == 0/1, instead of NaN.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4

T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = y_center - (T)0.5 * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = x_center - (T)0.5 * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);

T val = bilinear_interpolate(offset_bottom_data, height, width, y, x,
index);
output_val += val;
}
}
output_val /= count;

top_data[index] = output_val;
}
}

template <typename T>
__global__ void bezier_align_backward_cuda_kernel(
const int nthreads, const T *top_diff, const T *bottom_rois, T *bottom_diff,
const int pooled_height, const int pooled_width, const T spatial_scale,
const int sampling_ratio, bool aligned, const int channels,
const int height, const int width) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

// beziers have size Nx(1+8*2) = Nx17
const T *offset_bottom_rois = bottom_rois + n * 17;
int roi_batch_ind = offset_bottom_rois[0];

// Do not use rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T p0_x = offset_bottom_rois[1] * spatial_scale;
T p0_y = offset_bottom_rois[2] * spatial_scale;
T p1_x = offset_bottom_rois[3] * spatial_scale;
T p1_y = offset_bottom_rois[4] * spatial_scale;
T p2_x = offset_bottom_rois[5] * spatial_scale;
T p2_y = offset_bottom_rois[6] * spatial_scale;
T p3_x = offset_bottom_rois[7] * spatial_scale;
T p3_y = offset_bottom_rois[8] * spatial_scale;
T p4_x = offset_bottom_rois[15] * spatial_scale;
T p4_y = offset_bottom_rois[16] * spatial_scale;
T p5_x = offset_bottom_rois[13] * spatial_scale;
T p5_y = offset_bottom_rois[14] * spatial_scale;
T p6_x = offset_bottom_rois[11] * spatial_scale;
T p6_y = offset_bottom_rois[12] * spatial_scale;
T p7_x = offset_bottom_rois[9] * spatial_scale;
T p7_y = offset_bottom_rois[10] * spatial_scale;

// compute the coords
const T u = pw / static_cast<T>(pooled_width);
const T v = ph / static_cast<T>(pooled_height);
const T x0 = bezier_curve(p0_x, p1_x, p2_x, p3_x, u);
const T y0 = bezier_curve(p0_y, p1_y, p2_y, p3_y, u);
const T x1 = bezier_curve(p4_x, p5_x, p6_x, p7_x, u);
const T y1 = bezier_curve(p4_y, p5_y, p6_y, p7_y, u);
const T x_center = x1 * v + x0 * (1. - v) - offset;
const T y_center = y1 * v + y0 * (1. - v) - offset;

T roi_width = max(abs(p0_x - p3_x), abs(p4_x - p7_x));
T roi_height = max(abs(p0_y - p3_y), abs(p4_y - p7_y));
if (!aligned) { // for backward-compatibility only
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T *offset_bottom_diff =
bottom_diff + (roi_batch_ind * channels + c) * height * width;

int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T *offset_top_diff = top_diff + top_offset;
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = y_center - (T)0.5 * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = x_center - (T)0.5 * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);

T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;

bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);

T g1 = top_diff_this_bin * w1 / count;
T g2 = top_diff_this_bin * w2 / count;
T g3 = top_diff_this_bin * w3 / count;
T g4 = top_diff_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(offset_bottom_diff + y_low * width + x_low,
static_cast<T>(g1));
atomicAdd(offset_bottom_diff + y_low * width + x_high,
static_cast<T>(g2));
atomicAdd(offset_bottom_diff + y_high * width + x_low,
static_cast<T>(g3));
atomicAdd(offset_bottom_diff + y_high * width + x_high,
static_cast<T>(g4));
} // if
} // ix
} // iy
} // CUDA_1D_KERNEL_LOOP
} // BezierAlignBackward

#endif // BEZIER_ALIGN_CUDA_KERNEL_CUH
Loading

0 comments on commit 756a86e

Please sign in to comment.