diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c6cbb42ecfc..c1b81a54dd7 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -10,37 +10,45 @@ Transforms are common image transforms. They can be chained together using :clas Transforms on PIL Image ----------------------- -.. autoclass:: Resize +.. autoclass:: CenterCrop -.. autoclass:: Scale +.. autoclass:: ColorJitter -.. autoclass:: CenterCrop +.. autoclass:: FiveCrop -.. autoclass:: RandomCrop +.. autoclass:: Grayscale -.. autoclass:: RandomHorizontalFlip +.. autoclass:: LinearTransformation -.. autoclass:: RandomVerticalFlip +.. autoclass:: Pad -.. autoclass:: RandomResizedCrop +.. autoclass:: RandomAffine -.. autoclass:: RandomSizedCrop +.. autoclass:: RandomApply -.. autoclass:: Grayscale +.. autoclass:: RandomChoice + +.. autoclass:: RandomCrop .. autoclass:: RandomGrayscale -.. autoclass:: FiveCrop +.. autoclass:: RandomHorizontalFlip -.. autoclass:: TenCrop +.. autoclass:: RandomOrder -.. autoclass:: Pad - -.. autoclass:: ColorJitter +.. autoclass:: RandomResizedCrop .. autoclass:: RandomRotation -.. autoclass:: RandomAffine +.. autoclass:: RandomSizedCrop + +.. autoclass:: RandomVerticalFlip + +.. autoclass:: Resize + +.. autoclass:: Scale + +.. autoclass:: TenCrop Transforms on torch.\*Tensor ---------------------------- @@ -53,11 +61,11 @@ Transforms on torch.\*Tensor Conversion Transforms --------------------- -.. autoclass:: ToTensor +.. autoclass:: ToPILImage :members: __call__ :special-members: -.. autoclass:: ToPILImage +.. autoclass:: ToTensor :members: __call__ :special-members: @@ -66,3 +74,9 @@ Generic Transforms .. autoclass:: Lambda + +Functional Transforms +--------------------- + +.. automodule:: torchvision.transforms.functional + :members: diff --git a/setup.py b/setup.py index 0f46586deec..6c271c72bb4 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,11 @@ import shutil import sys from setuptools import setup, find_packages +from pkg_resources import get_distribution, DistributionNotFound +import glob + +import torch +from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME def read(*names, **kwargs): @@ -15,6 +20,13 @@ def read(*names, **kwargs): return fp.read() +def get_dist(pkgname): + try: + return get_distribution(pkgname) + except DistributionNotFound: + return None + + def find_version(*file_paths): version_file = read(*file_paths) version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", @@ -30,11 +42,52 @@ def find_version(*file_paths): requirements = [ 'numpy', - 'pillow >= 4.1.1', 'six', 'torch', ] +pillow_ver = ' >= 4.1.1' +pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' +requirements.append(pillow_req + pillow_ver) + +tqdm_ver = ' == 4.19.9' if sys.version_info[0] < 3 else '' +requirements.append('tqdm' + tqdm_ver) + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') + + main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp')) + source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu')) + + sources = main_file + source_cpu + extension = CppExtension + + extra_cflags = [] + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [('WITH_CUDA', None)] + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + 'torchvision._C', + sources, + include_dirs=include_dirs, + define_macros=define_macros + ) + ] + + return ext_modules + setup( # Metadata name='torchvision', @@ -51,4 +104,7 @@ def find_version(*file_paths): zip_safe=True, install_requires=requirements, + + ext_modules=get_extensions(), + cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension} ) diff --git a/test/test_layers.py b/test/test_layers.py new file mode 100644 index 00000000000..d508393c64a --- /dev/null +++ b/test/test_layers.py @@ -0,0 +1,190 @@ +import torch +from torch.autograd import gradcheck + +from torchvision import layers + + +import unittest + + +class ROIPoolTester(unittest.TestCase): + + def test_roi_pool_basic_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) + + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + + def test_roi_pool_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y, y), 'ROIPool layer incorrect' + + def test_roi_pool_gradient_cpu(self): + dtype = torch.float32 + device = torch.device('cpu') + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda() + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_basic_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w) + + for n in range(0, gt_y.size(0)): + start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1 + start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1 + roi_x = x[:, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w]) + + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=dtype, device=device) + + pool_h, pool_w = (5, 5) + roi_pool = layers.ROIPool((pool_h, pool_w), 1) + y = roi_pool(x, rois) + + gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device) + for n in range(0, gt_y.size(0)): + for r, roi in enumerate(rois): + if roi[0] == n: + start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 + start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 + roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w] + bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w + for j in range(0, pool_h): + for i in range(0, pool_w): + gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i], + torch.max(roi_x[:, :, + j * bin_h:(j + 1) * bin_h, + i * bin_w:(i + 1) * bin_w]) + ) + + assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_pool_gradient_gpu(self): + dtype = torch.float32 + device = torch.device('cuda') + layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) + x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 4, 9], + [0, 0, 0, 4, 4]], + dtype=dtype, device=device) + + def func(input): + return layer(input, rois) + + x.requires_grad = True + y = layer(x, rois) + # print(argmax, argmax.shape) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], + [2., 1., 2., 1., 2., 0., 1., 0., 1., 0.], + [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device) + + assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_transforms.py b/test/test_transforms.py index e2232e2491b..34024acc6dd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -245,7 +245,7 @@ def test_pad_with_tuple_of_pad_values(self): def test_pad_with_non_constant_padding_modes(self): """Unit tests for edge, reflect, symmetric padding""" - img = torch.zeros(3, 27, 27) + img = torch.zeros(3, 27, 27).byte() img[:, :, 0] = 1 # Constant value added to leftmost edge img = transforms.ToPILImage()(img) img = F.pad(img, 1, (200, 200, 200)) @@ -255,7 +255,7 @@ def test_pad_with_non_constant_padding_modes(self): # First 6 elements of leftmost edge in the middle of the image, values are in order: # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] - assert np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 255, 0])) + assert np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0])) assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) # Pad 3 to left/right, 2 to top/bottom @@ -263,7 +263,7 @@ def test_pad_with_non_constant_padding_modes(self): # First 6 elements of leftmost edge in the middle of the image, values are in order: # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] - assert np.all(reflect_middle_slice == np.asarray([0, 0, 255, 200, 255, 0])) + assert np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0])) assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) # Pad 3 to left, 2 to top, 2 to right, 1 to bottom @@ -271,7 +271,7 @@ def test_pad_with_non_constant_padding_modes(self): # First 6 elements of leftmost edge in the middle of the image, values are in order: # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] - assert np.all(symmetric_middle_slice == np.asarray([0, 255, 200, 200, 255, 0])) + assert np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0])) assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34) def test_pad_raises_with_invalid_pad_sequence_len(self): @@ -404,6 +404,12 @@ def test_to_tensor(self): expected_output = ndarray.transpose((2, 0, 1)) assert np.allclose(output.numpy(), expected_output) + # separate test for mode '1' PIL images + input_data = torch.ByteTensor(1, height, width).bernoulli_() + img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + output = trans(img) + assert np.allclose(input_data.numpy(), output.numpy()) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() @@ -646,7 +652,7 @@ def test_random_horizontal_flip(self): # Checking if RandomHorizontalFlip can be printed as string transforms.RandomHorizontalFlip().__repr__() - @unittest.skipIf(stats is None, 'scipt.stats is not available') + @unittest.skipIf(stats is None, 'scipy.stats is not available') def test_normalize(self): def samples_from_standard_normal(tensor): p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue diff --git a/test/test_utils.py b/test/test_utils.py index df6ae972bda..2f8392e0a94 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -18,14 +18,6 @@ def test_make_grid_not_inplace(self): utils.make_grid(t, normalize=True, scale_each=True) assert torch.equal(t, t_clone), 'make_grid modified tensor in-place' - def test_make_grid_raises_with_variable(self): - t = torch.autograd.Variable(torch.rand(3, 10, 10)) - with self.assertRaises(TypeError): - utils.make_grid(t) - - with self.assertRaises(TypeError): - utils.make_grid([t, t, t, t]) - if __name__ == '__main__': unittest.main() diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 87b3dd9df1e..64ff27e1a07 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -3,7 +3,7 @@ from torchvision import transforms from torchvision import utils -__version__ = '0.2.0' +__version__ = '0.2.1' _image_backend = 'PIL' diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/ROIPool.h new file mode 100644 index 00000000000..bd15a9e70fd --- /dev/null +++ b/torchvision/csrc/ROIPool.h @@ -0,0 +1,46 @@ +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +std::tuple ROIPool_forward(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width); +} + +at::Tensor ROIPool_backward(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + if (grad.type().is_cuda()) + { +#ifdef WITH_CUDA + return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp new file mode 100644 index 00000000000..0f587d4323f --- /dev/null +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -0,0 +1,152 @@ +#include +#include +#include + +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + int num_rois = rois.size(0); + int channels = input.size(1); + int input_height = input.size(2); + int input_width = input.size(3); + + at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + + // define accessors for indexing + auto input_a = input.accessor(); + auto rois_a = rois.accessor(); + auto output_a = output.accessor(); + auto argmax_a = argmax.accessor(); + + if (output.numel() == 0) + { + return std::make_tuple(output, argmax); + } + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + int roi_start_w = round(rois_a[n][1] * spatial_scale); + int roi_start_h = round(rois_a[n][2] * spatial_scale); + int roi_end_w = round(rois_a[n][3] * spatial_scale); + int roi_end_h = round(rois_a[n][4] * spatial_scale); + + // Force malformed ROIs to be 1x1 or HxW + int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); + int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); + float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), input_height); + hend = std::min(std::max(hend + roi_start_h, 0), input_height); + wstart = std::min(std::max(wstart + roi_start_w, 0), input_width); + wend = std::min(std::max(wend + roi_start_w, 0), input_width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + float maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + + for (int c = 0; c < channels; ++c) + { + for (int h = hstart; h < hend; ++h) + { + for (int w = wstart; w < wend; ++w) + { + int index = h * input_width + w; + if (input_a[roi_batch_ind][c][h][w] > maxval) + { + maxval = input_a[roi_batch_ind][c][h][w]; + maxidx = index; + } + } + } + output_a[n][c][ph][pw] = maxval; + argmax_a[n][c][ph][pw] = maxidx; + } + } + } + } + + return std::make_tuple(output, argmax); +} + +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) +{ + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor"); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + + // handle possibly empty gradients + if (grad.numel() == 0) + { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + // define accessors for tensors + auto grad_input_a = grad_input.accessor(); + auto grad_a = grad.accessor(); + auto argmax_a = argmax.accessor(); + auto rois_a = rois.accessor(); + + for (int n = 0; n < num_rois; ++n) + { + int roi_batch_ind = rois_a[n][0]; + + for (int c = 0; c < channels; ++c) + { + for (int ph = 0; ph < pooled_height; ++ph) + { + for (int pw = 0; pw < pooled_width; ++pw) + { + int argmax_idx = argmax_a[n][c][ph][pw]; + // get height and width index from argmax index + int h = argmax_idx / height; + int w = argmax_idx % width; + + grad_input_a[roi_batch_ind][c][h][w] += grad_a[n * n_stride][c * c_stride][ph * h_stride][pw * w_stride]; + } + } + } + } + + return grad_input; +} \ No newline at end of file diff --git a/torchvision/csrc/cpu/vision.h b/torchvision/csrc/cpu/vision.h new file mode 100644 index 00000000000..64e65e66864 --- /dev/null +++ b/torchvision/csrc/cpu/vision.h @@ -0,0 +1,19 @@ +#pragma once +#include + +std::tuple ROIPool_forward_cpu(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu new file mode 100644 index 00000000000..5f95de1da43 --- /dev/null +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -0,0 +1,208 @@ +#include +#include + +#include +#include +#include + +#include "cuda_helpers.h" +#include + + +template +__global__ void RoIPoolForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const T* bottom_rois, T* top_data, int* argmax_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 or HxW + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (offset_bottom_data[bottom_index] > maxval) { + maxval = offset_bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +__global__ void RoIPoolBackward(const int nthreads, const T* top_grad, + const int* argmax_data, const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* bottom_data, + const T* bottom_rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + T* bottom_data_offset = bottom_data + bottom_offset; + + int top_offset = n*n_stride + c*c_stride; + const int* argmax_data_offset = argmax_data + n*channels*pooled_height*pooled_width; + int argmax = argmax_data_offset[c*pooled_height*pooled_width + ph*pooled_width + pw]; + + if (argmax != -1) { + atomicAdd(bottom_data_offset + argmax, + static_cast(top_grad[top_offset + ph*h_stride + pw*w_stride])); + } + } +} + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] { + RoIPoolForward<<>>( + output_size, + input.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.data(), + output.data(), + argmax.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); +} + +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); + + auto num_rois = rois.size(0); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { + RoIPoolBackward<<>>( + grad.numel(), + grad.data(), + argmax.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data(), + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} \ No newline at end of file diff --git a/torchvision/csrc/cuda/cuda_helpers.h b/torchvision/csrc/cuda/cuda_helpers.h new file mode 100644 index 00000000000..15fb7f6031a --- /dev/null +++ b/torchvision/csrc/cuda/cuda_helpers.h @@ -0,0 +1,8 @@ +#ifndef CUDA_HELPERS_H +#define CUDA_HELPERS_H + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ + i += (blockDim.x * gridDim.x)) + +#endif // CUDA_HELPERS_H \ No newline at end of file diff --git a/torchvision/csrc/cuda/vision.h b/torchvision/csrc/cuda/vision.h new file mode 100644 index 00000000000..4f83f83b4f4 --- /dev/null +++ b/torchvision/csrc/cuda/vision.h @@ -0,0 +1,19 @@ +#pragma once +#include + +std::tuple ROIPool_forward_cuda(const at::Tensor &input, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cuda(const at::Tensor &grad, + const at::Tensor &rois, + const at::Tensor &argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); \ No newline at end of file diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp new file mode 100644 index 00000000000..88caec61d43 --- /dev/null +++ b/torchvision/csrc/vision.cpp @@ -0,0 +1,7 @@ +#include "ROIPool.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); + m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); +} \ No newline at end of file diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index ac7a7269af3..d7fdfbc18b5 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -45,6 +45,18 @@ class CIFAR10(data.Dataset): test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + + @property + def targets(self): + if self.train: + return self.train_labels + else: + return self.test_labels def __init__(self, root, train=True, transform=None, target_transform=None, @@ -100,6 +112,21 @@ def __init__(self, root, train=True, self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + self._load_meta() + + def _load_meta(self): + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + if not check_integrity(path, self.meta['md5']): + raise RuntimeError('Dataset metadata file not found or corrupted.' + + ' You can use download=True to download it') + with open(path, 'rb') as infile: + if sys.version_info[0] == 2: + data = pickle.load(infile) + else: + data = pickle.load(infile, encoding='latin1') + self.classes = data[self.meta['key']] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + def __getitem__(self, index): """ Args: @@ -187,3 +214,8 @@ class CIFAR100(CIFAR10): test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 1df4bcbf44d..3bd4c485b65 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -4,6 +4,7 @@ import os import os.path +import sys def has_file_allowed_extension(filename, extensions): @@ -11,25 +12,31 @@ def has_file_allowed_extension(filename, extensions): Args: filename (string): path to a file + extensions (iterable of strings): extensions to consider (lowercase) Returns: - bool: True if the filename ends with a known image extension + bool: True if the filename ends with one of given extensions """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) -def find_classes(dir): - classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] - classes.sort() - class_to_idx = {classes[i]: i for i in range(len(classes))} - return classes, class_to_idx +def is_image_file(filename): + """Checks if a file is an allowed image extension. + + Args: + filename (string): path to a file + + Returns: + bool: True if the filename ends with a known image extension + """ + return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) - for target in sorted(os.listdir(dir)): + for target in sorted(class_to_idx.keys()): d = os.path.join(dir, target) if not os.path.isdir(d): continue @@ -69,10 +76,11 @@ class DatasetFolder(data.Dataset): classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): - classes, class_to_idx = find_classes(root) + classes, class_to_idx = self._find_classes(root) samples = make_dataset(root, class_to_idx, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" @@ -85,10 +93,33 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.classes = classes self.class_to_idx = class_to_idx self.samples = samples + self.targets = [s[1] for s in samples] self.transform = transform self.target_transform = target_transform + def _find_classes(self, dir): + """ + Finds the class folders in a dataset. + + Args: + dir (string): Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + + Ensures: + No class is a subdirectory of another. + """ + if sys.version_info >= (3, 5): + # Faster and available in Python 3.5 and above + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + else: + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + def __getitem__(self, index): """ Args: diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index c7fe27e040a..0d63a061a1f 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -22,7 +22,7 @@ def __init__(self, root, transform=None, target_transform=None): readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = txn.stat()['entries'] - cache_file = '_cache_' + root.replace('/', '_') + cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) if os.path.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 7f4463eff64..ee4f30d4f18 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,6 +7,7 @@ import numpy as np import torch import codecs +from .utils import download_url class MNIST(data.Dataset): @@ -35,6 +36,9 @@ class MNIST(data.Dataset): processed_folder = 'processed' training_file = 'training.pt' test_file = 'test.pt' + classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + class_to_idx = {_class: i for i, _class in enumerate(classes)} def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) @@ -50,11 +54,10 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down ' You can use download=True to download it') if self.train: - self.train_data, self.train_labels = torch.load( - os.path.join(self.root, self.processed_folder, self.training_file)) + data_file = self.training_file else: - self.test_data, self.test_labels = torch.load( - os.path.join(self.root, self.processed_folder, self.test_file)) + data_file = self.test_file + self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file)) def __getitem__(self, index): """ @@ -64,10 +67,7 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is index of the target class. """ - if self.train: - img, target = self.train_data[index], self.train_labels[index] - else: - img, target = self.test_data[index], self.test_labels[index] + img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image @@ -82,10 +82,7 @@ def __getitem__(self, index): return img, target def __len__(self): - if self.train: - return len(self.train_data) - else: - return len(self.test_data) + return len(self.data) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ @@ -93,7 +90,6 @@ def _check_exists(self): def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" - from six.moves import urllib import gzip if self._check_exists(): @@ -110,12 +106,10 @@ def download(self): raise for url in self.urls: - print('Downloading ' + url) - data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + download_url(url, root=os.path.join(self.root, self.raw_folder), + filename=filename, md5=None) with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read()) @@ -174,6 +168,9 @@ class FashionMNIST(MNIST): 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', ] + classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', + 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] + class_to_idx = {_class: i for i, _class in enumerate(classes)} class EMNIST(MNIST): @@ -216,7 +213,6 @@ def _test_file(self, split): def download(self): """Download the EMNIST data if it doesn't exist in processed_folder already.""" - from six.moves import urllib import gzip import shutil import zipfile @@ -234,13 +230,10 @@ def download(self): else: raise - print('Downloading ' + self.url) - data = urllib.request.urlopen(self.url) filename = self.url.rpartition('/')[2] raw_folder = os.path.join(self.root, self.raw_folder) file_path = os.path.join(raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + download_url(self.url, root=file_path, filename=filename, md5=None) print('Extracting zip archive') with zipfile.ZipFile(file_path) as zip_f: diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 9fa3b0b8c9b..43e5896801a 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -2,9 +2,22 @@ import os.path import hashlib import errno +from tqdm import tqdm -def check_integrity(fpath, md5): +def gen_bar_updater(pbar): + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def check_integrity(fpath, md5=None): + if md5 is None: + return True if not os.path.isfile(fpath): return False md5o = hashlib.md5() @@ -38,13 +51,19 @@ def download_url(url, root, filename, md5): else: try: print('Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) except: if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) def list_dir(root, prefix=False): diff --git a/torchvision/layers/__init__.py b/torchvision/layers/__init__.py new file mode 100644 index 00000000000..1e54c6311c3 --- /dev/null +++ b/torchvision/layers/__init__.py @@ -0,0 +1,6 @@ +from .roi_pool import roi_pool, ROIPool + + +__all__ = [ + 'roi_pool', 'ROIPool' +] diff --git a/torchvision/layers/roi_pool.py b/torchvision/layers/roi_pool.py new file mode 100644 index 00000000000..f232d0cc160 --- /dev/null +++ b/torchvision/layers/roi_pool.py @@ -0,0 +1,53 @@ +import torch +from torch import nn + +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from torch.nn.modules.utils import _pair + +from torchvision import _C + + +class _ROIPool(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + output, argmax = _C.roi_pool_forward( + input, roi, spatial_scale, + output_size[0], output_size[1]) + ctx.save_for_backward(roi, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, argmax = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_pool_backward( + grad_output, rois, argmax, spatial_scale, + output_size[0], output_size[1], bs, ch, h, w) + return grad_input, None, None, None + +roi_pool = _ROIPool.apply + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'output_size=' + str(self.output_size) + tmpstr += ', spatial_scale=' + str(self.spatial_scale) + tmpstr += ')' + return tmpstr diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 1bcdb0b429d..79064849dfd 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,3 +1,4 @@ +import re import torch import torch.nn as nn import torch.nn.functional as F @@ -25,7 +26,20 @@ def densenet121(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet121']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -39,7 +53,20 @@ def densenet169(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet169']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -53,7 +80,20 @@ def densenet201(pretrained=False, **kwargs): model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet201']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model @@ -67,20 +107,33 @@ def densenet161(pretrained=False, **kwargs): model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) + # '.'s are no longer allowed in module names, but pervious _DenseLayer + # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. + # They are also in the checkpoints in model_urls. This pattern is used + # to find such keys. + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + state_dict = model_zoo.load_url(model_urls['densenet161']) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + model.load_state_dict(state_dict) return model class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() - self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu.1', nn.ReLU(inplace=True)), - self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu.2', nn.ReLU(inplace=True)), - self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), self.drop_rate = drop_rate @@ -122,6 +175,7 @@ class DenseNet(nn.Module): drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes """ + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): @@ -156,12 +210,12 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal(m.weight.data) + nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): - m.bias.data.zero_() + nn.init.constant_(m.bias, 0) def forward(self, x): features = self.features(x) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 3f1283fef9d..425c5b88bcb 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -61,12 +61,12 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): import scipy.stats as stats stddev = m.stddev if hasattr(m, 'stddev') else 0.1 X = stats.truncnorm(-2, 2, scale=stddev) - values = torch.Tensor(X.rvs(m.weight.data.numel())) - values = values.view(m.weight.data.size()) + values = torch.Tensor(X.rvs(m.weight.numel())) + values = values.view(m.weight.size()) m.weight.data.copy_(values) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def forward(self, x): if self.transform_input: diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 033f3415034..14135b92de5 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -64,8 +64,8 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -112,11 +112,10 @@ def __init__(self, block, layers, num_classes=1000): for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 9965e9efc3c..428e8d4a4df 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -89,11 +89,11 @@ def __init__(self, version=1.0, num_classes=1000): for m in self.modules(): if isinstance(m, nn.Conv2d): if m is final_conv: - init.normal(m.weight.data, mean=0.0, std=0.01) + init.normal_(m.weight, mean=0.0, std=0.01) else: - init.kaiming_uniform(m.weight.data) + init.kaiming_uniform_(m.weight) if m.bias is not None: - m.bias.data.zero_() + init.constant_(m.bias, 0) def forward(self, x): x = self.features(x) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 4f112d96772..9da6db5686f 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -47,16 +47,15 @@ def forward(self, x): def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: - m.bias.data.zero_() + nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) def make_layers(cfg, batch_norm=False): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d5325078be..95030caede0 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -64,9 +64,11 @@ def to_tensor(pic): img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F': img = torch.from_numpy(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16': @@ -86,13 +88,13 @@ def to_tensor(pic): def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. - See :class:`~torchvision.transforms.ToPIlImage` for more details. + See :class:`~torchvision.transforms.ToPILImage` for more details. Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). - .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes Returns: PIL Image: Image converted to PIL Image. @@ -149,7 +151,10 @@ def to_pil_image(pic, mode=None): def normalize(tensor, mean, std): """Normalize a tensor image with mean and standard deviation. - See ``Normalize`` for more details. + .. note:: + This transform acts in-place, i.e., it mutates the input tensor. + + See :class:`~torchvision.transforms.Normalize` for more details. Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -161,14 +166,15 @@ def normalize(tensor, mean, std): """ if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') - # TODO: make efficient + + # This is faster than using broadcasting, don't change without benchmarking for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor def resize(img, size, interpolation=Image.BILINEAR): - """Resize the input PIL Image to the given size. + r"""Resize the input PIL Image to the given size. Args: img (PIL Image): Image to be resized. @@ -176,7 +182,7 @@ def resize(img, size, interpolation=Image.BILINEAR): (h, w), the output size will be matched to this. If size is an int, the smaller edge of the image will be matched to this number maintaing the aspect ratio. i.e, if height > width, then image will be rescaled to - (size * height / width, size) + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` @@ -211,7 +217,7 @@ def scale(*args, **kwargs): def pad(img, padding, fill=0, padding_mode='constant'): - """Pad the given PIL Image on all sides with speficified padding mode and fill value. + r"""Pad the given PIL Image on all sides with specified padding mode and fill value. Args: img (PIL Image): Image to be padded. @@ -224,14 +230,20 @@ def pad(img, padding, fill=0, padding_mode='constant'): length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value on the edge of the image - reflect: pads with reflection of image (without repeating the last value on the edge) - padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode - will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image (repeating the last value on the edge) - padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] Returns: PIL Image: Padded image. @@ -310,7 +322,7 @@ def center_crop(img, output_size): def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): """Crop the given PIL Image and resize it to desired size. - Notably used in RandomResizedCrop. + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. Args: img (PIL Image): Image to be cropped. @@ -371,9 +383,10 @@ def five_crop(img, size): size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. + Returns: - tuple: tuple (tl, tr, bl, br, center) corresponding top left, - top right, bottom left, bottom right and center crop. + tuple: tuple (tl, tr, bl, br, center) + Corresponding top left, top right, bottom left, bottom right and center crop. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) @@ -394,24 +407,23 @@ def five_crop(img, size): def ten_crop(img, size, vertical_flip=False): - """Crop the given PIL Image into four corners and the central crop plus the - flipped version of these (horizontal flipping is used by default). + r"""Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. - vertical_flip (bool): Use vertical flipping instead of horizontal - - Returns: - tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, - br_flip, center_flip) corresponding top left, top right, - bottom left, bottom right and center crop and same for the - flipped image. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + Corresponding top left, top right, bottom left, bottom right and center crop + and same for the flipped image. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) @@ -499,7 +511,9 @@ def adjust_hue(img, hue_factor): `hue_factor` is the amount of shift in H channel and must be in the interval `[-0.5, 0.5]`. - See https://en.wikipedia.org/wiki/Hue for more details on Hue. + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue Args: img (PIL Image): PIL Image to be adjusted. @@ -535,20 +549,23 @@ def adjust_hue(img, hue_factor): def adjust_gamma(img, gamma, gain=1): - """Perform gamma correction on an image. + r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted based on the following equation: - I_out = 255 * gain * ((I_in / 255) ** gamma) + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. - See https://en.wikipedia.org/wiki/Gamma_correction for more details. + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction Args: img (PIL Image): PIL Image to be adjusted. - gamma (float): Non negative real number. gamma larger than 1 make the - shadows darker, while gamma smaller than 1 make dark regions - lighter. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. """ if not _is_pil_image(img): @@ -573,11 +590,10 @@ def rotate(img, angle, resample=False, expand=False, center=None): Args: img (PIL Image): PIL Image to be rotated. - angle ({float, int}): In degrees degrees counter clockwise order. - resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): - An optional resampling filter. - See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + angle (float or int): In degrees degrees counter clockwise order. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -585,6 +601,9 @@ def rotate(img, angle, resample=False, expand=False, center=None): center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + """ if not _is_pil_image(img): @@ -633,14 +652,14 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): Args: img (PIL Image): PIL Image to be rotated. - angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction. + angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale shear (float): shear angle value in degrees between -180 to 180, clockwise direction. - resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. - See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) """ if not _is_pil_image(img): @@ -665,9 +684,10 @@ def to_grayscale(img, num_output_channels=1): img (PIL Image): Image to be converted to grayscale. Returns: - PIL Image: Grayscale version of the image. - if num_output_channels == 1 : returned image is single channel - if num_output_channels == 3 : returned image is 3 channel with r == g == b + PIL Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6475316074f..7c352be2f57 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -93,7 +93,7 @@ class ToPILImage(object): 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, ``int``, ``float``, ``short``). - .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ def __init__(self, mode=None): self.mode = mode @@ -123,6 +123,9 @@ class Normalize(object): will normalize each channel of the input ``torch.*Tensor`` i.e. ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts in-place, i.e., it mutates the input tensor. + Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. @@ -227,17 +230,24 @@ class Pad(object): on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. - fill: Pixel fill value for constant fill. Default is 0. If a tuple of + fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant - padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value at the edge of the image - reflect: pads with reflection of image (without repeating the last value on the edge) - padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image + + - reflect: pads with reflection of image without repeating the last value on the edge + + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image (repeating the last value on the edge) - padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + + - symmetric: pads with reflection of image repeating the last value on the edge + + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode will result in [2, 1, 1, 2, 3, 4, 4, 3] """ @@ -365,20 +375,42 @@ class RandomCrop(object): int instead of sequence like (h, w), a square crop (size, size) is made. padding (int or sequence, optional): Optional padding on each border - of the image. Default is 0, i.e no padding. If a sequence of length + of the image. Default is None, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders - respectively. + respectively. If a sequence of length 2 is provided, it is used to + pad left/right, top/bottom borders, respectively. pad_if_needed (boolean): It will pad the image if smaller than the desired size to avoid raising an exception. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ - def __init__(self, size, padding=0, pad_if_needed=False): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode @staticmethod def get_params(img, output_size): @@ -408,15 +440,15 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ - if self.padding > 0: - img = F.pad(img, self.padding) + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) # pad the width if needed if self.pad_if_needed and img.size[0] < self.size[1]: - img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) + img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) # pad the height if needed if self.pad_if_needed and img.size[1] < self.size[0]: - img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) + img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) i, j, h, w = self.get_params(img, self.size) @@ -655,7 +687,7 @@ class LinearTransformation(object): original shape. Applications: - - whitening: zero-center the data, compute the data covariance matrix + - whitening: zero-center the data, compute the data covariance matrix [D x D] with np.dot(X.T, X), perform SVD on this matrix and pass it as transformation_matrix. @@ -696,20 +728,44 @@ class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: - brightness (float): How much to jitter brightness. brightness_factor - is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. - contrast (float): How much to jitter contrast. contrast_factor - is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. - saturation (float): How much to jitter saturation. saturation_factor - is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. - hue(float): How much to jitter hue. hue_factor is chosen uniformly from - [-hue, hue]. Should be >=0 and <= 0.5. + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value @staticmethod def get_params(brightness, contrast, saturation, hue): @@ -722,20 +778,21 @@ def get_params(brightness, contrast, saturation, hue): saturation in a random order. """ transforms = [] - if brightness > 0: - brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) - if contrast > 0: - contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) - if saturation > 0: - saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) - if hue > 0: - hue_factor = random.uniform(-hue, hue) + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) random.shuffle(transforms) @@ -772,8 +829,7 @@ class RandomRotation(object): If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): - An optional resampling filter. - See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters + An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. @@ -782,6 +838,9 @@ class RandomRotation(object): center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + """ def __init__(self, degrees, resample=False, expand=False, center=None): @@ -837,7 +896,7 @@ class RandomAffine(object): Args: degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). Set to 0 to desactivate rotations. + will be (-degrees, +degrees). Set to 0 to deactivate rotations. translate (tuple, optional): tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is @@ -848,10 +907,12 @@ class RandomAffine(object): If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). Will not apply shear by default resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): - An optional resampling filter. - See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters + An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + """ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): diff --git a/torchvision/utils.py b/torchvision/utils.py index ab78104d265..e277d0c7253 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -58,7 +58,7 @@ def norm_range(t, range): if range is not None: norm_ip(t, range[0], range[1]) else: - norm_ip(t, t.min(), t.max()) + norm_ip(t, float(t.min()), float(t.max())) if scale_each is True: for t in tensor: # loop over mini-batch dimension