Skip to content

Commit

Permalink
Minor improvements on functional.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Oct 25, 2022
1 parent 0d7807d commit 090a0d7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 22 deletions.
12 changes: 1 addition & 11 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
i = torch.floor(h6)
f = (h6) - i
f = h6 - i
i = i.to(dtype=torch.int32)

p = (v * (1.0 - s)).clamp_(0.0, 1.0)
Expand All @@ -210,9 +210,6 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")

if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")

c = get_num_channels_image_tensor(image)

if c not in [1, 3]:
Expand Down Expand Up @@ -258,9 +255,6 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input


def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")

if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")

Expand Down Expand Up @@ -337,10 +331,6 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp


def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:

if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")

c = get_num_channels_image_tensor(image)

if c not in [1, 3]:
Expand Down
10 changes: 3 additions & 7 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,8 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)


def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]


def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
Expand Down Expand Up @@ -237,7 +233,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
Expand All @@ -248,7 +244,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,18 @@ def normalize(
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
return kernel1d


def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
return kernel2d

Expand Down

0 comments on commit 090a0d7

Please sign in to comment.