Skip to content

Commit 96e475f

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] extract common utils for prototype transform tests (#6552)
Reviewed By: YosuaMichael Differential Revision: D39426992 fbshipit-source-id: 3cb21ec8f82b5ccbd5184dbbed62fd07c84f65cd
1 parent 065c970 commit 96e475f

File tree

5 files changed

+246
-243
lines changed

5 files changed

+246
-243
lines changed

test/prototype_common_utils.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import functools
2+
import itertools
3+
4+
import PIL.Image
5+
import pytest
6+
7+
import torch
8+
import torch.testing
9+
from torch.nn.functional import one_hot
10+
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
11+
from torchvision.prototype import features
12+
from torchvision.prototype.transforms.functional import to_image_tensor
13+
from torchvision.transforms.functional_tensor import _max_value as get_max_value
14+
15+
16+
class ImagePair(TensorLikePair):
17+
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
18+
return super()._process_inputs(
19+
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
20+
id=id,
21+
allow_subclasses=allow_subclasses,
22+
)
23+
24+
25+
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
26+
27+
28+
class ArgsKwargs:
29+
def __init__(self, *args, **kwargs):
30+
self.args = args
31+
self.kwargs = kwargs
32+
33+
def __iter__(self):
34+
yield self.args
35+
yield self.kwargs
36+
37+
def __str__(self):
38+
def short_repr(obj, max=20):
39+
repr_ = repr(obj)
40+
if len(repr_) <= max:
41+
return repr_
42+
43+
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
44+
45+
return ", ".join(
46+
itertools.chain(
47+
[short_repr(arg) for arg in self.args],
48+
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
49+
)
50+
)
51+
52+
53+
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
54+
55+
56+
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
57+
size = size or torch.randint(16, 33, (2,)).tolist()
58+
59+
try:
60+
num_channels = {
61+
features.ColorSpace.GRAY: 1,
62+
features.ColorSpace.GRAY_ALPHA: 2,
63+
features.ColorSpace.RGB: 3,
64+
features.ColorSpace.RGB_ALPHA: 4,
65+
}[color_space]
66+
except KeyError as error:
67+
raise pytest.UsageError() from error
68+
69+
shape = (*extra_dims, num_channels, *size)
70+
max_value = get_max_value(dtype)
71+
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
72+
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
73+
data[..., -1, :, :] = max_value
74+
return features.Image(data, color_space=color_space)
75+
76+
77+
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
78+
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
79+
80+
81+
def make_images(
82+
sizes=((16, 16), (7, 33), (31, 9)),
83+
color_spaces=(
84+
features.ColorSpace.GRAY,
85+
features.ColorSpace.GRAY_ALPHA,
86+
features.ColorSpace.RGB,
87+
features.ColorSpace.RGB_ALPHA,
88+
),
89+
dtypes=(torch.float32, torch.uint8),
90+
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
91+
):
92+
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
93+
yield make_image(size, color_space=color_space, dtype=dtype)
94+
95+
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
96+
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
97+
98+
99+
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
100+
low, high = torch.broadcast_tensors(
101+
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
102+
)
103+
return torch.stack(
104+
[
105+
torch.randint(low_scalar, high_scalar, (), **kwargs)
106+
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
107+
]
108+
).reshape(low.shape)
109+
110+
111+
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
112+
if isinstance(format, str):
113+
format = features.BoundingBoxFormat[format]
114+
115+
if any(dim == 0 for dim in extra_dims):
116+
return features.BoundingBox(torch.empty(*extra_dims, 4), format=format, image_size=image_size)
117+
118+
height, width = image_size
119+
120+
if format == features.BoundingBoxFormat.XYXY:
121+
x1 = torch.randint(0, width // 2, extra_dims)
122+
y1 = torch.randint(0, height // 2, extra_dims)
123+
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
124+
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
125+
parts = (x1, y1, x2, y2)
126+
elif format == features.BoundingBoxFormat.XYWH:
127+
x = torch.randint(0, width // 2, extra_dims)
128+
y = torch.randint(0, height // 2, extra_dims)
129+
w = randint_with_tensor_bounds(1, width - x)
130+
h = randint_with_tensor_bounds(1, height - y)
131+
parts = (x, y, w, h)
132+
elif format == features.BoundingBoxFormat.CXCYWH:
133+
cx = torch.randint(1, width - 1, ())
134+
cy = torch.randint(1, height - 1, ())
135+
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
136+
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
137+
parts = (cx, cy, w, h)
138+
else:
139+
raise pytest.UsageError()
140+
141+
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
142+
143+
144+
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
145+
146+
147+
def make_bounding_boxes(
148+
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
149+
image_sizes=((32, 32),),
150+
dtypes=(torch.int64, torch.float32),
151+
extra_dims=((0,), (), (4,), (2, 3), (5, 0), (0, 5)),
152+
):
153+
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
154+
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
155+
156+
for format, extra_dims_ in itertools.product(formats, extra_dims):
157+
yield make_bounding_box(format=format, extra_dims=extra_dims_)
158+
159+
160+
def make_label(size=(), *, categories=("category0", "category1")):
161+
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
162+
163+
164+
def make_one_hot_label(*args, **kwargs):
165+
label = make_label(*args, **kwargs)
166+
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
167+
168+
169+
def make_one_hot_labels(
170+
*,
171+
num_categories=(1, 2, 10),
172+
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
173+
):
174+
for num_categories_ in num_categories:
175+
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
176+
177+
for extra_dims_ in extra_dims:
178+
yield make_one_hot_label(extra_dims_)
179+
180+
181+
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
182+
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
183+
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
184+
shape = (*extra_dims, num_objects, *size)
185+
data = make_tensor(shape, low=0, high=2, dtype=dtype)
186+
return features.SegmentationMask(data)
187+
188+
189+
def make_segmentation_masks(
190+
sizes=((16, 16), (7, 33), (31, 9)),
191+
dtypes=(torch.uint8,),
192+
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
193+
num_objects=(1, 0, 10),
194+
):
195+
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
196+
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
197+
198+
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
199+
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)

test/test_prototype_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import torch
99
from common_utils import assert_equal, cpu_and_gpu
10-
from test_prototype_transforms_functional import (
10+
from prototype_common_utils import (
1111
make_bounding_box,
1212
make_bounding_boxes,
1313
make_image,

test/test_prototype_transforms_consistency.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,21 @@
11
import enum
2-
import functools
32
import inspect
4-
import itertools
53

64
import numpy as np
75
import PIL.Image
86
import pytest
97

108
import torch
11-
from test_prototype_transforms_functional import make_images
12-
from torch.testing._comparison import assert_equal as _assert_equal, TensorLikePair
9+
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
1310
from torchvision import transforms as legacy_transforms
1411
from torchvision._utils import sequence_to_str
1512
from torchvision.prototype import features, transforms as prototype_transforms
16-
from torchvision.prototype.transforms.functional import to_image_pil, to_image_tensor
17-
18-
19-
class ImagePair(TensorLikePair):
20-
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
21-
return super()._process_inputs(
22-
*[to_image_tensor(input) if isinstance(input, PIL.Image.Image) else input for input in [actual, expected]],
23-
id=id,
24-
allow_subclasses=allow_subclasses,
25-
)
26-
27-
28-
assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)
13+
from torchvision.prototype.transforms.functional import to_image_pil
2914

3015

3116
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
3217

3318

34-
class ArgsKwargs:
35-
def __init__(self, *args, **kwargs):
36-
self.args = args
37-
self.kwargs = kwargs
38-
39-
def __iter__(self):
40-
yield self.args
41-
yield self.kwargs
42-
43-
def __str__(self):
44-
def short_repr(obj, max=20):
45-
repr_ = repr(obj)
46-
if len(repr_) <= max:
47-
return repr_
48-
49-
return f"{repr_[:max//2]}...{repr_[-(max//2-3):]}"
50-
51-
return ", ".join(
52-
itertools.chain(
53-
[short_repr(arg) for arg in self.args],
54-
[f"{param}={short_repr(kwarg)}" for param, kwarg in self.kwargs.items()],
55-
)
56-
)
57-
58-
5919
class ConsistencyConfig:
6020
def __init__(
6121
self,

0 commit comments

Comments
 (0)