Skip to content

RoI Pooling Layer & Tests #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
aad3b4b
Transforms Documentation changes (#464)
vishwakftw Apr 7, 2018
1a6038e
Fix for min/max returning 0dim tensors now (#469)
ssnl Apr 12, 2018
7bda0e8
Backward compat fix for Pytorch < 0.4 (#472)
fmassa Apr 16, 2018
f6ab107
Add support in transforms.ToTensor for PIL Images mod '1' (#471)
arturml Apr 16, 2018
6f7e26b
Fix Densenet module keys (#474)
ssnl Apr 21, 2018
1d0a3b1
bump up version to 0.2.1
soumith Apr 24, 2018
f87a896
fix models for PyTorch v0.4 (remove .data and add _ for the initializ…
moskomule Apr 30, 2018
4db0398
Update test_utils.py (#486)
vfdev-5 May 4, 2018
73281b4
Update resnet.py (#487)
vfdev-5 May 9, 2018
47214f0
Progress Bar for download_url (#497)
maruthgoyal May 11, 2018
8a4786a
Add functional transforms to docs (#499)
vishwakftw May 15, 2018
9f28cff
Update test_transforms.py (#500)
vfdev-5 May 15, 2018
972b80c
Load and parse metadata for CIFAR-10, CIFAR-100 (#502)
xenosoz May 17, 2018
1a47a44
Revert "Load and parse metadata for CIFAR-10, CIFAR-100 (#502)" (#506)
soumith May 17, 2018
628e90c
Add metadata to some datasets (#501)
davidlmorton May 17, 2018
2e25533
Fix documentation and add is_image_file (#507)
bstriner May 18, 2018
55d7395
Fix invalid argument error when using lsun method in windows (#508)
tanvach May 18, 2018
6cfd2ae
Fix transforms.Pad and transforms.LinearTransformation doc strings (#…
vfdev-5 May 24, 2018
11da4f1
add padding-mode choice to `RandomCrop` (#512)
jxgu1016 May 24, 2018
9aff567
Revert "add padding-mode choice to `RandomCrop` (#512)" (#515)
fmassa May 24, 2018
f27ecce
fix a bug described in issue #488 (#489)
KovenYu May 24, 2018
b1ef1fe
Normalize and pil link fix (#519)
kohr-h May 26, 2018
cf65f39
Add note on in-place nature of Normalize, closes #517 (#520)
kohr-h May 27, 2018
da67a1e
Partially revert #519 due to performance regression & other issues (#…
kohr-h May 28, 2018
5a0d079
make vision depend on pillow-simd if already installed (#522)
May 30, 2018
3f6c23c
Addresses issue #145 as per @fmessa's suggestion. (#527)
Choco31415 Jun 6, 2018
50b2f91
Fix broken progress bar (#524)
kohr-h Jun 6, 2018
0bbb1aa
fix #530. (#531)
csukuangfj Jun 14, 2018
1fb0ccf
Add progress bar based downloading to MNIST (#535)
vishwakftw Jun 25, 2018
cd1f58f
Downgrade tqdm version to 4.19 for py2.7 (#542)
vishwakftw Jul 13, 2018
f1b5907
Fix links to pillow docs (#554)
vishwakftw Jul 23, 2018
d6c7900
Use scandir in _find_classes (#557) (#559)
Jul 27, 2018
b51a2c3
ColorJitter Enhancement (#548)
yaox12 Jul 30, 2018
be68e24
Fix small Typo in RandomAffine comment (#563)
ashaw596 Jul 31, 2018
fe973ce
Removed the +1 and divide by 2 to allow for random offset padding (#564)
ryanpeach Aug 2, 2018
c74b79c
MNIST loader refactored: permanent 'data' and 'targets' fields (#578)
dizcza Aug 16, 2018
51bc7f2
ROI Pooling CPU and CUDA code as well as pytorch function and layer
varunagrawal Aug 23, 2018
17a2c93
Updated setup.py to compile Cpp and CUDA extensions for ROI Pooling
varunagrawal Aug 23, 2018
ac023f6
tests for ROI Pooling
varunagrawal Aug 23, 2018
37671c9
don't import torch.cuda explicitly
varunagrawal Aug 23, 2018
151c6ca
Add Half type support
varunagrawal Aug 30, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,45 @@ Transforms are common image transforms. They can be chained together using :clas
Transforms on PIL Image
-----------------------

.. autoclass:: Resize
.. autoclass:: CenterCrop

.. autoclass:: Scale
.. autoclass:: ColorJitter

.. autoclass:: CenterCrop
.. autoclass:: FiveCrop

.. autoclass:: RandomCrop
.. autoclass:: Grayscale

.. autoclass:: RandomHorizontalFlip
.. autoclass:: LinearTransformation

.. autoclass:: RandomVerticalFlip
.. autoclass:: Pad

.. autoclass:: RandomResizedCrop
.. autoclass:: RandomAffine

.. autoclass:: RandomSizedCrop
.. autoclass:: RandomApply

.. autoclass:: Grayscale
.. autoclass:: RandomChoice

.. autoclass:: RandomCrop

.. autoclass:: RandomGrayscale

.. autoclass:: FiveCrop
.. autoclass:: RandomHorizontalFlip

.. autoclass:: TenCrop
.. autoclass:: RandomOrder

.. autoclass:: Pad

.. autoclass:: ColorJitter
.. autoclass:: RandomResizedCrop

.. autoclass:: RandomRotation

.. autoclass:: RandomAffine
.. autoclass:: RandomSizedCrop

.. autoclass:: RandomVerticalFlip

.. autoclass:: Resize

.. autoclass:: Scale

.. autoclass:: TenCrop

Transforms on torch.\*Tensor
----------------------------
Expand All @@ -53,11 +61,11 @@ Transforms on torch.\*Tensor
Conversion Transforms
---------------------

.. autoclass:: ToTensor
.. autoclass:: ToPILImage
:members: __call__
:special-members:

.. autoclass:: ToPILImage
.. autoclass:: ToTensor
:members: __call__
:special-members:

Expand All @@ -66,3 +74,9 @@ Generic Transforms

.. autoclass:: Lambda


Functional Transforms
---------------------

.. automodule:: torchvision.transforms.functional
:members:
58 changes: 57 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import shutil
import sys
from setuptools import setup, find_packages
from pkg_resources import get_distribution, DistributionNotFound
import glob

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME


def read(*names, **kwargs):
Expand All @@ -15,6 +20,13 @@ def read(*names, **kwargs):
return fp.read()


def get_dist(pkgname):
try:
return get_distribution(pkgname)
except DistributionNotFound:
return None


def find_version(*file_paths):
version_file = read(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
Expand All @@ -30,11 +42,52 @@ def find_version(*file_paths):

requirements = [
'numpy',
'pillow >= 4.1.1',
'six',
'torch',
]

pillow_ver = ' >= 4.1.1'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver)

tqdm_ver = ' == 4.19.9' if sys.version_info[0] < 3 else ''
requirements.append('tqdm' + tqdm_ver)


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')

main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))

sources = main_file + source_cpu
extension = CppExtension

extra_cflags = []
define_macros = []

if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]

sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]

ext_modules = [
extension(
'torchvision._C',
sources,
include_dirs=include_dirs,
define_macros=define_macros
)
]

return ext_modules

setup(
# Metadata
name='torchvision',
Expand All @@ -51,4 +104,7 @@ def find_version(*file_paths):

zip_safe=True,
install_requires=requirements,

ext_modules=get_extensions(),
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension}
)
190 changes: 190 additions & 0 deletions test/test_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
from torch.autograd import gradcheck

from torchvision import layers


import unittest


class ROIPoolTester(unittest.TestCase):

def test_roi_pool_basic_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_pool_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y, y), 'ROIPool layer incorrect'

def test_roi_pool_gradient_cpu(self):
dtype = torch.float32
device = torch.device('cpu')
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda()
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

y = layer(x, rois)
s = y.sum()
s.backward()

gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)

for n in range(0, gt_y.size(0)):
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
roi_x = x[:, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])

assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gpu(self):
dtype = torch.float32
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=dtype, device=device)

pool_h, pool_w = (5, 5)
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
for n in range(0, gt_y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
for j in range(0, pool_h):
for i in range(0, pool_w):
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
torch.max(roi_x[:, :,
j * bin_h:(j + 1) * bin_h,
i * bin_w:(i + 1) * bin_w])
)

assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gradient_gpu(self):
dtype = torch.float32
device = torch.device('cuda')
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=dtype, device=device)

def func(input):
return layer(input, rois)

x.requires_grad = True
y = layer(x, rois)
# print(argmax, argmax.shape)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)

assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'


if __name__ == '__main__':
unittest.main()
Loading