Skip to content

Commit

Permalink
[fbsync] Let Normalize() and RandomPhotometricDistort return datapoin…
Browse files Browse the repository at this point in the history
…ts instead of tensors (#7113)

Reviewed By: YosuaMichael

Differential Revision: D42706907

fbshipit-source-id: a7b7487ab8563f8a43a0ebb84df19579ccd35fe1
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jan 24, 2023
1 parent 94ecbbc commit 140a480
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 31 deletions.
1 change: 0 additions & 1 deletion test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
datapoints.Video: F.normalize_video,
},
test_marks=[
skip_dispatch_feature,
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
],
Expand Down
14 changes: 1 addition & 13 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)

inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)


@pytest.mark.parametrize(
"inpt",
[
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
)
return Image.wrap_like(self, output)

def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Image.wrap_like(self, output)


ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output)

def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Video.wrap_like(self, output)


VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output


# TODO: This class seems to be untested
class RandomPhotometricDistort(Transform):
_transformed_types = (
datapoints.Image,
Expand Down Expand Up @@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]:
if isinstance(inpt, PIL.Image.Image):

orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)

output = inpt[..., permutation, :, :]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]

elif isinstance(inpt, PIL.Image.Image):
if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)

return output
Expand Down
20 changes: 8 additions & 12 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,14 @@ def normalize(
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
elif not is_simple_tensor(inpt):
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead."
)

# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.normalize(mean=mean, std=std, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
Expand Down

0 comments on commit 140a480

Please sign in to comment.