diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 38a565310d0..c070c5c1d61 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -552,24 +552,25 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine): def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): # 4) Test rotation + translation + scale + share test_configs = [ - (45, [5, 6], 1.0, [0.0, 0.0]), - (33, (5, -4), 1.0, [0.0, 0.0]), - (45, [-5, 4], 1.2, [0.0, 0.0]), - (33, (-4, -8), 2.0, [0.0, 0.0]), - (85, (10, -10), 0.7, [0.0, 0.0]), - (0, [0, 0], 1.0, [35.0, ]), - (-25, [0, 0], 1.2, [0.0, 15.0]), - (-45, [-10, 0], 0.7, [2.0, 5.0]), - (-45, [-10, -10], 1.2, [4.0, 5.0]), - (-90, [0, 0], 1.0, [0.0, 0.0]), + (45.5, [5, 6], 1.0, [0.0, 0.0], None), + (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), + (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), + (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), + (85, (10, -10), 0.7, [0.0, 0.0], [1, ]), + (0, [0, 0], 1.0, [35.0, ], (2.0, )), + (-25, [0, 0], 1.2, [0.0, 15.0], None), + (-45, [-10, 0], 0.7, [2.0, 5.0], None), + (-45, [-10, -10], 1.2, [4.0, 5.0], None), + (-90, [0, 0], 1.0, [0.0, 0.0], None), ] for r in [NEAREST, ]: - for a, t, s, sh in test_configs: - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) + for a, t, s, sh, f in test_configs: + f_pil = int(f[0]) if f is not None and len(f) == 1 else f + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu() + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -582,7 +583,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): ratio_diff_pixels, tol, msg="{}: {}\n{} vs \n{}".format( - (r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] + (r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) ) @@ -643,35 +644,36 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): for a in range(-180, 180, 17): for e in [True, False]: for c in centers: - - out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu() - - if out_tensor.dtype != torch.uint8: - out_tensor = out_tensor.to(torch.uint8) - - self.assertEqual( - out_tensor.shape, - out_pil_tensor.shape, - msg="{}: {} vs {}".format( - (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape - ) - ) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 3% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.03, - msg="{}: {}\n{} vs \n{}".format( - (img_size, r, dt, a, e, c), + for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]: + f_pil = int(f[0]) if f is not None and len(f) == 1 else f + out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for fn in [F.rotate, scripted_rotate]: + out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu() + + if out_tensor.dtype != torch.uint8: + out_tensor = out_tensor.to(torch.uint8) + + self.assertEqual( + out_tensor.shape, + out_pil_tensor.shape, + msg="{}: {} vs {}".format( + (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape + )) + + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 3% of different pixels + self.assertLess( ratio_diff_pixels, - out_tensor[0, :7, :7], - out_pil_tensor[0, :7, :7] + 0.03, + msg="{}: {}\n{} vs \n{}".format( + (img_size, r, dt, a, e, c, f), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) ) - ) def test_rotate(self): # Tests on square image @@ -721,30 +723,33 @@ def test_rotate(self): def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): dt = tensor.dtype - for r in [NEAREST, ]: - for spoints, epoints in test_configs: - out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) - out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) + for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]: + for r in [NEAREST, ]: + for spoints, epoints in test_configs: + f_pil = int(f[0]) if f is not None and len(f) == 1 else f + out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r, + fill=f_pil) + out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - for fn in [F.perspective, scripted_transform]: - out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() + for fn in [F.perspective, scripted_transform]: + out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu() - if out_tensor.dtype != torch.uint8: - out_tensor = out_tensor.to(torch.uint8) + if out_tensor.dtype != torch.uint8: + out_tensor = out_tensor.to(torch.uint8) - num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 - ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] - # Tolerance : less than 5% of different pixels - self.assertLess( - ratio_diff_pixels, - 0.05, - msg="{}: {}\n{} vs \n{}".format( - (r, dt, spoints, epoints), + num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 + ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] + # Tolerance : less than 5% of different pixels + self.assertLess( ratio_diff_pixels, - out_tensor[0, :7, :7], - out_pil_tensor[0, :7, :7] + 0.05, + msg="{}: {}\n{} vs \n{}".format( + (f, r, dt, spoints, epoints), + ratio_diff_pixels, + out_tensor[0, :7, :7], + out_pil_tensor[0, :7, :7] + ) ) - ) def test_perspective(self): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index c5c4a7f09e0..aff492b41d6 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -349,14 +349,15 @@ def test_random_affine(self): for translate in [(0.1, 0.2), [0.2, 0.1]]: for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for interpolation in [NEAREST, BILINEAR]: - transform = T.RandomAffine( - degrees=degrees, translate=translate, - scale=scale, shear=shear, interpolation=interpolation - ) - s_transform = torch.jit.script(transform) + for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + transform = T.RandomAffine( + degrees=degrees, translate=translate, + scale=scale, shear=shear, interpolation=interpolation, fill=fill + ) + s_transform = torch.jit.script(transform) - self._test_transform_vs_scripted(transform, s_transform, tensor) - self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt")) @@ -369,13 +370,14 @@ def test_random_rotate(self): for expand in [True, False]: for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for interpolation in [NEAREST, BILINEAR]: - transform = T.RandomRotation( - degrees=degrees, interpolation=interpolation, expand=expand, center=center - ) - s_transform = torch.jit.script(transform) + for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + transform = T.RandomRotation( + degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill + ) + s_transform = torch.jit.script(transform) - self._test_transform_vs_scripted(transform, s_transform, tensor) - self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt")) @@ -386,14 +388,16 @@ def test_random_perspective(self): for distortion_scale in np.linspace(0.1, 1.0, num=20): for interpolation in [NEAREST, BILINEAR]: - transform = T.RandomPerspective( - distortion_scale=distortion_scale, - interpolation=interpolation - ) - s_transform = torch.jit.script(transform) + for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]: + transform = T.RandomPerspective( + distortion_scale=distortion_scale, + interpolation=interpolation, + fill=fill + ) + s_transform = torch.jit.script(transform) - self._test_transform_vs_scripted(transform, s_transform, tensor) - self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + self._test_transform_vs_scripted(transform, s_transform, tensor) + self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmp_dir, "t_perspective.pt")) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3f6548357d5..72baf021f9d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -557,7 +557,7 @@ def perspective( startpoints: List[List[int]], endpoints: List[List[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[int] = None + fill: Optional[List[float]] = None ) -> Tensor: """Perform perspective transform of the given image. The image can be a PIL Image or a Tensor, in which case it is expected @@ -573,10 +573,12 @@ def perspective( :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed image. If int or float, the value is used for all bands respectively. - This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor - input. Fill value for the area outside the transform in the output image is always 0. + This option is supported for PIL image and Tensor inputs. + In torchscript mode single int/float value is not supported, please use a tuple + or list of length 1: ``[value, ]``. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. Returns: PIL Image or Tensor: transformed Image. @@ -871,7 +873,7 @@ def _get_inverse_affine_matrix( def rotate( img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[int]] = None, - fill: Optional[int] = None, resample: Optional[int] = None + fill: Optional[List[float]] = None, resample: Optional[int] = None ) -> Tensor: """Rotate the image by angle. The image can be a PIL Image or a Tensor, in which case it is expected @@ -890,13 +892,12 @@ def rotate( Note that the expand flag assumes rotation around the center and no translation. center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. - This option is not supported for Tensor input. Fill value for the area outside the transform in the output - image is always 0. - resample (int, optional): deprecated argument and will be removed since v0.10.0. - Please use `arg`:interpolation: instead. + This option is supported for PIL image and Tensor inputs. + In torchscript mode single int/float value is not supported, please use a tuple + or list of length 1: ``[value, ]``. + If input is PIL Image, the options is only available for ``Pillow>=5.2.0``. Returns: PIL Image or Tensor: Rotated image. @@ -945,8 +946,8 @@ def rotate( def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None, - resample: Optional[int] = None, fillcolor: Optional[int] = None + interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, + resample: Optional[int] = None, fillcolor: Optional[List[float]] = None ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. The image can be a PIL Image or a Tensor, in which case it is expected @@ -964,10 +965,13 @@ def affine( :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. - fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). - This option is not supported for Tensor input. Fill value for the area outside the transform in the output - image is always 0. - fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + In torchscript mode single int/float value is not supported, please use a tuple + or list of length 1: ``[value, ]``. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. + fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0. Please use `arg`:fill: instead. resample (int, optional): deprecated argument and will be removed since v0.10.0. Please use `arg`:interpolation: instead. diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 7e3989f0288..51d83f0fd63 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -465,10 +465,13 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): fill = 0 if isinstance(fill, (int, float)) and num_bands > 1: fill = tuple([fill] * num_bands) - if not isinstance(fill, (int, float)) and len(fill) != num_bands: - msg = ("The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") - raise ValueError(msg.format(len(fill), num_bands)) + if isinstance(fill, (list, tuple)): + if len(fill) != num_bands: + msg = ("The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + fill = tuple(fill) return {name: fill} diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4f3e72a62ce..d21e2d6220e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -835,7 +835,7 @@ def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], interpolation: str, - fill: Optional[int], + fill: Optional[List[float]], supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ): @@ -851,8 +851,15 @@ def _assert_grid_transform_inputs( if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") - if fill is not None and not (isinstance(fill, (int, float)) and fill == 0): - warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") + if fill is not None and not isinstance(fill, (int, float, tuple, list)): + warnings.warn("Argument fill should be either int, float, tuple or list") + + # Check fill + num_channels = _get_image_num_channels(img) + if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): + msg = ("The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_channels)) if interpolation not in supported_interpolation_modes: raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation)) @@ -887,15 +894,34 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp return img -def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: +def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ]) if img.shape[0] > 1: # Apply same grid to a batch of images grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) + img = torch.cat((img, dummy), dim=1) + img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) + # Fill with required color + if fill is not None: + mask = img[:, -1:, :, :] # N * 1 * H * W + img = img[:, :-1, :, :] # N * C * H * W + mask = mask.expand_as(img) + len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 + fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) + if mode == 'nearest': + mask = mask < 0.5 + img[mask] = fill_img[mask] + else: # 'bilinear' + img = img * mask + (1.0 - mask) * fill_img + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img @@ -923,7 +949,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None ) -> Tensor: """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. @@ -936,8 +962,8 @@ def affine( img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". - fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the - transform in the output image is always 0. + fill (sequence or int or float, optional): Optional fill value, default None. + If None, fill with 0. Returns: Tensor: Transformed image. @@ -949,7 +975,7 @@ def affine( shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - return _apply_grid_transform(img, grid, interpolation) + return _apply_grid_transform(img, grid, interpolation, fill=fill) def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: @@ -979,7 +1005,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( img: Tensor, matrix: List[float], interpolation: str = "nearest", - expand: bool = False, fill: Optional[int] = None + expand: bool = False, fill: Optional[List[float]] = None ) -> Tensor: """PRIVATE METHOD. Rotate the Tensor image by angle. @@ -997,8 +1023,8 @@ def rotate( If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - fill (n-tuple or int or float): this option is not supported for Tensor input. - Fill value for the area outside the transform in the output image is always 0. + fill (sequence or int or float, optional): Optional fill value, default None. + If None, fill with 0. Returns: Tensor: Rotated image. @@ -1013,7 +1039,8 @@ def rotate( theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) - return _apply_grid_transform(img, grid, interpolation) + + return _apply_grid_transform(img, grid, interpolation, fill=fill) def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device): @@ -1050,7 +1077,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, def perspective( - img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None + img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None ) -> Tensor: """PRIVATE METHOD. Perform perspective transform of the given Tensor image. @@ -1063,8 +1090,8 @@ def perspective( img (Tensor): Image to be transformed. perspective_coeffs (list of float): perspective transformation coefficients. interpolation (str): Interpolation type. Default, "bilinear". - fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area - outside the transform in the output image is always 0. + fill (sequence or int or float, optional): Optional fill value, default None. + If None, fill with 0. Returns: Tensor: transformed image. @@ -1084,7 +1111,7 @@ def perspective( ow, oh = img.shape[-1], img.shape[-2] dtype = img.dtype if torch.is_floating_point(img) else torch.float32 grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) - return _apply_grid_transform(img, grid, interpolation) + return _apply_grid_transform(img, grid, interpolation, fill=fill) def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 25886d59a60..3b159fd3f22 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -667,10 +667,10 @@ class RandomPerspective(torch.nn.Module): :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated - image. If int or float, the value is used for all bands respectively. Default is 0. - This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor - input. Fill value for the area outside the transform in the output image is always 0. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. """ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): @@ -697,10 +697,18 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly transformed image. """ + + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + else: + fill = [float(f) for f in fill] + if torch.rand(1) < self.p: width, height = F._get_image_size(img) startpoints, endpoints = self.get_params(width, height, self.distortion_scale) - return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) + return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @staticmethod @@ -1157,11 +1165,10 @@ class RandomRotation(torch.nn.Module): Note that the expand flag assumes rotation around the center and no translation. center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. Default is the center of the image. - fill (n-tuple or int or float): Pixel fill value for area outside the rotated + fill (sequence or int or float, optional): Pixel fill value for the area outside the rotated image. If int or float, the value is used for all bands respectively. - Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. - This option is not supported for Tensor input. Fill value for the area outside the transform in the output - image is always 0. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.2.0``. resample (int, optional): deprecated argument and will be removed since v0.10.0. Please use `arg`:interpolation: instead. @@ -1216,8 +1223,15 @@ def forward(self, img): Returns: PIL Image or Tensor: Rotated image. """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + else: + fill = [float(f) for f in fill] angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill) + + return F.rotate(img, angle, self.resample, self.expand, self.center, fill) def __repr__(self): interpolate_str = self.interpolation.value @@ -1257,10 +1271,11 @@ class RandomAffine(torch.nn.Module): :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. - fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area - outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor - input. Fill value for the area outside the transform in the output image is always 0. - fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed + image. If int or float, the value is used for all bands respectively. + This option is supported for PIL image and Tensor inputs. + If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. + fillcolor (sequence or int or float, optional): deprecated argument and will be removed since v0.10.0. Please use `arg`:fill: instead. resample (int, optional): deprecated argument and will be removed since v0.10.0. Please use `arg`:interpolation: instead. @@ -1363,11 +1378,18 @@ def forward(self, img): Returns: PIL Image or Tensor: Affine transformed image. """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F._get_image_num_channels(img) + else: + fill = [float(f) for f in fill] img_size = F._get_image_size(img) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill) + + return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) def __repr__(self): s = '{name}(degrees={degrees}'