Skip to content

Commit

Permalink
Expose transforms.v2 utils for writing custom transforms. (pytorch#…
Browse files Browse the repository at this point in the history
…8670)

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
venkatram-dev and NicolasHug committed Oct 3, 2024
1 parent 2d8a288 commit d163c95
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
47 changes: 47 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality):
def test_transform_invalid_quality_error(self, quality):
with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
transforms.JPEG(quality=quality)


class TestUtils:
# TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
@pytest.mark.parametrize(
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize(
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_query_size_and_query_chw(self, make_input1, make_input2, query):
size = (32, 64)
input1 = make_input1(size)
input2 = make_input2(size)

if query is transforms.query_chw and not any(
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
for inpt in (input1, input2)
):
return

expected = size if query is transforms.query_size else ((3,) + size)
assert query([input1, input2]) == expected

@pytest.mark.parametrize(
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize(
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_different_sizes(self, make_input1, make_input2, query):
input1 = make_input1((10, 10))
input2 = make_input2((20, 20))
if query is transforms.query_chw and not all(
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
for inpt in (input1, input2)
):
return
with pytest.raises(ValueError, match="Found multiple"):
query([input1, input2])

@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip

0 comments on commit d163c95

Please sign in to comment.