Skip to content

Support for ROI Pooling #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from setuptools import setup, find_packages
import glob

import torch.cuda
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME


Expand Down Expand Up @@ -51,12 +51,15 @@ def get_extensions():
sources = main_file + source_cpu
extension = CppExtension

extra_compile_args = {'cxx': []}
define_macros = []

if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]
extra_compile_args['nvcc'] = ['-DCUDA_HAS_FP16=1', '-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__']

sources = [os.path.join(extensions_dir, s) for s in sources]

Expand All @@ -67,7 +70,8 @@ def get_extensions():
'torchvision._C',
sources,
include_dirs=include_dirs,
define_macros=define_macros
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

Expand Down
209 changes: 164 additions & 45 deletions test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,184 @@
import unittest


class Tester(unittest.TestCase):
class ROIPoolTester(unittest.TestCase):

def test_roi_align(self):
outputs = []
def test_roi_pool_basic_cpu(self):
dtype = torch.float32
x = torch.rand(1, 1, 10, 10, dtype=dtype)
rois = torch.tensor([
[0, 0, 0, 10, 10],
[0, 0, 5, 5, 10],
[0, 5, 5, 10, 10]], dtype=dtype)
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for device in ['cpu', 'cuda']:
device = torch.device(device)
x_n = x.to(device)
rois_n = rois.to(device)
output = layers.roi_align(x_n, rois_n, (5, 5), 0.5, 1).to('cpu')
outputs.append(output)
for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert (outputs[0] - outputs[1]).abs().max() < 1e-6
assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_align_gradient(self):
dtype = torch.float64
def test_roi_pool_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_pool_gradient_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda()
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

y = layer(x, rois)
s = y.sum()
s.backward()

gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([
[0, 0, 0, 10, 10],
[0, 0, 5, 5, 10],
[0, 5, 5, 10, 10]], dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

def func(input):
return m(input, rois)
pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert gradcheck(func, (x,)), 'gradcheck failed for roi_align'
assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

def test_roi_pool_gradient(self):
dtype = torch.float64
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gpu(self):
dtype = torch.float32
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gradient_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
m = layers.ROIPool((5, 5), 0.5).to(dtype=dtype, device=device)
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 10, 10],
[0, 0, 5, 5, 10],
[0, 5, 5, 10, 10]], dtype=dtype, device=device)
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

def func(input):
return m(input, rois)

assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool'

def test_nms(self):
boxes = torch.tensor([
[0, 0, 100, 100],
[2, 2, 98, 98],
[50, 50, 200, 200],
[50, 50, 200, 200]], dtype=torch.float32)
scores = torch.tensor([1, 2, 0.5, 1], dtype=torch.float32)
keep = layers.nms(boxes, scores, 0.5)
assert keep.tolist() == [1, 3]
return layer(input, rois)

x.requires_grad = True
y = layer(x, rois)
# print(argmax, argmax.shape)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'


if __name__ == '__main__':
unittest.main()
61 changes: 30 additions & 31 deletions torchvision/csrc/ROIPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,41 @@
#include "cuda/vision.h"
#endif


std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
if (input.type().is_cuda()) {
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor &input,
const at::Tensor &rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
#else
AT_ERROR("Not compiled with GPU support");
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width);
}

at::Tensor ROIPool_backward(const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.type().is_cuda()) {
at::Tensor ROIPool_backward(const at::Tensor &grad,
const at::Tensor &rois,
const at::Tensor &argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width)
{
if (grad.type().is_cuda())
{
#ifdef WITH_CUDA
return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
#else
AT_ERROR("Not compiled with GPU support");
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}



}
return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
}
Loading