Skip to content

Commit

Permalink
Fixing more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Oct 10, 2022
1 parent 07e7e25 commit 0e2240c
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 123 deletions.
50 changes: 25 additions & 25 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ def load(self, device="cpu"):
return args, kwargs


DEFAULT_SQUARE_IMAGE_SIZE = 15
DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33)
DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9)
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random")
DEFAULT_SQUARE_SPATIAL_SIZE = 15
DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33)
DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9)
DEFAULT_SPATIAL_SIZES = (DEFAULT_LANDSCAPE_SPATIAL_SIZE, DEFAULT_PORTRAIT_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE, "random")


def _parse_image_size(size, *, name="size"):
def _parse_spatial_size(size, *, name="size"):
if size == "random":
return tuple(torch.randint(15, 33, (2,)).tolist())
elif isinstance(size, int) and size > 0:
Expand Down Expand Up @@ -246,11 +246,11 @@ def load(self, device):
@dataclasses.dataclass
class ImageLoader(TensorLoader):
color_space: features.ColorSpace
image_size: Tuple[int, int] = dataclasses.field(init=False)
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)

def __post_init__(self):
self.image_size = self.shape[-2:]
self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3]


Expand All @@ -277,7 +277,7 @@ def make_image_loader(
dtype=torch.float32,
constant_alpha=True,
):
size = _parse_image_size(size)
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
Expand All @@ -295,7 +295,7 @@ def fn(shape, dtype, device):

def make_image_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
Expand All @@ -316,7 +316,7 @@ def make_image_loaders(
@dataclasses.dataclass
class BoundingBoxLoader(TensorLoader):
format: features.BoundingBoxFormat
image_size: Tuple[int, int]
spatial_size: Tuple[int, int]


def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
Expand All @@ -331,7 +331,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
).reshape(low.shape)


def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
if format not in {
Expand All @@ -341,7 +341,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
}:
raise pytest.UsageError(f"Can't make bounding box in format {format}")

image_size = _parse_image_size(image_size, name="image_size")
spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")

def fn(shape, dtype, device):
*extra_dims, num_coordinates = shape
Expand All @@ -350,10 +350,10 @@ def fn(shape, dtype, device):

if any(dim == 0 for dim in extra_dims):
return features.BoundingBox(
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=image_size
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
)

height, width = image_size
height, width = spatial_size

if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
Expand All @@ -375,10 +375,10 @@ def fn(shape, dtype, device):
parts = (cx, cy, w, h)

return features.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=image_size
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
)

return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, image_size=image_size)
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)


make_bounding_box = from_loader(make_bounding_box_loader)
Expand All @@ -388,11 +388,11 @@ def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(features.BoundingBoxFormat),
image_size="random",
spatial_size="random",
dtypes=(torch.float32, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, image_size=image_size)
yield make_bounding_box_loader(**params, spatial_size=spatial_size)


make_bounding_boxes = from_loaders(make_bounding_box_loaders)
Expand Down Expand Up @@ -475,7 +475,7 @@ class MaskLoader(TensorLoader):

def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_image_size(size)
size = _parse_spatial_size(size)
num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects

def fn(shape, dtype, device):
Expand All @@ -489,7 +489,7 @@ def fn(shape, dtype, device):


def make_detection_mask_loaders(
sizes=DEFAULT_IMAGE_SIZES,
sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
Expand All @@ -503,7 +503,7 @@ def make_detection_mask_loaders(

def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = _parse_image_size(size)
size = _parse_spatial_size(size)
num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories

def fn(shape, dtype, device):
Expand All @@ -518,7 +518,7 @@ def fn(shape, dtype, device):

def make_segmentation_mask_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
sizes=DEFAULT_SPATIAL_SIZES,
num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
Expand All @@ -532,7 +532,7 @@ def make_segmentation_mask_loaders(

def make_mask_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"),
num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
Expand All @@ -559,7 +559,7 @@ def make_video_loader(
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_image_size(size)
size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames

def fn(shape, dtype, device):
Expand All @@ -576,7 +576,7 @@ def fn(shape, dtype, device):

def make_video_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
Expand Down
48 changes: 24 additions & 24 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def sample_inputs_horizontal_flip_bounding_box():
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)


Expand Down Expand Up @@ -185,9 +185,9 @@ def sample_inputs_horizontal_flip_video():
)


def _get_resize_sizes(image_size):
height, width = image_size
length = max(image_size)
def _get_resize_sizes(spatial_size):
height, width = spatial_size
length = max(spatial_size)
yield length
yield [length]
yield (length,)
Expand Down Expand Up @@ -252,7 +252,7 @@ def reference_inputs_resize_image_tensor():
def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.spatial_size)
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)


def sample_inputs_resize_mask():
Expand Down Expand Up @@ -394,7 +394,7 @@ def sample_inputs_affine_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.spatial_size,
spatial_size=bounding_box_loader.spatial_size,
**affine_params,
)

Expand Down Expand Up @@ -422,9 +422,9 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix


def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center=None):
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in image_size[::-1]]
center = [s * 0.5 for s in spatial_size[::-1]]

def transform(bbox):
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
Expand Down Expand Up @@ -473,7 +473,7 @@ def reference_inputs_affine_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.spatial_size,
spatial_size=bounding_box_loader.spatial_size,
**affine_kwargs,
)

Expand Down Expand Up @@ -650,7 +650,7 @@ def sample_inputs_vertical_flip_bounding_box():
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)


Expand Down Expand Up @@ -729,7 +729,7 @@ def sample_inputs_rotate_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.spatial_size,
spatial_size=bounding_box_loader.spatial_size,
angle=_ROTATE_ANGLES[0],
)

Expand Down Expand Up @@ -1001,7 +1001,7 @@ def sample_inputs_pad_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.spatial_size,
spatial_size=bounding_box_loader.spatial_size,
padding=padding,
padding_mode="constant",
)
Expand Down Expand Up @@ -1131,8 +1131,8 @@ def sample_inputs_perspective_video():
)


def _get_elastic_displacement(image_size):
return torch.rand(1, *image_size, 2)
def _get_elastic_displacement(spatial_size):
return torch.rand(1, *spatial_size, 2)


def sample_inputs_elastic_image_tensor():
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def sample_inputs_elastic_video():
)


_CENTER_CROP_IMAGE_SIZES = [(16, 16), (7, 33), (31, 9)]
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]


Expand All @@ -1231,7 +1231,7 @@ def sample_inputs_center_crop_image_tensor():

def reference_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product(
make_image_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
):
yield ArgsKwargs(image_loader, output_size=output_size)

Expand All @@ -1241,7 +1241,7 @@ def sample_inputs_center_crop_bounding_box():
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
image_size=bounding_box_loader.spatial_size,
spatial_size=bounding_box_loader.spatial_size,
output_size=output_size,
)

Expand All @@ -1254,7 +1254,7 @@ def sample_inputs_center_crop_mask():

def reference_inputs_center_crop_mask():
for mask_loader, output_size in itertools.product(
make_mask_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
):
yield ArgsKwargs(mask_loader, output_size=output_size)

Expand Down Expand Up @@ -1820,7 +1820,7 @@ def sample_inputs_adjust_saturation_video():
def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)


Expand All @@ -1834,7 +1834,7 @@ def sample_inputs_clamp_bounding_box():
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


def _get_five_ten_crop_image_size(size):
def _get_five_ten_crop_spatial_size(size):
if isinstance(size, int):
crop_height = crop_width = size
elif len(size) == 1:
Expand All @@ -1847,28 +1847,28 @@ def _get_five_ten_crop_image_size(size):
def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
sizes=[_get_five_ten_crop_spatial_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
):
yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size)


def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
sizes=[_get_five_ten_crop_spatial_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


Expand Down
Loading

0 comments on commit 0e2240c

Please sign in to comment.