Skip to content

ROI Pooling Layer #585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import sys
from setuptools import setup, find_packages
from pkg_resources import get_distribution, DistributionNotFound
import glob

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME


def read(*names, **kwargs):
Expand Down Expand Up @@ -49,6 +53,41 @@ def find_version(*file_paths):
tqdm_ver = ' == 4.19.9' if sys.version_info[0] < 3 else ''
requirements.append('tqdm' + tqdm_ver)


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

sources = main_file + source_cpu
extension = CppExtension

extra_cflags = []
define_macros = []

if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]

sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]

ext_modules = [
extension(
'torchvision._C',
sources,
include_dirs=include_dirs,
define_macros=define_macros
)
]

return ext_modules

setup(
# Metadata
name='torchvision',
Expand All @@ -65,4 +104,7 @@ def find_version(*file_paths):

zip_safe=True,
install_requires=requirements,

ext_modules=get_extensions(),
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension}
)
190 changes: 190 additions & 0 deletions test/test_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
from torch.autograd import gradcheck

from torchvision import layers


import unittest


class ROIPoolTester(unittest.TestCase):

def test_roi_pool_basic_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_pool_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_pool_gradient_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda()
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

y = layer(x, rois)
s = y.sum()
s.backward()

gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gpu(self):
dtype = torch.float32
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gradient_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

def func(input):
return layer(input, rois)

x.requires_grad = True
y = layer(x, rois)
# print(argmax, argmax.shape)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'


if __name__ == '__main__':
unittest.main()
46 changes: 46 additions & 0 deletions torchvision/csrc/ROIPool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "cpu/vision.h"

#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif

std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor &input,
const at::Tensor &rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width)
{
if (input.type().is_cuda())
{
#ifdef WITH_CUDA
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width);
}

at::Tensor ROIPool_backward(const at::Tensor &grad,
const at::Tensor &rois,
const at::Tensor &argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width)
{
if (grad.type().is_cuda())
{
#ifdef WITH_CUDA
return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
}
Loading