Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Large Tensor] Fixed Spatial Transformer op #17617

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions src/operator/spatial_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
DType *out = output.dptr_;
const DType *data = input.dptr_;
const DType *grid = grid_src.dptr_;
const int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
const int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
const index_t o_n = output.size(0), o_c = output.size(1),
o_h = output.size(2), o_w = output.size(3);
const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
Expand All @@ -51,23 +52,28 @@ inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
const auto top_left_y = static_cast<int>(std::floor(y_real));
const auto top_left_x = static_cast<int>(std::floor(x_real));
const auto top_left_y = static_cast<index_t>(std::floor(y_real));
const auto top_left_x = static_cast<index_t>(std::floor(x_real));
const DType top_left_y_w = 1.0 - (y_real - top_left_y);
const DType top_left_x_w = 1.0 - (x_real - top_left_x);
const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
top_left_y * i_w + top_left_x;
DType top_left_v = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this changed back to Dtype?
It's the index position right?

Copy link
Contributor Author

@connorgoggins connorgoggins Feb 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChaiBapchya changing DType to index_t for these specific variables makes the op generate incorrect output on standard inputs (e.g. the inputs in the CI run) - the values generated in the output NDArray are all integers instead of floats. This is due to the fact that these variables do not represent the index positions (as I also originally believed), but instead represent the underlying values at the vertices.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. good catch.
_v indicates the value at that particular index. Thanks!

DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
index_t lower_bound = 0;
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1))
top_left_v = *(data + data_index);
if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
if (between(top_left_x + 1, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1))
top_right_v = *(data + data_index + 1);
if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y + 1, lower_bound, i_h-1))
bottom_left_v = *(data + data_index + i_w);
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y + 1, lower_bound, i_h-1))
bottom_right_v = *(data + data_index + i_w + 1);
*(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
Expand All @@ -88,9 +94,9 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
DType *grid_src = grid_src_data.dptr_;
const DType *grad = output_grad.dptr_;
const DType *data = input_data.dptr_;
const int o_n = output_grad.size(0), o_c = output_grad.size(1),
const index_t o_n = output_grad.size(0), o_c = output_grad.size(1),
o_h = output_grad.size(2), o_w = output_grad.size(3);
const int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
Expand All @@ -99,34 +105,39 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
const auto top_left_y = static_cast<int>(std::floor(y_real));
const auto top_left_x = static_cast<int>(std::floor(x_real));
const auto top_left_y = static_cast<index_t>(std::floor(y_real));
const auto top_left_x = static_cast<index_t>(std::floor(x_real));
const DType top_left_y_w = 1.0 - (y_real - top_left_y);
const DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
index_t lower_bound = 0;
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y+1, lower_bound, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y+1, lower_bound, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
bottom_right_v = *(data + data_index + i_w + 1);
Expand Down
17 changes: 17 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def check_col2im():
assert res.shape[2] == 2
assert res.shape[3] == 2
assert res.shape[4] == 1

def check_embedding():
data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
Expand All @@ -479,6 +480,21 @@ def check_embedding():
assert out.shape[0] == LARGE_TENSOR_SHAPE
assert out.shape[1] == 1
assert out.shape[2] == 1

def check_spatial_transformer():
data = nd.random_normal(shape=(2, 2**29, 1, 6))
loc = nd.random_normal(shape=(2, 6))
transform_type = 'affine'
sampler_type = 'bilinear'
target_shape = (2, 6)

res = nd.SpatialTransformer(data=data, loc=loc, transform_type=transform_type,
sampler_type=sampler_type, target_shape=target_shape)

assert res.shape[0] == 2
assert res.shape[1] == 536870912
assert res.shape[2] == 2
assert res.shape[3] == 6

check_gluon_embedding()
check_fully_connected()
Expand All @@ -501,6 +517,7 @@ def check_embedding():
check_instance_norm()
check_col2im()
check_embedding()
check_spatial_transformer()


def test_tensor():
Expand Down