Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]Nearest #44901

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/phi/kernels/funcs/interpolate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ namespace funcs {

template <typename T>
HOSTDEVICE inline T CubicConvolution1(T x, T A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
return ((A + static_cast<T>(2)) * x - (A + static_cast<T>(3))) * x * x +
static_cast<T>(1);
}

template <typename T>
HOSTDEVICE inline T CubicConvolution2(T x, T A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
return ((A * x - static_cast<T>(5) * A) * x + static_cast<T>(8) * A) * x -
static_cast<T>(4) * A;
}

template <typename T>
Expand Down
72 changes: 41 additions & 31 deletions paddle/phi/kernels/gpu/interpolate_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"

namespace phi {
using paddle::platform::FastDivMod;

Expand All @@ -34,11 +34,11 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T* lambda2,
T src_x,
const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f;
src_x = (src_x > static_cast<T>(0)) ? src_x : static_cast<T>(0);
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx;
*lambda2 = 1.f - *lambda1;
*lambda1 = src_x - static_cast<T>(*in_img_idx);
*lambda2 = static_cast<T>(1) - *lambda1;
}

template <typename T>
Expand Down Expand Up @@ -79,11 +79,11 @@ __global__ void KeLinearInterpFw(const T* in,
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; // w
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; // w_id

T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
float src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T w1lambda = static_cast<T>(
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1) - w1lambda;

if (data_layout == DataLayout::kNCHW) {
const T* in_pos =
Expand Down Expand Up @@ -222,8 +222,12 @@ __global__ void KeBilinearInterpFw(const T* in,

int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
T src_w = static_cast<T>(ratio_w) *
(static_cast<T>(out_img_idx) + align_type_value) -
align_type_value;
T src_h = static_cast<T>(ratio_h) *
(static_cast<T>(out_img_idy) + align_type_value) -
align_type_value;

PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
Expand Down Expand Up @@ -262,8 +266,12 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,

int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
T src_w = static_cast<T>(ratio_w) *
(static_cast<T>(out_img_idx) + align_type_value) -
align_type_value;
T src_h = static_cast<T>(ratio_h) *
(static_cast<T>(out_img_idy) + align_type_value) -
align_type_value;

PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
Expand Down Expand Up @@ -296,13 +304,13 @@ template <typename T>
__device__ __forceinline__ static T Kecubic_interp(
const T x0, const T x1, const T x2, const T x3, T t) {
T coeffs[4];
T a = -0.75;
T a = static_cast<T>(-0.75);
T x_1 = t;
T x_2 = 1.0 - t;
coeffs[0] = funcs::CubicConvolution2<T>(x_1 + 1.0, a);
T x_2 = static_cast<T>(1) - t;
coeffs[0] = funcs::CubicConvolution2<T>(x_1 + static_cast<T>(1), a);
coeffs[1] = funcs::CubicConvolution1<T>(x_1, a);
coeffs[2] = funcs::CubicConvolution1<T>(x_2, a);
coeffs[3] = funcs::CubicConvolution2<T>(x_2 + 1.0, a);
coeffs[3] = funcs::CubicConvolution2<T>(x_2 + static_cast<T>(1), a);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

Expand Down Expand Up @@ -348,13 +356,13 @@ __global__ void KeBicubicInterpFw(const T* in,
? static_cast<T>(ratio_h * out_img_idy)
: static_cast<T>(ratio_h * (out_img_idy + 0.5) - 0.5);
int input_y = floorf(in_img_idy);
const T y_t = in_img_idy - input_y;
const T y_t = in_img_idy - static_cast<T>(input_y);

T in_img_idx = align_corners
? static_cast<T>(ratio_w * out_img_idx)
: static_cast<T>(ratio_w * (out_img_idx + 0.5) - 0.5);
int input_x = floorf(in_img_idx);
const T x_t = in_img_idx - input_x;
const T x_t = in_img_idx - static_cast<T>(input_x);

T coefficients[4];
const T* in_pos_0;
Expand Down Expand Up @@ -482,33 +490,33 @@ __global__ void KeTrilinearInterpFw(const T* in,
: static_cast<int>(ratio_d * out_img_idt);
in_img_idt = (in_img_idt > 0) ? in_img_idt : 0;
int d_id = (in_img_idt < in_img_d - 1) ? 1 : 0;
T src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
float src_d = ratio_d * (out_img_idt + 0.5) - 0.5;
src_d = (src_d > 0) ? src_d : 0;
T d1lambda =
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt;
T d2lambda = 1.f - d1lambda;
T d1lambda = static_cast<T>(
align_flag ? src_d - in_img_idt : ratio_d * out_img_idt - in_img_idt);
T d2lambda = static_cast<T>(1) - d1lambda;

int in_img_idy = align_flag
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
: static_cast<int>(ratio_h * out_img_idy);
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
float src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
T h1lambda = static_cast<T>(
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy);
T h2lambda = static_cast<T>(1) - h1lambda;

int in_img_idx = align_flag
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
float src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T w1lambda = static_cast<T>(
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx);
T w2lambda = static_cast<T>(1) - w1lambda;

if (data_layout == DataLayout::kNCHW) {
int in_pos1_idx = out_id_h * input_w + channel_id * in_img_size +
Expand Down Expand Up @@ -926,7 +934,8 @@ static void Interpolate2DCUDAFwd(
thread_num = 512;
}
#endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
const T align_type_value =
static_cast<T>((align_mode == 0 && !align_corners) ? 0.5f : 0);
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
Expand Down Expand Up @@ -1454,6 +1463,7 @@ PD_REGISTER_KERNEL(nearest_interp_v2,
GPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
phi::dtype::float16,
float,
double,
int,
Expand Down