Skip to content

Adding ROIAlign backwards for CPU #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/csrc/ROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad,
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
}

206 changes: 203 additions & 3 deletions torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void ROIAlignForward_cpu_kernel(
const T* bottom_rois,
//int roi_cols,
T* top_data) {
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
//AT_CHECK(roi_cols == 4 || roi_cols == 5);
int roi_cols = 5;

int n_rois = nthreads / channels / pooled_width / pooled_height;
Expand Down Expand Up @@ -217,14 +217,182 @@ void ROIAlignForward_cpu_kernel(
} // for n
}



template <class T>
inline void add(const T& val, T* address){
*address += val;
}

template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int /*index*/ /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}

if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}

y_low = (int)y;
x_low = (int)x;

if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}

if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}

T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;

// reference in forward
// T v1 = bottom_data[y_low * width + x_low];
// T v2 = bottom_data[y_low * width + x_high];
// T v3 = bottom_data[y_high * width + x_low];
// T v4 = bottom_data[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

return;
}

template <typename T>
void ROIAlignBackwardFeature(
const int nthreads,
const T* top_diff,
const int num_rois,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
T* bottom_diff,
const T* bottom_rois) {

for(int index=0; index < nthreads; index++){
//(n,c,ph,pw) is an elemen in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];

//Do not using rounding; this implementation detail is critical;
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
T roi_end_h = offset_bottom_rois[3] * spatial_scale;

//Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w , (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T* offset_bottom_diff =
bottom_diff + (roi_batch_ind * channels +c ) * height * width;

int top_offset = (n * channels + c) * height * width;

const T* offset_top_diff = top_diff + top_offset;
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);

T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;

bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);

T g1 = top_diff_this_bin * w1 / count;
T g2 = top_diff_this_bin * w2 / count;
T g3 = top_diff_this_bin * w3 / count;
T g4 = top_diff_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
} // if
} // ix
} // iy
} // for
}//ROI Backward

at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
AT_ASSERT(!input.type().is_cuda(), "input must be a CPU tensor");
AT_ASSERT(!rois.type().is_cuda(), "rois must be a CPU tensor");
AT_CHECK(!input.type().is_cuda(), "input must be a CPU tensor");
AT_CHECK(!rois.type().is_cuda(), "rois must be a CPU tensor");

auto num_rois = rois.size(0);
auto channels = input.size(1);
Expand Down Expand Up @@ -254,3 +422,35 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
});
return output;
}

at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio){
auto num_rois = rois.size(0);

at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_();

AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&]{
ROIAlignBackwardFeature<scalar_t>(
grad.numel(),
grad.data<scalar_t>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
grad_input.data<scalar_t>(),
rois.data<scalar_t>());
});
return grad_input;
}
6 changes: 3 additions & 3 deletions torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel(const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
AT_ASSERT(!dets.type().is_cuda(), "dets must be a CPU tensor");
AT_ASSERT(!scores.type().is_cuda(), "scores must be a CPU tensor");
AT_ASSERT(dets.type() == scores.type(), "dets should have the same type as scores");
AT_CHECK(!dets.type().is_cuda(), "dets must be a CPU tensor");
AT_CHECK(!scores.type().is_cuda(), "scores must be a CPU tensor");
AT_CHECK(dets.type() == scores.type(), "dets should have the same type as scores");

if (dets.numel() == 0)
return torch::CPU(at::kLong).tensor();
Expand Down
11 changes: 11 additions & 0 deletions torchvision/csrc/cpu/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,14 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
at::Tensor nms_cpu(const at::Tensor& dets,
const at::Tensor& scores,
const float threshold);

at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio);