From dfba94d17396a509d1334448032fba01cf468edb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 19 Sep 2023 10:59:16 +0200 Subject: [PATCH] port tests for RandomPhotometricDistort --- test/test_transforms_v2.py | 1 - test/test_transforms_v2_refactored.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5ab35fc873b..0d6727a7dd6 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -123,7 +123,6 @@ class TestSmoke: (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), (transforms.RandomChannelPermutation(), None), - (transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None), (transforms.CenterCrop([16, 16]), None), diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 59d30d482e2..cb4c674cafb 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -3945,3 +3945,25 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue): mae = (actual.float() - expected.float()).abs().mean() assert mae < 2 + + +class TestRandomPhotometricDistort: + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_video], + ) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform(self, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + + check_transform( + transforms.RandomPhotometricDistort( + brightness=(0.3, 0.4), contrast=(0.5, 0.6), saturation=(0.7, 0.8), hue=(-0.1, 0.2), p=1 + ), + make_input(dtype=dtype, device=device), + )