Skip to content

Adjust hue accepts torch tensor #2300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3119328
Adjust hue
Jun 4, 2020
08a70b2
Adjust hue acceps torch.tensor uint8
Jun 4, 2020
d88086f
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
Jun 8, 2020
149d5a3
Adjust hue amend
Jun 8, 2020
2e82f62
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
Jun 8, 2020
0155071
remove commented code
Jun 8, 2020
fdd5cd7
Remove commented code
Jun 8, 2020
666c8cb
Remove commented code
Jun 8, 2020
2701e04
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
Jun 8, 2020
b68f504
Add support for [0, 1] input image.
Jun 8, 2020
cde0048
Add suport for [0, 1] imput image.
Jun 8, 2020
1962355
Change comment.
Jun 8, 2020
cd63405
Fix casting.
Jun 8, 2020
2424009
Batch equalities.
Jun 8, 2020
a9d350b
Batch equlaities
Jun 8, 2020
9c78791
Try alternative
Jun 8, 2020
a63d923
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 9, 2020
40f5c94
Compare with colorsys. Fix divide by zero.
vikramtankasali Jun 9, 2020
b72c424
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
40766e0
Fix nits
vikramtankasali Jun 11, 2020
a70eb18
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
442c873
Fix nit
vikramtankasali Jun 11, 2020
7e7581b
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
07bb5c7
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
5427564
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
168a4c4
Fixed pylint
vikramtankasali Jun 11, 2020
87622eb
Merge branch 'adjust_hue' of github.com:vikramtankasali/vision into a…
vikramtankasali Jun 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import unittest
import random
import colorsys
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple


Expand Down Expand Up @@ -56,6 +57,45 @@ def test_crop(self):
cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script))

def test_hsv2rgb(self):
shape = (3, 100, 150)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)

h, s, v, = img.unbind(0)
h = h.flatten().numpy()
s = s.flatten().numpy()
v = v.flatten().numpy()

rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))

colorsys_img = torch.tensor(rgb, dtype=torch.float32)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_rgb2hsv(self):
shape = (3, 150, 100)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)

r, g, b, = img.unbind(0)
r = r.flatten().numpy()
g = g.flatten().numpy()
b = b.flatten().numpy()

hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

colorsys_img = torch.tensor(hsv, dtype=torch.float32)

max_diff = (colorsys_img - ft_hsv_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
Expand Down
92 changes: 92 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,54 @@ def adjust_contrast(img, contrast_factor):
return _blend(img, mean, contrast_factor)


def adjust_hue(img, hue_factor):
"""Adjust hue of an image.

The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.

`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.

See `Hue`_ for more details.

.. _Hue: https://en.wikipedia.org/wiki/Hue

Args:
img (Tensor): Image to be adjusted. Image type is either uint8 or float.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.

Returns:
Tensor: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0

img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h += hue_factor
h = h % 1.0
img = torch.stack((h, s, v))
img_hue_adj = _hsv2rgb(img)

if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

return img_hue_adj


def adjust_saturation(img, saturation_factor):
# type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image.
Expand Down Expand Up @@ -236,3 +284,47 @@ def _blend(img1, img2, ratio):
# type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def _rgb2hsv(img):
r, g, b = img.unbind(0)

maxc, _ = torch.max(img, dim=0)
minc, _ = torch.min(img, dim=0)

cr = maxc - minc
s = cr / maxc
rc = (maxc - r) / cr
gc = (maxc - g) / cr
bc = (maxc - b) / cr

t = (maxc != minc)
s = t * s
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = t * h
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc))


def _hsv2rgb(img):
h, s, v = img.unbind(0)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)

p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6

mask = i == torch.arange(6)[:, None, None]

a1 = torch.stack((v, q, p, p, t, v))
a2 = torch.stack((t, v, v, q, p, p))
a3 = torch.stack((p, p, t, v, v, q))
a4 = torch.stack((a1, a2, a3))

return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)