diff --git a/setup.py b/setup.py index 4e34fb5f528..c1d30e5b030 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages import glob -import torch.cuda +import torch from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME @@ -51,12 +51,15 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension + extra_compile_args = {'cxx': []} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [('WITH_CUDA', None)] + extra_compile_args['nvcc'] = ['-DCUDA_HAS_FP16=1', '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__'] sources = [os.path.join(extensions_dir, s) for s in sources] @@ -67,7 +70,8 @@ def get_extensions(): 'torchvision._C', sources, include_dirs=include_dirs, - define_macros=define_macros + define_macros=define_macros, + extra_compile_args=extra_compile_args, ) ] diff --git a/test/test_layers.py b/test/test_layers.py index 51022fd3881..d508393c64a 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -7,65 +7,184 @@ import unittest -class Tester(unittest.TestCase): +class ROIPoolTester(unittest.TestCase): - def test_roi_align(self): - outputs = [] + def test_roi_pool_basic_cpu(self): dtype = torch.float32 - x = torch.rand(1, 1, 10, 10, dtype=dtype) - rois = torch.tensor([ - [0, 0, 0, 10, 10], - [0, 0, 5, 5, 10], - [0, 5, 5, 10, 10]], dtype=dtype) + device = torch.device('cpu') + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) - for device in ['cpu', 'cuda']: - device = torch.device(device) - x_n = x.to(device) - rois_n = rois.to(device) - output = layers.roi_align(x_n, rois_n, (5, 5), 0.5, 1).to('cpu') - outputs.append(output) + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) - assert (outputs[0] - outputs[1]).abs().max() < 1e-6 + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' - def test_roi_align_gradient(self): - dtype = torch.float64 + def test_roi_pool_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + + def test_roi_pool_gradient_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda() + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_basic_gpu(self): + dtype = torch.float32 device = torch.device('cuda') - m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device) x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) - rois = torch.tensor([ - [0, 0, 0, 10, 10], - [0, 0, 5, 5, 10], - [0, 5, 5, 10, 10]], dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) - def func(input): - return m(input, rois) + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) - assert gradcheck(func, (x,)), 'gradcheck failed for roi_align' + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' - def test_roi_pool_gradient(self): - dtype = torch.float64 + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gradient_gpu(self): + dtype = torch.float32 device = torch.device('cuda') - m = layers.ROIPool((5, 5), 0.5).to(dtype=dtype, device=device) - x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) rois = torch.tensor([ - [0, 0, 0, 10, 10], - [0, 0, 5, 5, 10], - [0, 5, 5, 10, 10]], dtype=dtype, device=device) + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) def func(input): - return m(input, rois) - - assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool' - - def test_nms(self): - boxes = torch.tensor([ - [0, 0, 100, 100], - [2, 2, 98, 98], - [50, 50, 200, 200], - [50, 50, 200, 200]], dtype=torch.float32) - scores = torch.tensor([1, 2, 0.5, 1], dtype=torch.float32) - keep = layers.nms(boxes, scores, 0.5) - assert keep.tolist() == [1, 3] + return layer(input, rois) + + x.requires_grad = True + y = layer(x, rois) + # print(argmax, argmax.shape) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/ROIPool.h index 6a3951fa2fe..54dceb3ad78 100644 --- a/torchvision/csrc/ROIPool.h +++ b/torchvision/csrc/ROIPool.h @@ -6,42 +6,41 @@ #include "cuda/vision.h" #endif - -std::tuple ROIPool_forward(const at::Tensor& input, - const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width) { - if (input.type().is_cuda()) { +std::tuple ROIPool_forward(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + if (input.type().is_cuda()) + { #ifdef WITH_CUDA - return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); + return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); #else - AT_ERROR("Not compiled with GPU support"); + AT_ERROR("Not compiled with GPU support"); #endif - } - AT_ERROR("Not implemented on the CPU"); + } + return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width); } -at::Tensor ROIPool_backward(const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& rois, - const at::Tensor& argmax, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width) { - if (grad.type().is_cuda()) { +at::Tensor ROIPool_backward(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + if (grad.type().is_cuda()) + { #ifdef WITH_CUDA - return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); + return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); #else - AT_ERROR("Not compiled with GPU support"); + AT_ERROR("Not compiled with GPU support"); #endif - } - AT_ERROR("Not implemented on the CPU"); -} - - - + } + return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp new file mode 100644 index 00000000000..5500a63f4ca --- /dev/null +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -0,0 +1,152 @@ +#include +#include +#include + +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + int num_rois = rois.size(0); + int channels = input.size(1); + int input_height = input.size(2); + int input_width = input.size(3); + + at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + + // define accessors for indexing + auto input_a = input.accessor(); + auto rois_a = rois.accessor(); + auto output_a = output.accessor(); + auto argmax_a = argmax.accessor(); + + if (output.numel() == 0) + { + return std::make_tuple(output, argmax); + } + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + int roi_start_w = round(rois_a[n][1] * spatial_scale); + int roi_start_h = round(rois_a[n][2] * spatial_scale); + int roi_end_w = round(rois_a[n][3] * spatial_scale); + int roi_end_h = round(rois_a[n][4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), input_height); + hend = std::min(std::max(hend + roi_start_h, 0), input_height); + wstart = std::min(std::max(wstart + roi_start_w, 0), input_width); + wend = std::min(std::max(wend + roi_start_w, 0), input_width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + float maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + for (int c = 0; c < channels; ++c) + { + for (int h = hstart; h < hend; ++h) + { + for (int w = wstart; w < wend; ++w) + { + int index = h * input_width + w; + if (input_a[roi_batch_ind][c][h][w] > maxval) + { + maxval = input_a[roi_batch_ind][c][h][w]; + maxidx = index; + } + } + } + output_a[n][c][ph][pw] = maxval; + argmax_a[n][c][ph][pw] = maxidx; + } + } + } + } + + return std::make_tuple(output, argmax); +} + +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + + // handle possibly empty gradients + if (grad.numel() == 0) + { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + // define accessors for tensors + auto grad_input_a = grad_input.accessor(); + auto grad_a = grad.accessor(); + auto argmax_a = argmax.accessor(); + auto rois_a = rois.accessor(); + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + + for (int c = 0; c < channels; ++c) + { + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int argmax_idx = argmax_a[n][c][ph][pw]; + // get height and width index from argmax index + int h = argmax_idx / height; + int w = argmax_idx % width; + + grad_input_a[roi_batch_ind][c][h][w] += grad_a[n * n_stride][c * c_stride][ph * h_stride][pw * w_stride]; + } + } + } + } + + return grad_input; +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/vision.h b/torchvision/csrc/cpu/vision.h index 9b69073693c..eebf4b95ad5 100644 --- a/torchvision/csrc/cpu/vision.h +++ b/torchvision/csrc/cpu/vision.h @@ -1,15 +1,30 @@ #pragma once #include +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); -at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, - const at::Tensor& rois, +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, const float spatial_scale, const int pooled_height, const int pooled_width, - const int sampling_ratio); + const int batch_size, + const int channels, + const int height, + const int width); +at::Tensor ROIAlign_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); -at::Tensor nms_cpu(const at::Tensor& dets, - const at::Tensor& scores, +at::Tensor nms_cpu(const at::Tensor &dets, + const at::Tensor &scores, const float threshold); diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index bc94c8017be..e8408c4ee06 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -1,4 +1,5 @@ #include +#include #include #include @@ -269,7 +270,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); auto output_size = num_rois * pooled_height * pooled_width * channels; - cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); dim3 block(512); @@ -314,7 +315,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, auto num_rois = rois.size(0); at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_(); - cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); dim3 block(512); diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 57da27b8144..29d9c9c9319 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -1,21 +1,19 @@ #include +#include #include #include #include - -// TODO make it in a common file -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) +#include "cuda_helpers.h" +#include template -__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, +__global__ void RoIPoolForward(const int nthreads, const T* input, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, - const T* bottom_rois, T* top_data, int* argmax_data) { + const T* rois, T* output, int* argmax_data) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -23,12 +21,12 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); // Force malformed ROIs to be 1x1 int roi_width = max(roi_end_w - roi_start_w + 1, 1); @@ -58,51 +56,51 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, T maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd int maxidx = -1; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - int bottom_index = h * width + w; - if (offset_bottom_data[bottom_index] > maxval) { - maxval = offset_bottom_data[bottom_index]; - maxidx = bottom_index; + int input_index = h * width + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; } } } - top_data[index] = maxval; + output[index] = maxval; argmax_data[index] = maxidx; } } template -__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff, +__global__ void RoIPoolBackward(const int nthreads, const T* grad_output, const int* argmax_data, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, - const int pooled_height, const int pooled_width, T* bottom_diff, - const T* bottom_rois) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; - - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - int bottom_offset = (roi_batch_ind * channels + c) * height * width; - int top_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_top_diff = top_diff + top_offset; - T* offset_bottom_diff = bottom_diff + bottom_offset; - const int* offset_argmax_data = argmax_data + top_offset; - - int argmax = offset_argmax_data[ph * pooled_width + pw]; - if (argmax != -1) { - atomicAdd( - offset_bottom_diff + argmax, - static_cast(offset_top_diff[ph * pooled_width + pw])); - + const int pooled_height, const int pooled_width, + T* grad_input, const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + T* grad_input_offset = grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n*n_stride + c*c_stride; + const int* argmax_data_offset = argmax_data + n*channels*pooled_height*pooled_width; + int argmax = argmax_data_offset[c*pooled_height*pooled_width + ph*pooled_width + pw]; + + if (argmax != -1) { + atomicAdd(grad_input_offset + argmax, + static_cast(grad_output[output_offset + ph*h_stride + pw*w_stride])); + } } - } } std::tuple ROIPool_forward_cuda(const at::Tensor& input, @@ -122,7 +120,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); auto output_size = num_rois * pooled_height * pooled_width * channels; - cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); dim3 block(512); @@ -132,8 +130,8 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, return std::make_tuple(output, argmax); } - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { - RoIPoolFForward<<>>( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] { + RoIPoolForward<<>>( output_size, input.data(), spatial_scale, @@ -150,9 +148,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, return std::make_tuple(output, argmax); } -// TODO remove the dependency on input and use instead its sizes -> save memory at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, - const at::Tensor& input, const at::Tensor& rois, const at::Tensor& argmax, const float spatial_scale, @@ -162,14 +158,16 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, const int channels, const int height, const int width) { + // Check if input tensors are CUDA tensors AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); - // TODO add more checks + AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); auto num_rois = rois.size(0); - at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_(); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); - cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); dim3 block(512); @@ -179,9 +177,14 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, THCudaCheck(cudaGetLastError()); return grad_input; } - - AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { - RoIPoolFBackward<<>>( + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { + RoIPoolBackward<<>>( grad.numel(), grad.data(), argmax.data(), @@ -193,7 +196,11 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, pooled_height, pooled_width, grad_input.data(), - rois.data()); + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); }); THCudaCheck(cudaGetLastError()); return grad_input; diff --git a/torchvision/csrc/cuda/cuda_helpers.h b/torchvision/csrc/cuda/cuda_helpers.h new file mode 100644 index 00000000000..15fb7f6031a --- /dev/null +++ b/torchvision/csrc/cuda/cuda_helpers.h @@ -0,0 +1,8 @@ +#ifndef CUDA_HELPERS_H +#define CUDA_HELPERS_H + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ + i += (blockDim.x * gridDim.x)) + +#endif // CUDA_HELPERS_H \ No newline at end of file diff --git a/torchvision/csrc/cuda/vision.h b/torchvision/csrc/cuda/vision.h index e7ce0ae720e..c9e738e175c 100644 --- a/torchvision/csrc/cuda/vision.h +++ b/torchvision/csrc/cuda/vision.h @@ -1,16 +1,15 @@ #pragma once #include - -at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, - const at::Tensor& rois, +at::Tensor ROIAlign_forward_cuda(const at::Tensor &input, + const at::Tensor &rois, const float spatial_scale, const int pooled_height, const int pooled_width, const int sampling_ratio); -at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, - const at::Tensor& rois, +at::Tensor ROIAlign_backward_cuda(const at::Tensor &grad, + const at::Tensor &rois, const float spatial_scale, const int pooled_height, const int pooled_width, @@ -20,21 +19,19 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const int width, const int sampling_ratio); +std::tuple ROIPool_forward_cuda(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); -std::tuple ROIPool_forward_cuda(const at::Tensor& input, - const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width); - -at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, - const at::Tensor& input, - const at::Tensor& rois, - const at::Tensor& argmax, +at::Tensor ROIPool_backward_cuda(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, const float spatial_scale, const int pooled_height, const int pooled_width, const int batch_size, const int channels, const int height, - const int width); + const int width); \ No newline at end of file diff --git a/torchvision/layers/roi_pool.py b/torchvision/layers/roi_pool.py index 510911571be..f232d0cc160 100644 --- a/torchvision/layers/roi_pool.py +++ b/torchvision/layers/roi_pool.py @@ -16,21 +16,21 @@ def forward(ctx, input, roi, output_size, spatial_scale): ctx.spatial_scale = spatial_scale ctx.input_shape = input.size() output, argmax = _C.roi_pool_forward( - input, roi, spatial_scale, - output_size[0], output_size[1]) - ctx.save_for_backward(input, roi, argmax) + input, roi, spatial_scale, + output_size[0], output_size[1]) + ctx.save_for_backward(roi, argmax) return output @staticmethod @once_differentiable def backward(ctx, grad_output): - input, rois, argmax = ctx.saved_tensors + rois, argmax = ctx.saved_tensors output_size = ctx.output_size spatial_scale = ctx.spatial_scale bs, ch, h, w = ctx.input_shape grad_input = _C.roi_pool_backward( - grad_output, input, rois, argmax, spatial_scale, - output_size[0], output_size[1], bs, ch, h, w) + grad_output, rois, argmax, spatial_scale, + output_size[0], output_size[1], bs, ch, h, w) return grad_input, None, None, None roi_pool = _ROIPool.apply