From 51bc7f2f074b41b335e5d147f5e58249ffa36e5e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Aug 2018 16:21:01 -0400 Subject: [PATCH 1/5] ROI Pooling CPU and CUDA code as well as pytorch function and layer --- torchvision/csrc/ROIPool.h | 46 ++++++ torchvision/csrc/cpu/ROIPool_cpu.cpp | 152 +++++++++++++++++++ torchvision/csrc/cpu/vision.h | 19 +++ torchvision/csrc/cuda/ROIPool_cuda.cu | 208 ++++++++++++++++++++++++++ torchvision/csrc/cuda/cuda_helpers.h | 8 + torchvision/csrc/cuda/vision.h | 19 +++ torchvision/csrc/vision.cpp | 7 + torchvision/layers/__init__.py | 6 + torchvision/layers/roi_pool.py | 53 +++++++ 9 files changed, 518 insertions(+) create mode 100644 torchvision/csrc/ROIPool.h create mode 100644 torchvision/csrc/cpu/ROIPool_cpu.cpp create mode 100644 torchvision/csrc/cpu/vision.h create mode 100644 torchvision/csrc/cuda/ROIPool_cuda.cu create mode 100644 torchvision/csrc/cuda/cuda_helpers.h create mode 100644 torchvision/csrc/cuda/vision.h create mode 100644 torchvision/csrc/vision.cpp create mode 100644 torchvision/layers/__init__.py create mode 100644 torchvision/layers/roi_pool.py diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/ROIPool.h new file mode 100644 index 00000000000..bd15a9e70fd --- /dev/null +++ b/torchvision/csrc/ROIPool.h @@ -0,0 +1,46 @@ +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +std::tuple ROIPool_forward(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width); +} + +at::Tensor ROIPool_backward(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + if (grad.type().is_cuda()) + { +#ifdef WITH_CUDA + return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp new file mode 100644 index 00000000000..0f587d4323f --- /dev/null +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -0,0 +1,152 @@ +#include +#include +#include + +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + int num_rois = rois.size(0); + int channels = input.size(1); + int input_height = input.size(2); + int input_width = input.size(3); + + at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + + // define accessors for indexing + auto input_a = input.accessor(); + auto rois_a = rois.accessor(); + auto output_a = output.accessor(); + auto argmax_a = argmax.accessor(); + + if (output.numel() == 0) + { + return std::make_tuple(output, argmax); + } + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + int roi_start_w = round(rois_a[n][1] * spatial_scale); + int roi_start_h = round(rois_a[n][2] * spatial_scale); + int roi_end_w = round(rois_a[n][3] * spatial_scale); + int roi_end_h = round(rois_a[n][4] * spatial_scale); + + // Force malformed ROIs to be 1x1 or HxW + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), input_height); + hend = std::min(std::max(hend + roi_start_h, 0), input_height); + wstart = std::min(std::max(wstart + roi_start_w, 0), input_width); + wend = std::min(std::max(wend + roi_start_w, 0), input_width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + float maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + for (int c = 0; c < channels; ++c) + { + for (int h = hstart; h < hend; ++h) + { + for (int w = wstart; w < wend; ++w) + { + int index = h * input_width + w; + if (input_a[roi_batch_ind][c][h][w] > maxval) + { + maxval = input_a[roi_batch_ind][c][h][w]; + maxidx = index; + } + } + } + output_a[n][c][ph][pw] = maxval; + argmax_a[n][c][ph][pw] = maxidx; + } + } + } + } + + return std::make_tuple(output, argmax); +} + +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + + // handle possibly empty gradients + if (grad.numel() == 0) + { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + // define accessors for tensors + auto grad_input_a = grad_input.accessor(); + auto grad_a = grad.accessor(); + auto argmax_a = argmax.accessor(); + auto rois_a = rois.accessor(); + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + + for (int c = 0; c < channels; ++c) + { + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int argmax_idx = argmax_a[n][c][ph][pw]; + // get height and width index from argmax index + int h = argmax_idx / height; + int w = argmax_idx % width; + + grad_input_a[roi_batch_ind][c][h][w] += grad_a[n * n_stride][c * c_stride][ph * h_stride][pw * w_stride]; + } + } + } + } + + return grad_input; +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/vision.h b/torchvision/csrc/cpu/vision.h new file mode 100644 index 00000000000..64e65e66864 --- /dev/null +++ b/torchvision/csrc/cpu/vision.h @@ -0,0 +1,19 @@ +#pragma once +#include + +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu new file mode 100644 index 00000000000..77a27b1d1ff --- /dev/null +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -0,0 +1,208 @@ +#include +#include + +#include +#include +#include + +#include "cuda_helpers.h" +#include + + +template +__global__ void RoIPoolForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const T* bottom_rois, T* top_data, int* argmax_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 or HxW + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (offset_bottom_data[bottom_index] > maxval) { + maxval = offset_bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +__global__ void RoIPoolBackward(const int nthreads, const T* top_grad, + const int* argmax_data, const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* bottom_data, + const T* bottom_rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + T* bottom_data_offset = bottom_data + bottom_offset; + + int top_offset = n*n_stride + c*c_stride; + const int* argmax_data_offset = argmax_data + n*channels*pooled_height*pooled_width; + int argmax = argmax_data_offset[c*pooled_height*pooled_width + ph*pooled_width + pw]; + + if (argmax != -1) { + atomicAdd(bottom_data_offset + argmax, + static_cast(top_grad[top_offset + ph*h_stride + pw*w_stride])); + } + } +} + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { + RoIPoolForward<<>>( + output_size, + input.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.data(), + output.data(), + argmax.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); +} + +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { + RoIPoolBackward<<>>( + grad.numel(), + grad.data(), + argmax.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data(), + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} \ No newline at end of file diff --git a/torchvision/csrc/cuda/cuda_helpers.h b/torchvision/csrc/cuda/cuda_helpers.h new file mode 100644 index 00000000000..15fb7f6031a --- /dev/null +++ b/torchvision/csrc/cuda/cuda_helpers.h @@ -0,0 +1,8 @@ +#ifndef CUDA_HELPERS_H +#define CUDA_HELPERS_H + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ + i += (blockDim.x * gridDim.x)) + +#endif // CUDA_HELPERS_H \ No newline at end of file diff --git a/torchvision/csrc/cuda/vision.h b/torchvision/csrc/cuda/vision.h new file mode 100644 index 00000000000..4f83f83b4f4 --- /dev/null +++ b/torchvision/csrc/cuda/vision.h @@ -0,0 +1,19 @@ +#pragma once +#include + +std::tuple ROIPool_forward_cuda(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cuda(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); \ No newline at end of file diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp new file mode 100644 index 00000000000..88caec61d43 --- /dev/null +++ b/torchvision/csrc/vision.cpp @@ -0,0 +1,7 @@ +#include "ROIPool.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); + m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); +} \ No newline at end of file diff --git a/torchvision/layers/__init__.py b/torchvision/layers/__init__.py new file mode 100644 index 00000000000..1e54c6311c3 --- /dev/null +++ b/torchvision/layers/__init__.py @@ -0,0 +1,6 @@ +from .roi_pool import roi_pool, ROIPool + + +__all__ = [ + 'roi_pool', 'ROIPool' +] diff --git a/torchvision/layers/roi_pool.py b/torchvision/layers/roi_pool.py new file mode 100644 index 00000000000..f232d0cc160 --- /dev/null +++ b/torchvision/layers/roi_pool.py @@ -0,0 +1,53 @@ +import torch +from torch import nn + +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from torch.nn.modules.utils import _pair + +from torchvision import _C + + +class _ROIPool(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + output, argmax = _C.roi_pool_forward( + input, roi, spatial_scale, + output_size[0], output_size[1]) + ctx.save_for_backward(roi, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_pool_backward( + grad_output, rois, argmax, spatial_scale, + output_size[0], output_size[1], bs, ch, h, w) + return grad_input, None, None, None + +roi_pool = _ROIPool.apply + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'output_size=' + str(self.output_size) + tmpstr += ', spatial_scale=' + str(self.spatial_scale) + tmpstr += ')' + return tmpstr From 17a2c9337ffe3f2251695d60264c5e59df09911e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Aug 2018 16:21:24 -0400 Subject: [PATCH 2/5] Updated setup.py to compile Cpp and CUDA extensions for ROI Pooling --- setup.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/setup.py b/setup.py index ff11430c828..f6697cd80ea 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,10 @@ import sys from setuptools import setup, find_packages from pkg_resources import get_distribution, DistributionNotFound +import glob + +import torch.cuda +from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME def read(*names, **kwargs): @@ -49,6 +53,41 @@ def find_version(*file_paths): tqdm_ver = ' == 4.19.9' if sys.version_info[0] < 3 else '' requirements.append('tqdm' + tqdm_ver) + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') + + main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp')) + source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu')) + + sources = main_file + source_cpu + extension = CppExtension + + extra_cflags = [] + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [('WITH_CUDA', None)] + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + 'torchvision._C', + sources, + include_dirs=include_dirs, + define_macros=define_macros + ) + ] + + return ext_modules + setup( # Metadata name='torchvision', @@ -65,4 +104,7 @@ def find_version(*file_paths): zip_safe=True, install_requires=requirements, + + ext_modules=get_extensions(), + cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension} ) From ac023f6f9611d8161a5968209efd7fb7f86be99e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Aug 2018 16:22:43 -0400 Subject: [PATCH 3/5] tests for ROI Pooling --- test/test_layers.py | 190 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 test/test_layers.py diff --git a/test/test_layers.py b/test/test_layers.py new file mode 100644 index 00000000000..d508393c64a --- /dev/null +++ b/test/test_layers.py @@ -0,0 +1,190 @@ +import torch +from torch.autograd import gradcheck + +from torchvision import layers + + +import unittest + + +class ROIPoolTester(unittest.TestCase): + + def test_roi_pool_basic_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) + + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + + def test_roi_pool_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + + def test_roi_pool_gradient_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda() + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_basic_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) + + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gradient_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) + + def func(input): + return layer(input, rois) + + x.requires_grad = True + y = layer(x, rois) + # print(argmax, argmax.shape) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + +if __name__ == '__main__': + unittest.main() From 37671c9f172cbe302bd322dabcd968f65e46679f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Aug 2018 17:46:30 -0400 Subject: [PATCH 4/5] don't import torch.cuda explicitly --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f6697cd80ea..6c271c72bb4 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from pkg_resources import get_distribution, DistributionNotFound import glob -import torch.cuda +import torch from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME From 151c6ca94045a46a9bfdcfebf489a694d2fdc623 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 30 Aug 2018 14:44:35 -0400 Subject: [PATCH 5/5] Add Half type support --- torchvision/csrc/cuda/ROIPool_cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 77a27b1d1ff..5f95de1da43 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -131,7 +131,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, return std::make_tuple(output, argmax); } - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] { RoIPoolForward<<>>( output_size, input.data(), @@ -184,7 +184,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, int h_stride = grad.stride(2); int w_stride = grad.stride(3); - AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { RoIPoolBackward<<>>( grad.numel(), grad.data(),