Skip to content

Commit

Permalink
Merge pull request #11576 from JiayiFeng/dev_refine_bilinear_interp
Browse files Browse the repository at this point in the history
Add bilinear interp supporting for uint8
  • Loading branch information
JiayiFeng authored Jun 20, 2018
2 parents 80f6364 + e418862 commit 6e1c48d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 23 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/bilinear_interp_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
ops::BilinearInterpOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>);
42 changes: 24 additions & 18 deletions paddle/fluid/operators/bilinear_interp_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;

T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;

if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T));
Expand All @@ -56,24 +58,24 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda;
float h1lambda = ratio_h * i - h;
float h2lambda = 1.f - h1lambda;

for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda;
float w1lambda = ratio_w * j - w;
float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation
const T* in_pos = &input[k * in_chw + h * in_w + w];
T* out_pos = &output[k * out_chw + i * out_w + j];

for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation
out_pos[0] =
out_pos[0] = static_cast<T>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]);
w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw;
out_pos += out_hw;
}
Expand Down Expand Up @@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;

T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;

if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
Expand All @@ -127,22 +131,24 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda;
float h1lambda = ratio_h * i - h;
float h2lambda = 1 - h1lambda;

for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda;
float w1lambda = ratio_w * j - w;
float w2lambda = 1 - w1lambda;
T* in_pos = &d_input[k * in_chw + h * in_w + w];
const T* out_pos = &d_output[k * out_chw + i * out_w + j];

for (int c = 0; c < channels; ++c) { // loop for channels
in_pos[0] += h2lambda * w2lambda * out_pos[0];
in_pos[wid] += h2lambda * w1lambda * out_pos[0];
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0];
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0];
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
in_pos[hid * in_w] +=
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
in_pos[hid * in_w + wid] +=
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
in_pos += in_hw;
out_pos += out_hw;
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;

#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/tensor_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) {
auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
platform::float16>()(tensor);
uint8_t, platform::float16>()(tensor);
return buffer_info;
}

Expand Down
46 changes: 43 additions & 3 deletions python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core


def bilinear_interp_np(input, out_h, out_w, out_size):
Expand Down Expand Up @@ -45,9 +46,9 @@ def bilinear_interp_np(input, out_h, out_w, out_size):

out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
w1lambda*input[:, :, h, w+wid]) + \
h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid])
return out.astype("float32")
h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid])
return out.astype(input.dtype)


class TestBilinearInterpOp(OpTest):
Expand Down Expand Up @@ -122,5 +123,44 @@ def init_test_case(self):
self.out_size = np.array([65, 129]).astype("int32")


class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "bilinear_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.outputs = {'Out': output_np}

def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)

def init_test_case(self):
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9


class TestCase1Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50


class TestCase2Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6e1c48d

Please sign in to comment.