Skip to content

Commit

Permalink
[Fix] Apply kernel loop for bicubic interpolate (open-mmlab#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Dec 9, 2021
1 parent c213879 commit d5c3be9
Showing 1 changed file with 43 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,57 +89,55 @@ __global__ void resize_cubic_kernel_torch(const int num_elements, const scalar_t
int srcHeight, scalar_t *dst, int dstWidth, int dstHeight,
bool align_corners, float height_scale,
float width_scale) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index >= num_elements) {
return;
}
// Special case: input and output are the same size, just copy
const int output_x = index % dstWidth;
const int output_y = index / dstWidth;
CUDA_1D_KERNEL_LOOP(index, num_elements) {
// Special case: input and output are the same size, just copy
const int output_x = index % dstWidth;
const int output_y = index / dstWidth;

if (srcHeight == dstHeight && srcWidth == dstWidth) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; c++) {
const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
output_y * dstWidth + output_x];
dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth +
output_x] = val;
}
}
return;
}
// Interpolation kernel
scalar_t real_x =
area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int in_x = floorf(real_x);
scalar_t t_x = real_x - in_x;

scalar_t real_y =
area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int in_y = floorf(real_y);
scalar_t t_y = real_y - in_y;

if (srcHeight == dstHeight && srcWidth == dstWidth) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; c++) {
const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
output_y * dstWidth + output_x];
scalar_t coefficients[4];

for (int k = 0; k < 4; k++) {
coefficients[k] = cubic_interp1d<scalar_t>(
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x - 1),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 0),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 1),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 2),
t_x);
}

dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth +
output_x] = val;
output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], t_y));
}
}
return;
}
// Interpolation kernel
scalar_t real_x =
area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int in_x = floorf(real_x);
scalar_t t_x = real_x - in_x;

scalar_t real_y =
area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int in_y = floorf(real_y);
scalar_t t_y = real_y - in_y;

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; c++) {
scalar_t coefficients[4];

for (int k = 0; k < 4; k++) {
coefficients[k] = cubic_interp1d<scalar_t>(
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x - 1),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 0),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 1),
upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth,
in_y - 1 + k, in_x + 2),
t_x);
}

dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth +
output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2],
coefficients[3], t_y));
}
}
}

Expand Down

0 comments on commit d5c3be9

Please sign in to comment.