Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchscriptable adjust_gamma transform #2459

Merged
merged 3 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _create_data(self, height=3, width=3, channels=3):

def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.equal(pil_tensor), msg)

def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
Expand Down Expand Up @@ -293,6 +295,33 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def test_adjust_gamma(self):
script_fn = torch.jit.script(F_t.adjust_gamma)
tensor, pil_img = self._create_data(26, 36)

for dt in [torch.float64, torch.float32, None]:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)

gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):

adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])

rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

self.compareTensorToPIL(rbg_tensor, adjusted_pil)

self.assertTrue(adjusted_tensor.equal(scripted_result))

def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)
Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,14 +1179,14 @@ def test_adjust_gamma(self):
# test 1
y_pil = F.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 2
y_pil = F.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

Expand Down
30 changes: 14 additions & 16 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was giving results that didn't match PIL

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, we had quite a lot of discussion about this behavior in #2078 (comment). I believe if we make the multiplication go to dtype.max, we will end up with a non-uniform distribution over the last values.

@pmeier thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa is right about the intention. I think this boils down to

  • do we want it "right" or
  • do we want it compatible to other packages.

I'm in favor for the former (hence my implementation), but I can see why the latter is also feasible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my issue was internal consistency, not a difference between us and PIL. That comment thread is helpful -- we certainly expect more than one floating point value to map to 255.

I was able to fix it by just making the gamma adjustment consistent with convert_image_dtype.

This raises another issue though -- I'm not sure if it's still OK to express the equation for adjust_gamma as 255 * gain * (img/255) ** gamma in the docs, where in reality it's 255.999 * gain..... I want to be accurate but also don't want to be unnecessarily confusing, and the doc does say "based on."

result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else:
# int to float
if dtype.is_floating_point:
Expand Down Expand Up @@ -760,7 +766,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))


def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(img, gamma: float, gain: float = 1):
nairbv marked this conversation as resolved.
Show resolved Hide resolved
r"""Perform gamma correction on an image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
Expand All @@ -774,26 +780,18 @@ def adjust_gamma(img, gamma, gain=1):
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (PIL Image): PIL Image to be adjusted.
img (PIL Image or Tensor): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
nairbv marked this conversation as resolved.
Show resolved Hide resolved
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

input_mode = img.mode
img = img.convert('RGB')

gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)

img = img.convert(input_mode)
return img
return F_t.adjust_gamma(img, gamma, gain)


def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
Expand Down
36 changes: 36 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,42 @@ def adjust_hue(img, hue_factor):
return img


@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
r"""Perform gamma correction on an image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:

.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}

See `Gamma Correction`_ for more details.

.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part

img = img.convert(input_mode)
return img


@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
r"""Pad the given PIL.Image on all sides with the given "pad" value.
Expand Down
41 changes: 41 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
return _blend(img, rgb_to_grayscale(img), saturation_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""Adjust gamma of an RGB image.

Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:

.. math::
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`

See `Gamma Correction`_ for more details.

.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction

Args:
img (Tensor): Tensor of RBG values to be adjusted.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
"""

if not isinstance(img, torch.Tensor):
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0

result = (gain * result ** gamma).clamp(0, 1)

if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = result.to(dtype)
return result


def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""Crop the Image Tensor and resize it to desired size.

Expand Down