Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup spatial_size -> canvas_size #7783

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def load(self, device="cpu"):
)


def _parse_canvas_size(size, *, name="size"):
def _parse_size(size, *, name="size"):
if size == "random":
raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0:
Expand Down Expand Up @@ -478,13 +478,13 @@ def load(self, device):

@dataclasses.dataclass
class ImageLoader(TensorLoader):
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)

def __post_init__(self):
self.canvas_size = self.canvas_size = self.shape[-2:]
self.spatial_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
Expand Down Expand Up @@ -550,7 +550,7 @@ def make_image_loader(
):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_canvas_size(size)
size = _parse_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -590,7 +590,7 @@ def make_image_loaders(
def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_canvas_size(size)
size = _parse_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -687,33 +687,33 @@ def sample_position(values, max_value):
)


def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")
spatial_size = _parse_size(spatial_size, name="canvas_size")

def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
if num_coordinates != 4:
raise pytest.UsageError()

return make_bounding_box(
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
)

return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)


def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, canvas_size=canvas_size)
yield make_bounding_box_loader(**params, spatial_size=spatial_size)


make_bounding_boxes = from_loaders(make_bounding_box_loaders)
Expand All @@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp

def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_canvas_size(size)
size = _parse_size(size)

def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape
Expand Down Expand Up @@ -779,15 +779,15 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
canvas_size = _parse_canvas_size(size)
size = _parse_size(size)

def fn(shape, dtype, device):
*batch_dims, height, width = shape
return make_segmentation_mask(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)

return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)


def make_segmentation_mask_loaders(
Expand Down Expand Up @@ -841,7 +841,7 @@ def make_video_loader(
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_canvas_size(size)
size = _parse_size(size)

def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
bboxes = bboxes.to(device)

output_bboxes = F.perspective_bounding_boxes(
Expand Down