diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index fa04d5deb0c..9028b304c1b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -173,15 +173,7 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def check_dispatcher( - dispatcher, - # TODO: remove this parameter - kernel, - input, - *args, - check_scripted_smoke=True, - **kwargs, -): +def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs): unknown_input = object() with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): dispatcher(unknown_input, *args, **kwargs) @@ -516,20 +508,12 @@ def test_kernel_video(self): @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.resize_image_tensor, make_image_tensor), - (F.resize_image_pil, make_image_pil), - (F.resize_image_tensor, make_image), - (F.resize_bounding_boxes, make_bounding_box), - (F.resize_mask, make_segmentation_mask), - (F.resize_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, size, kernel, make_input): + def test_dispatcher(self, size, make_input): check_dispatcher( F.resize, - kernel, make_input(self.INPUT_SIZE), size=size, antialias=True, @@ -805,18 +789,11 @@ def test_kernel_video(self): check_kernel(F.horizontal_flip_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.horizontal_flip_image_tensor, make_image_tensor), - (F.horizontal_flip_image_pil, make_image_pil), - (F.horizontal_flip_image_tensor, make_image), - (F.horizontal_flip_bounding_boxes, make_bounding_box), - (F.horizontal_flip_mask, make_segmentation_mask), - (F.horizontal_flip_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.horizontal_flip, kernel, make_input()) + def test_dispatcher(self, make_input): + check_dispatcher(F.horizontal_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -988,18 +965,11 @@ def test_kernel_video(self): self._check_kernel(F.affine_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.affine_image_tensor, make_image_tensor), - (F.affine_image_pil, make_image_pil), - (F.affine_image_tensor, make_image), - (F.affine_bounding_boxes, make_bounding_box), - (F.affine_mask, make_segmentation_mask), - (F.affine_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.affine, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, make_input): + check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1284,18 +1254,11 @@ def test_kernel_video(self): check_kernel(F.vertical_flip_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.vertical_flip_image_tensor, make_image_tensor), - (F.vertical_flip_image_pil, make_image_pil), - (F.vertical_flip_image_tensor, make_image), - (F.vertical_flip_bounding_boxes, make_bounding_box), - (F.vertical_flip_mask, make_segmentation_mask), - (F.vertical_flip_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.vertical_flip, kernel, make_input()) + def test_dispatcher(self, make_input): + check_dispatcher(F.vertical_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1441,18 +1404,11 @@ def test_kernel_video(self): check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.rotate_image_tensor, make_image_tensor), - (F.rotate_image_pil, make_image_pil), - (F.rotate_image_tensor, make_image), - (F.rotate_bounding_boxes, make_bounding_box), - (F.rotate_mask, make_segmentation_mask), - (F.rotate_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.rotate, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, make_input): + check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1711,22 +1667,14 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca scale=scale, ) - @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.to_dtype_image_tensor, make_image_tensor), - (F.to_dtype_image_tensor, make_image), - (F.to_dtype_video, make_video), - ], - ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale): + def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale): check_dispatcher( F.to_dtype, - kernel, make_input(dtype=input_dtype, device=device), dtype=output_dtype, scale=scale, @@ -1890,17 +1838,9 @@ class TestAdjustBrightness: def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) - @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.adjust_brightness_image_tensor, make_image_tensor), - (F.adjust_brightness_image_pil, make_image_pil), - (F.adjust_brightness_image_tensor, make_image), - (F.adjust_brightness_video, make_video), - ], - ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.adjust_brightness, kernel, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + def test_dispatcher(self, make_input): + check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @pytest.mark.parametrize( ("kernel", "input_type"),