diff --git a/src/operator/spatial_transformer.cc b/src/operator/spatial_transformer.cc index 6c413f884df9..8dde5268a3b2 100644 --- a/src/operator/spatial_transformer.cc +++ b/src/operator/spatial_transformer.cc @@ -41,8 +41,9 @@ inline void BilinearSamplingForward(const Tensor &output, DType *out = output.dptr_; const DType *data = input.dptr_; const DType *grid = grid_src.dptr_; - const int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); - const int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); + const index_t o_n = output.size(0), o_c = output.size(1), + o_h = output.size(2), o_w = output.size(3); + const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { for (index_t c = 0; c < static_cast(o_c); ++c) { for (index_t h = 0; h < static_cast(o_h); ++h) { @@ -51,23 +52,28 @@ inline void BilinearSamplingForward(const Tensor &output, const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w; const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; - const auto top_left_y = static_cast(std::floor(y_real)); - const auto top_left_x = static_cast(std::floor(x_real)); + const auto top_left_y = static_cast(std::floor(y_real)); + const auto top_left_x = static_cast(std::floor(x_real)); const DType top_left_y_w = 1.0 - (y_real - top_left_y); const DType top_left_x_w = 1.0 - (x_real - top_left_x); - const int data_index = n * i_c * i_h * i_w + c * i_h * i_w + + const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; DType top_left_v = 0; DType top_right_v = 0; DType bottom_left_v = 0; DType bottom_right_v = 0; - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + index_t lower_bound = 0; + if (between(top_left_x, lower_bound, i_w-1) && + between(top_left_y, lower_bound, i_h-1)) top_left_v = *(data + data_index); - if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + if (between(top_left_x + 1, lower_bound, i_w-1) && + between(top_left_y, lower_bound, i_h-1)) top_right_v = *(data + data_index + 1); - if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x, lower_bound, i_w-1) && + between(top_left_y + 1, lower_bound, i_h-1)) bottom_left_v = *(data + data_index + i_w); - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x+1, lower_bound, i_w-1) && + between(top_left_y + 1, lower_bound, i_h-1)) bottom_right_v = *(data + data_index + i_w + 1); *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w + top_right_v * top_left_y_w * (1.0 - top_left_x_w) + @@ -88,9 +94,9 @@ inline void BilinearSamplingBackward(const Tensor &input_grad, DType *grid_src = grid_src_data.dptr_; const DType *grad = output_grad.dptr_; const DType *data = input_data.dptr_; - const int o_n = output_grad.size(0), o_c = output_grad.size(1), + const index_t o_n = output_grad.size(0), o_c = output_grad.size(1), o_h = output_grad.size(2), o_w = output_grad.size(3); - const int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); + const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { for (index_t h = 0; h < static_cast(o_h); ++h) { for (index_t w = 0; w < static_cast(o_w); ++w) { @@ -99,34 +105,39 @@ inline void BilinearSamplingBackward(const Tensor &input_grad, const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w; const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; - const auto top_left_y = static_cast(std::floor(y_real)); - const auto top_left_x = static_cast(std::floor(x_real)); + const auto top_left_y = static_cast(std::floor(y_real)); + const auto top_left_x = static_cast(std::floor(x_real)); const DType top_left_y_w = 1.0 - (y_real - top_left_y); const DType top_left_x_w = 1.0 - (x_real - top_left_x); for (index_t c = 0; c < static_cast(o_c); ++c) { index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - const int data_index = n * i_c * i_h * i_w + c * i_h * i_w + + const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; // calc 4 vertex value in input data DType top_left_v = 0; DType top_right_v = 0; DType bottom_left_v = 0; DType bottom_right_v = 0; - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + index_t lower_bound = 0; + if (between(top_left_x, lower_bound, i_w-1) && + between(top_left_y, lower_bound, i_h-1)) { *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; top_left_v = *(data + data_index); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + if (between(top_left_x+1, lower_bound, i_w-1) && + between(top_left_y, lower_bound, i_h-1)) { *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w); top_right_v = *(data + data_index + 1); } - if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + if (between(top_left_x, lower_bound, i_w-1) && + between(top_left_y+1, lower_bound, i_h-1)) { *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w; bottom_left_v = *(data + data_index + i_w); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + if (between(top_left_x+1, lower_bound, i_w-1) && + between(top_left_y+1, lower_bound, i_h-1)) { *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); bottom_right_v = *(data + data_index + i_w + 1); diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 8b36d09cbaf8..0dfeda47385f 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -468,6 +468,7 @@ def check_col2im(): assert res.shape[2] == 2 assert res.shape[3] == 2 assert res.shape[4] == 1 + def check_embedding(): data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1)) weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1)) @@ -479,6 +480,21 @@ def check_embedding(): assert out.shape[0] == LARGE_TENSOR_SHAPE assert out.shape[1] == 1 assert out.shape[2] == 1 + + def check_spatial_transformer(): + data = nd.random_normal(shape=(2, 2**29, 1, 6)) + loc = nd.random_normal(shape=(2, 6)) + transform_type = 'affine' + sampler_type = 'bilinear' + target_shape = (2, 6) + + res = nd.SpatialTransformer(data=data, loc=loc, transform_type=transform_type, + sampler_type=sampler_type, target_shape=target_shape) + + assert res.shape[0] == 2 + assert res.shape[1] == 536870912 + assert res.shape[2] == 2 + assert res.shape[3] == 6 check_gluon_embedding() check_fully_connected() @@ -501,6 +517,7 @@ def check_embedding(): check_instance_norm() check_col2im() check_embedding() + check_spatial_transformer() def test_tensor():