Skip to content

Commit

Permalink
port tests for RandomPhotometricDistort (#7973)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Sep 26, 2023
1 parent ace9221 commit 997384c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class TestSmoke:
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
Expand Down
25 changes: 25 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -4040,3 +4040,28 @@ def test_transform_params_correctness(self, side_range, make_input, device):
assert 0 <= padding[1] <= (side_range[1] - 1) * height
assert 0 <= padding[2] <= (side_range[1] - 1) * width
assert 0 <= padding[3] <= (side_range[1] - 1) * height


class TestRandomPhotometricDistort:
# Tests are light because this largely relies on the already tested
# `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, dtype, device):
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
pytest.skip(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)

check_transform(
transforms.RandomPhotometricDistort(
brightness=(0.3, 0.4), contrast=(0.5, 0.6), saturation=(0.7, 0.8), hue=(-0.1, 0.2), p=1
),
make_input(dtype=dtype, device=device),
)

0 comments on commit 997384c

Please sign in to comment.