Skip to content

Commit

Permalink
cleanup v2 tests (#7812)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 9, 2023
1 parent 5d8d61a commit 6b02079
Showing 1 changed file with 25 additions and 85 deletions.
110 changes: 25 additions & 85 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,7 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)


def check_dispatcher(
dispatcher,
# TODO: remove this parameter
kernel,
input,
*args,
check_scripted_smoke=True,
**kwargs,
):
def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs):
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
Expand Down Expand Up @@ -516,20 +508,12 @@ def test_kernel_video(self):

@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.resize_image_tensor, make_image_tensor),
(F.resize_image_pil, make_image_pil),
(F.resize_image_tensor, make_image),
(F.resize_bounding_boxes, make_bounding_box),
(F.resize_mask, make_segmentation_mask),
(F.resize_video, make_video),
],
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, size, kernel, make_input):
def test_dispatcher(self, size, make_input):
check_dispatcher(
F.resize,
kernel,
make_input(self.INPUT_SIZE),
size=size,
antialias=True,
Expand Down Expand Up @@ -805,18 +789,11 @@ def test_kernel_video(self):
check_kernel(F.horizontal_flip_video, make_video())

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.horizontal_flip_image_tensor, make_image_tensor),
(F.horizontal_flip_image_pil, make_image_pil),
(F.horizontal_flip_image_tensor, make_image),
(F.horizontal_flip_bounding_boxes, make_bounding_box),
(F.horizontal_flip_mask, make_segmentation_mask),
(F.horizontal_flip_video, make_video),
],
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.horizontal_flip, kernel, make_input())
def test_dispatcher(self, make_input):
check_dispatcher(F.horizontal_flip, make_input())

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down Expand Up @@ -988,18 +965,11 @@ def test_kernel_video(self):
self._check_kernel(F.affine_video, make_video())

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.affine_image_tensor, make_image_tensor),
(F.affine_image_pil, make_image_pil),
(F.affine_image_tensor, make_image),
(F.affine_bounding_boxes, make_bounding_box),
(F.affine_mask, make_segmentation_mask),
(F.affine_video, make_video),
],
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.affine, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS)
def test_dispatcher(self, make_input):
check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down Expand Up @@ -1284,18 +1254,11 @@ def test_kernel_video(self):
check_kernel(F.vertical_flip_video, make_video())

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.vertical_flip_image_tensor, make_image_tensor),
(F.vertical_flip_image_pil, make_image_pil),
(F.vertical_flip_image_tensor, make_image),
(F.vertical_flip_bounding_boxes, make_bounding_box),
(F.vertical_flip_mask, make_segmentation_mask),
(F.vertical_flip_video, make_video),
],
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.vertical_flip, kernel, make_input())
def test_dispatcher(self, make_input):
check_dispatcher(F.vertical_flip, make_input())

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down Expand Up @@ -1441,18 +1404,11 @@ def test_kernel_video(self):
check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS)

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.rotate_image_tensor, make_image_tensor),
(F.rotate_image_pil, make_image_pil),
(F.rotate_image_tensor, make_image),
(F.rotate_bounding_boxes, make_bounding_box),
(F.rotate_mask, make_segmentation_mask),
(F.rotate_video, make_video),
],
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.rotate, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS)
def test_dispatcher(self, make_input):
check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down Expand Up @@ -1711,22 +1667,14 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
scale=scale,
)

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.to_dtype_image_tensor, make_image_tensor),
(F.to_dtype_image_tensor, make_image),
(F.to_dtype_video, make_video),
],
)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale):
def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale):
check_dispatcher(
F.to_dtype,
kernel,
make_input(dtype=input_dtype, device=device),
dtype=output_dtype,
scale=scale,
Expand Down Expand Up @@ -1890,17 +1838,9 @@ class TestAdjustBrightness:
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.adjust_brightness_image_tensor, make_image_tensor),
(F.adjust_brightness_image_pil, make_image_pil),
(F.adjust_brightness_image_tensor, make_image),
(F.adjust_brightness_video, make_video),
],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.adjust_brightness, kernel, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
def test_dispatcher(self, make_input):
check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down

0 comments on commit 6b02079

Please sign in to comment.