Skip to content

Commit 54da5db

Browse files
vikramtankasaliVikram Mukunda Rao Tankasali
andauthored
Adjust hue accepts torch tensor (#2300)
* Adjust hue * Adjust hue acceps torch.tensor uint8 Co-authored-by: Vikram Mukunda Rao Tankasali <vikramtankasali@devvm765.lla0.facebook.com>
1 parent 747f406 commit 54da5db

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

test/test_functional_tensor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import unittest
88
import random
9+
import colorsys
910
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
1011

1112

@@ -56,6 +57,45 @@ def test_crop(self):
5657
cropped_img_script = script_crop(img_tensor, top, left, height, width)
5758
self.assertTrue(torch.equal(img_cropped, cropped_img_script))
5859

60+
def test_hsv2rgb(self):
61+
shape = (3, 100, 150)
62+
for _ in range(20):
63+
img = torch.rand(*shape, dtype=torch.float)
64+
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)
65+
66+
h, s, v, = img.unbind(0)
67+
h = h.flatten().numpy()
68+
s = s.flatten().numpy()
69+
v = v.flatten().numpy()
70+
71+
rgb = []
72+
for h1, s1, v1 in zip(h, s, v):
73+
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
74+
75+
colorsys_img = torch.tensor(rgb, dtype=torch.float32)
76+
max_diff = (ft_img - colorsys_img).abs().max()
77+
self.assertLess(max_diff, 1e-5)
78+
79+
def test_rgb2hsv(self):
80+
shape = (3, 150, 100)
81+
for _ in range(20):
82+
img = torch.rand(*shape, dtype=torch.float)
83+
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)
84+
85+
r, g, b, = img.unbind(0)
86+
r = r.flatten().numpy()
87+
g = g.flatten().numpy()
88+
b = b.flatten().numpy()
89+
90+
hsv = []
91+
for r1, g1, b1 in zip(r, g, b):
92+
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
93+
94+
colorsys_img = torch.tensor(hsv, dtype=torch.float32)
95+
96+
max_diff = (colorsys_img - ft_hsv_img).abs().max()
97+
self.assertLess(max_diff, 1e-5)
98+
5999
def test_adjustments(self):
60100
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
61101
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)

torchvision/transforms/functional_tensor.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,54 @@ def adjust_contrast(img, contrast_factor):
118118
return _blend(img, mean, contrast_factor)
119119

120120

121+
def adjust_hue(img, hue_factor):
122+
"""Adjust hue of an image.
123+
124+
The image hue is adjusted by converting the image to HSV and
125+
cyclically shifting the intensities in the hue channel (H).
126+
The image is then converted back to original image mode.
127+
128+
`hue_factor` is the amount of shift in H channel and must be in the
129+
interval `[-0.5, 0.5]`.
130+
131+
See `Hue`_ for more details.
132+
133+
.. _Hue: https://en.wikipedia.org/wiki/Hue
134+
135+
Args:
136+
img (Tensor): Image to be adjusted. Image type is either uint8 or float.
137+
hue_factor (float): How much to shift the hue channel. Should be in
138+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
139+
HSV space in positive and negative direction respectively.
140+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
141+
with complementary colors while 0 gives the original image.
142+
143+
Returns:
144+
Tensor: Hue adjusted image.
145+
"""
146+
if not(-0.5 <= hue_factor <= 0.5):
147+
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
148+
149+
if not _is_tensor_a_torch_image(img):
150+
raise TypeError('tensor is not a torch image.')
151+
152+
orig_dtype = img.dtype
153+
if img.dtype == torch.uint8:
154+
img = img.to(dtype=torch.float32) / 255.0
155+
156+
img = _rgb2hsv(img)
157+
h, s, v = img.unbind(0)
158+
h += hue_factor
159+
h = h % 1.0
160+
img = torch.stack((h, s, v))
161+
img_hue_adj = _hsv2rgb(img)
162+
163+
if orig_dtype == torch.uint8:
164+
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
165+
166+
return img_hue_adj
167+
168+
121169
def adjust_saturation(img, saturation_factor):
122170
# type: (Tensor, float) -> Tensor
123171
"""Adjust color saturation of an RGB image.
@@ -235,3 +283,47 @@ def _blend(img1, img2, ratio):
235283
# type: (Tensor, Tensor, float) -> Tensor
236284
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
237285
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
286+
287+
288+
def _rgb2hsv(img):
289+
r, g, b = img.unbind(0)
290+
291+
maxc, _ = torch.max(img, dim=0)
292+
minc, _ = torch.min(img, dim=0)
293+
294+
cr = maxc - minc
295+
s = cr / maxc
296+
rc = (maxc - r) / cr
297+
gc = (maxc - g) / cr
298+
bc = (maxc - b) / cr
299+
300+
t = (maxc != minc)
301+
s = t * s
302+
hr = (maxc == r) * (bc - gc)
303+
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
304+
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
305+
h = (hr + hg + hb)
306+
h = t * h
307+
h = torch.fmod((h / 6.0 + 1.0), 1.0)
308+
return torch.stack((h, s, maxc))
309+
310+
311+
def _hsv2rgb(img):
312+
h, s, v = img.unbind(0)
313+
i = torch.floor(h * 6.0)
314+
f = (h * 6.0) - i
315+
i = i.to(dtype=torch.int32)
316+
317+
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
318+
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
319+
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
320+
i = i % 6
321+
322+
mask = i == torch.arange(6)[:, None, None]
323+
324+
a1 = torch.stack((v, q, p, p, t, v))
325+
a2 = torch.stack((t, v, v, q, p, p))
326+
a3 = torch.stack((p, p, t, v, v, q))
327+
a4 = torch.stack((a1, a2, a3))
328+
329+
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)

0 commit comments

Comments
 (0)