diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d23930e7313..48604ec287b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -23,6 +23,8 @@ def _create_data(self, height=3, width=3, channels=3): def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + if msg is None: + msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) self.assertTrue(tensor.equal(pil_tensor), msg) def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): @@ -293,6 +295,33 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") + def test_adjust_gamma(self): + script_fn = torch.jit.script(F_t.adjust_gamma) + tensor, pil_img = self._create_data(26, 36) + + for dt in [torch.float64, torch.float32, None]: + + if dt is not None: + tensor = F.convert_image_dtype(tensor, dt) + + gammas = [0.8, 1.0, 1.2] + gains = [0.7, 1.0, 1.3] + for gamma, gain in zip(gammas, gains): + + adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain) + adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain) + scripted_result = script_fn(tensor, gamma, gain) + self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype) + self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1]) + + rbg_tensor = adjusted_tensor + if adjusted_tensor.dtype != torch.uint8: + rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8) + + self.compareTensorToPIL(rbg_tensor, adjusted_pil) + + self.assertTrue(adjusted_tensor.equal(scripted_result)) + def test_resize(self): script_fn = torch.jit.script(F_t.resize) tensor, pil_img = self._create_data(26, 36) diff --git a/test/test_transforms.py b/test/test_transforms.py index b0eb844fcf8..19caefcd788 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1179,14 +1179,14 @@ def test_adjust_gamma(self): # test 1 y_pil = F.adjust_gamma(x_pil, 0.5) y_np = np.array(y_pil) - y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] + y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) # test 2 y_pil = F.adjust_gamma(x_pil, 2) y_np = np.array(y_pil) - y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] + y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 801df42a187..659ea88b84a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -161,8 +161,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." raise RuntimeError(msg) + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. eps = 1e-3 - return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + result = image.mul(torch.iinfo(dtype).max + 1 - eps) + return result.to(dtype) else: # int to float if dtype.is_floating_point: @@ -760,7 +766,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: raise TypeError('img should be PIL Image. Got {}'.format(type(img))) -def adjust_gamma(img, gamma, gain=1): +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted @@ -774,26 +780,18 @@ def adjust_gamma(img, gamma, gain=1): .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Tensor): PIL Image to be adjusted. gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. + Returns: + PIL Image or Tensor: Gamma correction adjusted image. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') - - input_mode = img.mode - img = img.convert('RGB') - - gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 - img = img.point(gamma_map) # use PIL's point-function to accelerate this part + if not isinstance(img, torch.Tensor): + return F_pil.adjust_gamma(img, gamma, gain) - img = img.convert(input_mode) - return img + return F_t.adjust_gamma(img, gamma, gain) def rotate(img, angle, resample=False, expand=False, center=None, fill=None): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 994988ce1f6..753497f9b2d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -164,6 +164,42 @@ def adjust_hue(img, hue_factor): return img +@torch.jit.unused +def adjust_gamma(img, gamma, gain=1): + r"""Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + @torch.jit.unused def pad(img, padding, fill=0, padding_mode="constant"): r"""Pad the given PIL.Image on all sides with the given "pad" value. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 59cf6bc2764..b446fe37567 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -194,6 +194,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: + r"""Adjust gamma of an RGB image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + `I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}` + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (Tensor): Tensor of RBG values to be adjusted. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + """ + + if not isinstance(img, torch.Tensor): + raise TypeError('img should be a Tensor. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + result = img + dtype = img.dtype + if not torch.is_floating_point(img): + result = result / 255.0 + + result = (gain * result ** gamma).clamp(0, 1) + + if result.dtype != dtype: + eps = 1e-3 + result = (255 + 1.0 - eps) * result + result = result.to(dtype) + return result + + def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: """Crop the Image Tensor and resize it to desired size.