diff --git a/packaging/windows/internal/cuda_install.bat b/packaging/windows/internal/cuda_install.bat index aff5f834a99..15ea785a793 100644 --- a/packaging/windows/internal/cuda_install.bat +++ b/packaging/windows/internal/cuda_install.bat @@ -82,7 +82,7 @@ if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( ) if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" if errorlevel 1 exit /b 1 set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 511d2ba1adb..a054dbda91c 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -51,12 +51,16 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) + model.to(device) if not (args.test_only or args.post_training_quantize): model.fuse_model() model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend) torch.quantization.prepare_qat(model, inplace=True) + if args.distributed and args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) @@ -65,8 +69,6 @@ def main(args): step_size=args.lr_step_size, gamma=args.lr_gamma) - model.to(device) - criterion = nn.CrossEntropyLoss() model_without_ddp = model if args.distributed: @@ -224,6 +226,12 @@ def parse_args(): It also serializes the transforms", action="store_true", ) + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) parser.add_argument( "--test-only", dest="test_only", diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e7d058e8da2..1a8c77c827f 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -6,6 +6,7 @@ import numpy as np import unittest import random +import colorsys from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple @@ -56,6 +57,45 @@ def test_crop(self): cropped_img_script = script_crop(img_tensor, top, left, height, width) self.assertTrue(torch.equal(img_cropped, cropped_img_script)) + def test_hsv2rgb(self): + shape = (3, 100, 150) + for _ in range(20): + img = torch.rand(*shape, dtype=torch.float) + ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1) + + h, s, v, = img.unbind(0) + h = h.flatten().numpy() + s = s.flatten().numpy() + v = v.flatten().numpy() + + rgb = [] + for h1, s1, v1 in zip(h, s, v): + rgb.append(colorsys.hsv_to_rgb(h1, s1, v1)) + + colorsys_img = torch.tensor(rgb, dtype=torch.float32) + max_diff = (ft_img - colorsys_img).abs().max() + self.assertLess(max_diff, 1e-5) + + def test_rgb2hsv(self): + shape = (3, 150, 100) + for _ in range(20): + img = torch.rand(*shape, dtype=torch.float) + ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1) + + r, g, b, = img.unbind(0) + r = r.flatten().numpy() + g = g.flatten().numpy() + b = b.flatten().numpy() + + hsv = [] + for r1, g1, b1 in zip(r, g, b): + hsv.append(colorsys.rgb_to_hsv(r1, g1, b1)) + + colorsys_img = torch.tensor(hsv, dtype=torch.float32) + + max_diff = (colorsys_img - ft_hsv_img).abs().max() + self.assertLess(max_diff, 1e-5) + def test_adjustments(self): script_adjust_brightness = torch.jit.script(F_t.adjust_brightness) script_adjust_contrast = torch.jit.script(F_t.adjust_contrast) @@ -97,6 +137,23 @@ def test_adjustments(self): self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) self.assertTrue(torch.equal(img, img_clone)) + # test for class interface + f = transforms.ColorJitter(brightness=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(contrast=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(saturation=factor.item()) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + + f = transforms.ColorJitter(brightness=1) + scripted_fn = torch.jit.script(f) + scripted_fn(img) + def test_rgb_to_grayscale(self): script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) diff --git a/test/test_onnx.py b/test/test_onnx.py index cc88f954762..2ea58ba6bc6 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -446,39 +446,22 @@ def test_heatmaps_to_keypoints(self): assert torch.all(out2[1].eq(out_trace2[1])) def test_keypoint_rcnn(self): - class KeyPointRCNN(torch.nn.Module): - def __init__(self): - super(KeyPointRCNN, self).__init__() - self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn( - pretrained=True, min_size=200, max_size=300) - - def forward(self, images): - output = self.model(images) - # TODO: The keypoints_scores require the use of Argmax that is updated in ONNX. - # For now we are testing all the output of KeypointRCNN except keypoints_scores. - # Enable When Argmax is updated in ONNX Runtime. - return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints'] - images, test_images = self.get_test_images() - # TODO: - # Enable test for dummy_image (no detection) once issue is - # _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed - # dummy_images = [torch.ones(3, 100, 100) * 0.3] - model = KeyPointRCNN() + dummy_images = [torch.ones(3, 100, 100) * 0.3] + model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() model(images) - self.run_model(model, [(images,), (test_images,)], + self.run_model(model, [(images,), (test_images,), (dummy_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2, 3]}, + tolerate_small_mismatch=True) + + self.run_model(model, [(dummy_images,), (test_images,)], input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"], dynamic_axes={"images_tensors": [0, 1, 2, 3]}, tolerate_small_mismatch=True) - # TODO: enable this test once dynamic model export is fixed - # Test exported model for an image with no detections on other images - # self.run_model(model, [(dummy_images,), (test_images,)], - # input_names=["images_tensors"], - # output_names=["outputs1", "outputs2", "outputs3", "outputs4"], - # dynamic_axes={"images_tensors": [0, 1, 2, 3]}, - # tolerate_small_mismatch=True) if __name__ == '__main__': diff --git a/test/test_ops.py b/test/test_ops.py index ffaf23d80b5..2e3107f8d7e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -393,6 +393,10 @@ def test_nms(self): keep_ref = self.reference_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou)) + self.assertRaises(RuntimeError, ops.nms, torch.rand(4), torch.rand(3), 0.5) + self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 5), torch.rand(3), 0.5) + self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5) + self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_nms_cuda(self): diff --git a/test/test_transforms.py b/test/test_transforms.py index 45871276073..8423bf99ee3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,6 +23,22 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') +def cycle_over(objs): + objs = list(objs) + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + + +def int_dtypes(): + yield from iter( + (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) + ) + + +def float_dtypes(): + yield from iter((torch.float32, torch.float, torch.float64, torch.double)) + + class Tester(unittest.TestCase): def test_crop(self): @@ -299,13 +315,22 @@ def test_pad(self): width = random.randint(10, 32) * 2 img = torch.ones(3, height, width) padding = random.randint(1, 20) + fill = random.randint(1, 50) result = transforms.Compose([ transforms.ToPILImage(), - transforms.Pad(padding), + transforms.Pad(padding, fill=fill), transforms.ToTensor(), ])(img) self.assertEqual(result.size(1), height + 2 * padding) self.assertEqual(result.size(2), width + 2 * padding) + # check that all elements in the padded region correspond + # to the pad value + fill_v = fill / 255 + eps = 1e-5 + self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps) + self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps) + self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)), + transforms.ToPILImage()(img)) def test_pad_with_tuple_of_pad_values(self): height = random.randint(10, 32) * 2 @@ -501,6 +526,100 @@ def test_to_tensor(self): output = trans(img) self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + def test_convert_image_dtype_float_to_float(self): + for input_dtype, output_dtypes in cycle_over(float_dtypes()): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in output_dtypes: + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + + def test_convert_image_dtype_float_to_int(self): + for input_dtype in float_dtypes(): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + for output_dtype in int_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + + if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( + input_dtype == torch.float64 and output_dtype == torch.int64 + ): + with self.assertRaises(RuntimeError): + transform(input_image) + else: + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, torch.iinfo(output_dtype).max + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_float(self): + for input_dtype in int_dtypes(): + input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) + for output_dtype in float_dtypes(): + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + self.assertAlmostEqual(actual_min, desired_min) + self.assertGreaterEqual(actual_min, desired_min) + self.assertAlmostEqual(actual_max, desired_max) + self.assertLessEqual(actual_max, desired_max) + + def test_convert_image_dtype_int_to_int(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + output_image = transform(input_image) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, output_max + + # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details + if input_max >= output_max: + error_term = 0 + else: + error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max + error_term) + + def test_convert_image_dtype_int_to_int_consistency(self): + for input_dtype, output_dtypes in cycle_over(int_dtypes()): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + for output_dtype in output_dtypes: + output_max = torch.iinfo(output_dtype).max + if output_max <= input_max: + continue + + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): + transform = transforms.ConvertImageDtype(output_dtype) + inverse_transfrom = transforms.ConvertImageDtype(input_dtype) + output_image = inverse_transfrom(transform(input_image)) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, input_max + + self.assertEqual(actual_min, desired_min) + self.assertEqual(actual_max, desired_max) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_to_tensor(self): trans = transforms.ToTensor() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py new file mode 100644 index 00000000000..7791dd8b4f9 --- /dev/null +++ b/test/test_transforms_tensor.py @@ -0,0 +1,70 @@ +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F +from PIL import Image + +import numpy as np + +import unittest + + +class Tester(unittest.TestCase): + def _create_data(self, height=3, width=3, channels=3): + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) + return tensor, pil_img + + def compareTensorToPIL(self, tensor, pil_image): + pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + self.assertTrue(tensor.equal(pil_tensor)) + + def _test_flip(self, func, method): + tensor, pil_img = self._create_data() + flip_tensor = getattr(F, func)(tensor) + flip_pil_img = getattr(F, func)(pil_img) + self.compareTensorToPIL(flip_tensor, flip_pil_img) + + scripted_fn = torch.jit.script(getattr(F, func)) + flip_tensor_script = scripted_fn(tensor) + self.assertTrue(flip_tensor.equal(flip_tensor_script)) + + # test for class interface + f = getattr(T, method)() + scripted_fn = torch.jit.script(f) + scripted_fn(tensor) + + def test_random_horizontal_flip(self): + self._test_flip('hflip', 'RandomHorizontalFlip') + + def test_random_vertical_flip(self): + self._test_flip('vflip', 'RandomVerticalFlip') + + def test_adjustments(self): + fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] + for _ in range(20): + factor = 3 * torch.rand(1).item() + tensor, _ = self._create_data() + pil_img = T.ToPILImage()(tensor) + + for func in fns: + adjusted_tensor = getattr(F, func)(tensor, factor) + adjusted_pil_img = getattr(F, func)(pil_img, factor) + + adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img) + scripted_fn = torch.jit.script(getattr(F, func)) + adjusted_tensor_script = scripted_fn(tensor, factor) + + if not tensor.dtype.is_floating_point: + adjusted_tensor = adjusted_tensor.to(torch.float) / 255 + adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255 + + # F uses uint8 and F_t uses float, so there is a small + # difference in values caused by (at most 5) truncations. + max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max() + max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max() + self.assertLess(max_diff, 5 / 255 + 1e-5) + self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/csrc/cpu/video_reader/VideoReader.cpp b/torchvision/csrc/cpu/video_reader/VideoReader.cpp index 57801930926..3a184716b4d 100644 --- a/torchvision/csrc/cpu/video_reader/VideoReader.cpp +++ b/torchvision/csrc/cpu/video_reader/VideoReader.cpp @@ -311,7 +311,7 @@ torch::List readVideo( videoFrame = torch::zeros( {numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte); expectedWrittenBytes = - numVideoFrames * outHeight * outWidth * numChannels; + (size_t)numVideoFrames * outHeight * outWidth * numChannels; } videoFramePts = torch::zeros({numVideoFrames}, torch::kLong); diff --git a/torchvision/csrc/nms.h b/torchvision/csrc/nms.h index bc7bec1bbfe..3c2faba8353 100644 --- a/torchvision/csrc/nms.h +++ b/torchvision/csrc/nms.h @@ -12,6 +12,24 @@ at::Tensor nms( const at::Tensor& dets, const at::Tensor& scores, const double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); if (dets.is_cuda()) { #if defined(WITH_CUDA) if (dets.numel() == 0) { diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index ad19af8cb72..438ef225c91 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,8 +1,6 @@ import torch from torch import nn -from torchvision.ops import misc as misc_nn_ops - from torchvision.ops import MultiScaleRoIAlign from ..utils import load_state_dict_from_url @@ -253,10 +251,9 @@ def __init__(self, in_channels, num_keypoints): def forward(self, x): x = self.kps_score_lowres(x) - x = misc_nn_ops.interpolate( - x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False + return torch.nn.functional.interpolate( + x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False ) - return x model_urls = { diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 244186cf2dc..19cc15a8cc0 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -5,7 +5,6 @@ from torch import nn, Tensor from torchvision.ops import boxes as box_ops -from torchvision.ops import misc as misc_nn_ops from torchvision.ops import roi_align @@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, width_correction = widths_i / roi_map_width height_correction = heights_i / roi_map_height - roi_map = torch.nn.functional.interpolate( - maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0] + roi_map = F.interpolate( + maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0] w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) @@ -197,8 +196,12 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, xy_preds_i_2.to(dtype=torch.float32)], 0) # TODO: simplify when indexing without rank will be supported by ONNX + base = num_keypoints * num_keypoints + num_keypoints + 1 + ind = torch.arange(num_keypoints) + ind = ind.to(dtype=torch.int64) * base end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \ - .index_select(2, x_int.to(dtype=torch.int64))[:num_keypoints, 0, 0] + .index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64)) + return xy_preds_i, end_scores_i @@ -256,8 +259,8 @@ def heatmaps_to_keypoints(maps, rois): roi_map_height = int(heights_ceil[i].item()) width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height - roi_map = torch.nn.functional.interpolate( - maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0] + roi_map = F.interpolate( + maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0] # roi_map_probs = scores_to_probs(roi_map.copy()) w = roi_map.shape[2] pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) @@ -392,7 +395,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, -1, -1)) # Resize mask - mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) @@ -420,7 +423,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, mask.size(0), mask.size(1))) # Resize mask - mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) mask = mask[0][0] x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 29a2e07d5a2..5564866c571 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -2,10 +2,10 @@ import math import torch from torch import nn, Tensor +from torch.nn import functional as F import torchvision from torch.jit.annotations import List, Tuple, Dict, Optional -from torchvision.ops import misc as misc_nn_ops from .image_list import ImageList from .roi_heads import paste_masks_in_image @@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() target["masks"] = mask return image, target @@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() target["masks"] = mask return image, target diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index cc89418b146..1be9ef2741a 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -90,8 +90,10 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.Mixed_5b = inception_a(192, pool_features=32) self.Mixed_5c = inception_a(256, pool_features=64) self.Mixed_5d = inception_a(288, pool_features=64) @@ -105,6 +107,8 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, self.Mixed_7a = inception_d(768) self.Mixed_7b = inception_e(1280) self.Mixed_7c = inception_e(2048) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout() self.fc = nn.Linear(2048, num_classes) if init_weights: for m in self.modules(): @@ -136,13 +140,13 @@ def _forward(self, x): # N x 32 x 147 x 147 x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 - x = F.max_pool2d(x, kernel_size=3, stride=2) + x = self.maxpool1(x) # N x 64 x 73 x 73 x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 - x = F.max_pool2d(x, kernel_size=3, stride=2) + x = self.maxpool2(x) # N x 192 x 35 x 35 x = self.Mixed_5b(x) # N x 256 x 35 x 35 @@ -173,9 +177,9 @@ def _forward(self, x): x = self.Mixed_7c(x) # N x 2048 x 8 x 8 # Adaptive average pooling - x = F.adaptive_avg_pool2d(x, (1, 1)) + x = self.avgpool(x) # N x 2048 x 1 x 1 - x = F.dropout(x, training=self.training) + x = self.dropout(x) # N x 2048 x 1 x 1 x = torch.flatten(x, 1) # N x 2048 diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index ccc82e63cf2..61fab3edd7a 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,7 +1,3 @@ -from collections import OrderedDict -from torch.jit.annotations import Optional, List -from torch import Tensor - """ helper class that supports empty tensors on some nn functions. @@ -12,10 +8,8 @@ is implemented """ -import math import warnings import torch -from torchvision.ops import _new_empty_tensor class Conv2d(torch.nn.Conv2d): @@ -42,51 +36,7 @@ def __init__(self, *args, **kwargs): "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) -def _check_size_scale_factor(dim, size, scale_factor): - # type: (int, Optional[List[int]], Optional[float]) -> None - if size is None and scale_factor is None: - raise ValueError("either size or scale_factor should be defined") - if size is not None and scale_factor is not None: - raise ValueError("only one of size or scale_factor should be defined") - if scale_factor is not None: - if isinstance(scale_factor, (list, tuple)): - if len(scale_factor) != dim: - raise ValueError( - "scale_factor shape must match input shape. " - "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) - ) - - -def _output_size(dim, input, size, scale_factor): - # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] - assert dim == 2 - _check_size_scale_factor(dim, size, scale_factor) - if size is not None: - return size - # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat - assert scale_factor is not None and isinstance(scale_factor, (int, float)) - scale_factors = [scale_factor, scale_factor] - # math.floor might return float in py2.7 - return [ - int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) - ] - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if input.numel() > 0: - return torch.nn.functional.interpolate( - input, size, scale_factor, mode, align_corners - ) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) +interpolate = torch.nn.functional.interpolate # This is not in nn diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7f22fc51391..e49ff063dc8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor import math from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION try: @@ -11,6 +12,9 @@ from collections.abc import Sequence, Iterable import warnings +from . import functional_pil as F_pil +from . import functional_tensor as F_t + def _is_pil_image(img): if accimage is not None: @@ -109,6 +113,65 @@ def pil_to_tensor(pic): return img +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + (torch.Tensor): Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if image.dtype == dtype: + return image + + if image.dtype.is_floating_point: + # float to float + if dtype.is_floating_point: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + eps = 1e-3 + return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + else: + # int to float + if dtype.is_floating_point: + max = torch.iinfo(image.dtype).max + image = image.to(dtype) + return image / max + + # int to int + input_max = torch.iinfo(image.dtype).max + output_max = torch.iinfo(dtype).max + + if input_max > output_max: + factor = (input_max + 1) // (output_max + 1) + image = image // factor + return image.to(dtype) + else: + factor = (output_max + 1) // (input_max + 1) + image = image.to(dtype) + return image * factor + + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. @@ -329,6 +392,12 @@ def pad(img, padding, fill=0, padding_mode='constant'): 'Padding mode should be either constant, edge, reflect or symmetric' if padding_mode == 'constant': + if isinstance(fill, numbers.Number): + fill = (fill,) * len(img.getbands()) + if len(fill) != len(img.getbands()): + raise ValueError('fill should have the same number of elements ' + 'as the number of channels in the image ' + '({}), got {} instead'.format(len(img.getbands()), len(fill))) if img.mode == 'P': palette = img.getpalette() image = ImageOps.expand(img, border=padding, fill=fill) @@ -428,19 +497,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE return img -def hflip(img): - """Horizontally flip the given PIL Image. +def hflip(img: Tensor) -> Tensor: + """Horizontally flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Horizontally flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.hflip(img) - return img.transpose(Image.FLIP_LEFT_RIGHT) + return F_t.hflip(img) def _parse_fill(fill, img, min_pil_version): @@ -530,19 +602,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts) -def vflip(img): - """Vertically flip the given PIL Image. +def vflip(img: Tensor) -> Tensor: + """Vertically flip the given PIL Image or torch Tensor. Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Torch Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. Returns: PIL Image: Vertically flipped image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.vflip(img) - return img.transpose(Image.FLIP_TOP_BOTTOM) + return F_t.vflip(img) def five_crop(img, size): @@ -617,67 +692,61 @@ def ten_crop(img, size, vertical_flip=False): return first_five + second_five -def adjust_brightness(img, brightness_factor): +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: """Adjust brightness of an Image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: - PIL Image: Brightness adjusted image. + PIL Image or Torch Tensor: Brightness adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_brightness(img, brightness_factor) - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness_factor) - return img + return F_t.adjust_brightness(img, brightness_factor) -def adjust_contrast(img, contrast_factor): +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: """Adjust contrast of an Image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: - PIL Image: Contrast adjusted image. + PIL Image or Torch Tensor: Contrast adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_contrast(img, contrast_factor) - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast_factor) - return img + return F_t.adjust_contrast(img, contrast_factor) -def adjust_saturation(img, saturation_factor): +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: """Adjust color saturation of an image. Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Torch Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: - PIL Image: Saturation adjusted image. + PIL Image or Torch Tensor: Saturation adjusted image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_saturation(img, saturation_factor) - enhancer = ImageEnhance.Color(img) - img = enhancer.enhance(saturation_factor) - return img + return F_t.adjust_saturation(img, saturation_factor) -def adjust_hue(img, hue_factor): +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -702,26 +771,10 @@ def adjust_hue(img, hue_factor): Returns: PIL Image: Hue adjusted image. """ - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) - - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: - return img - - h, s, v = img.convert('HSV').split() - - np_h = np.array(h, dtype=np.uint8) - # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): - np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') + if not isinstance(img, torch.Tensor): + return F_pil.adjust_hue(img, hue_factor) - img = Image.merge('HSV', (h, s, v)).convert(input_mode) - return img + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) def adjust_gamma(img, gamma, gain=1): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py new file mode 100644 index 00000000000..84e27e79040 --- /dev/null +++ b/torchvision/transforms/functional_pil.py @@ -0,0 +1,154 @@ +import torch +try: + import accimage +except ImportError: + accimage = None +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +import numpy as np + + +@torch.jit.unused +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +@torch.jit.unused +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontally flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +@torch.jit.unused +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +@torch.jit.unused +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an RGB image. + + Args: + img (PIL Image): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +@torch.jit.unused +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +@torch.jit.unused +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +@torch.jit.unused +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b81deed6d43..89440701d17 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,11 +1,10 @@ import torch -import torchvision.transforms.functional as F from torch import Tensor from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple def _is_tensor_a_torch_image(input): - return len(input.shape) == 3 + return input.ndim >= 2 def vflip(img): @@ -119,6 +118,54 @@ def adjust_contrast(img, contrast_factor): return _blend(img, mean, contrast_factor) +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (Tensor): Image to be adjusted. Image type is either uint8 or float. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + Tensor: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + orig_dtype = img.dtype + if img.dtype == torch.uint8: + img = img.to(dtype=torch.float32) / 255.0 + + img = _rgb2hsv(img) + h, s, v = img.unbind(0) + h += hue_factor + h = h % 1.0 + img = torch.stack((h, s, v)) + img_hue_adj = _hsv2rgb(img) + + if orig_dtype == torch.uint8: + img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype) + + return img_hue_adj + + def adjust_saturation(img, saturation_factor): # type: (Tensor, float) -> Tensor """Adjust color saturation of an RGB image. @@ -236,3 +283,47 @@ def _blend(img1, img2, ratio): # type: (Tensor, Tensor, float) -> Tensor bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def _rgb2hsv(img): + r, g, b = img.unbind(0) + + maxc, _ = torch.max(img, dim=0) + minc, _ = torch.min(img, dim=0) + + cr = maxc - minc + s = cr / maxc + rc = (maxc - r) / cr + gc = (maxc - g) / cr + bc = (maxc - b) / cr + + t = (maxc != minc) + s = t * s + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = (hr + hg + hb) + h = t * h + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc)) + + +def _hsv2rgb(img): + h, s, v = img.unbind(0) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) + q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) + t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i == torch.arange(6)[:, None, None] + + a1 = torch.stack((v, q, p, p, t, v)) + a2 = torch.stack((t, v, v, q, p, p)) + a3 = torch.stack((p, p, t, v, v, q)) + a4 = torch.stack((a1, a2, a3)) + + return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 49fac26e395..d54aa5099f2 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,10 +15,10 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", - "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", - "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", - "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", +__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", + "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", + "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing"] _pil_interpolation_to_str = { @@ -115,6 +115,31 @@ def __repr__(self): return self.__class__.__name__ + '()' +class ConvertImageDtype(object): + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + def __init__(self, dtype: torch.dtype) -> None: + self.dtype = dtype + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + return F.convert_image_dtype(image, self.dtype) + + class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. @@ -500,25 +525,29 @@ def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) -class RandomHorizontalFlip(object): - """Horizontally flip the given PIL Image randomly with a given probability. +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.hflip(img) return img @@ -526,25 +555,29 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) -class RandomVerticalFlip(object): +class RandomVerticalFlip(torch.nn.Module): """Vertically flip the given PIL Image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): + super().__init__() self.p = p - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be flipped. + img (PIL Image or Tensor): Image to be flipped. Returns: - PIL Image: Randomly flipped image. + PIL Image or Tensor: Randomly flipped image. """ - if random.random() < self.p: + if torch.rand(1) < self.p: return F.vflip(img) return img @@ -857,7 +890,7 @@ def __repr__(self): return format_string -class ColorJitter(object): +class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast and saturation of an image. Args: @@ -874,20 +907,23 @@ class ColorJitter(object): hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError("If {} is a single number, it must be non negative.".format(name)) - value = [center - value, center + value] + value = [center - float(value), center + float(value)] if clip_first_on_zero: - value[0] = max(value[0], 0) + value[0] = max(value[0], 0.0) elif isinstance(value, (tuple, list)) and len(value) == 2: if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError("{} values should be between {}".format(name, bound)) @@ -901,6 +937,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs return value @staticmethod + @torch.jit.unused def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. @@ -933,17 +970,37 @@ def get_params(brightness, contrast, saturation, hue): return transform - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Input image. + img (PIL Image or Tensor): Input image. Returns: - PIL Image: Color jittered image. + PIL Image or Tensor: Color jittered image. """ - transform = self.get_params(self.brightness, self.contrast, - self.saturation, self.hue) - return transform(img) + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and self.brightness is not None: + brightness = self.brightness + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = F.adjust_brightness(img, brightness_factor) + + if fn_id == 1 and self.contrast is not None: + contrast = self.contrast + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = F.adjust_contrast(img, contrast_factor) + + if fn_id == 2 and self.saturation is not None: + saturation = self.saturation + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = F.adjust_saturation(img, saturation_factor) + + if fn_id == 3 and self.hue is not None: + hue = self.hue + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = F.adjust_hue(img, hue_factor) + + return img def __repr__(self): format_string = self.__class__.__name__ + '('