diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 56703d0a1fd..ccba4d08369 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -288,22 +288,35 @@ def _blend(img1, img2, ratio): def _rgb2hsv(img): r, g, b = img.unbind(0) - maxc, _ = torch.max(img, dim=0) - minc, _ = torch.min(img, dim=0) + maxc = torch.max(img, dim=0).values + minc = torch.min(img, dim=0).values + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occuring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc cr = maxc - minc - s = cr / maxc - rc = (maxc - r) / cr - gc = (maxc - g) / cr - bc = (maxc - b) / cr + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + s = cr / torch.where(eqc, maxc.new_ones(()), maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = torch.where(eqc, maxc.new_ones(()), cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor - t = (maxc != minc) - s = t * s hr = (maxc == r) * (bc - gc) hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) h = (hr + hg + hb) - h = t * h h = torch.fmod((h / 6.0 + 1.0), 1.0) return torch.stack((h, s, maxc))