Skip to content

Commit

Permalink
add transform noop test
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jun 28, 2023
1 parent 5dd4e53 commit 95c80d2
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def test_dispatcher_signature(self, kernel, input_type):
def test_transform(self, input_type, device):
input = self._make_input(input_type, device=device)

check_transform(transforms.RandomHorizontalFlip, input)
check_transform(transforms.RandomHorizontalFlip, input, p=1)

@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
Expand Down Expand Up @@ -846,3 +846,17 @@ def test_bounding_box_correctness(self, format, fn):
expected = self._reference_horizontal_flip_bounding_box(bounding_box)

torch.testing.assert_close(actual, expected)

@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, input_type, device):
input = self._make_input(input_type, device=device)

transform = transforms.RandomHorizontalFlip(p=0)

output = transform(input)

assert_equal(output, input)

0 comments on commit 95c80d2

Please sign in to comment.