Skip to content

Commit

Permalink
add reference test for normalize_image_tensor (#7119)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 23, 2023
1 parent d2d448c commit c206a47
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
18 changes: 18 additions & 0 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor():
yield ArgsKwargs(image_loader, mean=mean, std=std)


def reference_normalize_image_tensor(image, mean, std, inplace=False):
mean = torch.tensor(mean).view(-1, 1, 1)
std = torch.tensor(std).view(-1, 1, 1)

sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
yield ArgsKwargs(
make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]),
mean=[0.5, 0.5, 0.5],
std=[1.0, 1.0, 1.0],
)


def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders(
Expand All @@ -2246,6 +2262,8 @@ def sample_inputs_normalize_video():
F.normalize_image_tensor,
kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor,
reference_fn=reference_normalize_image_tensor,
reference_inputs_fn=reference_inputs_normalize_image_tensor,
test_marks=[
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
Expand Down
23 changes: 22 additions & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
from prototype_common_utils import (
assert_close,
DEFAULT_SQUARE_SPATIAL_SIZE,
make_bounding_boxes,
parametrized_error_message,
)
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
assert output.device == input.device


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

def assert_samples_from_standard_normal(t):
p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
return p_value > 1e-4

image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
mean = image.mean(dim=(1, 2)).tolist()
std = image.std(dim=(1, 2)).tolist()

assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`

Expand Down

0 comments on commit c206a47

Please sign in to comment.