Skip to content

Commit

Permalink
[fbsync] remove spatial_size (#7734)
Browse files Browse the repository at this point in the history
Reviewed By: matteobettini

Differential Revision: D48642265

fbshipit-source-id: 123d2a3157d4536ea9ac25e0192d54307b31ea1e
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 25, 2023
1 parent ae428c4 commit 7804725
Show file tree
Hide file tree
Showing 29 changed files with 440 additions and 491 deletions.
6 changes: 3 additions & 3 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# corresponding image alongside the actual values:

bounding_box = datapoints.BoundingBoxes(
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)
print(bounding_box)

Expand Down Expand Up @@ -108,7 +108,7 @@ def __getitem__(self, item):
target["boxes"] = datapoints.BoundingBoxes(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["labels"] = labels
target["masks"] = datapoints.Mask(masks)
Expand All @@ -129,7 +129,7 @@ def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBoxes(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
canvas_size=F.get_size(img),
)
target["masks"] = datapoints.Mask(target["masks"])
return img, target
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_data():
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))

bounding_boxes = datapoints.BoundingBoxes(
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
)

return path, image, bounding_boxes, masks, labels
Expand Down
86 changes: 26 additions & 60 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def load(self, device="cpu"):
)


def _parse_spatial_size(size, *, name="size"):
def _parse_canvas_size(size, *, name="size"):
if size == "random":
raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0:
Expand Down Expand Up @@ -467,12 +467,13 @@ def load(self, device):

@dataclasses.dataclass
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.canvas_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
Expand Down Expand Up @@ -538,7 +539,7 @@ def make_image_loader(
):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -578,7 +579,7 @@ def make_image_loaders(
def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device, memory_format):
Expand Down Expand Up @@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation(
class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False)

def __post_init__(self):
self.canvas_size = self.spatial_size


def make_bounding_box(
size=None,
canvas_size=DEFAULT_SIZE,
*,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=None,
batch_dims=(),
dtype=None,
device="cpu",
):
"""
size: Size of the actual bounding box, i.e.
- (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBoxes
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g.
.. code::
image = make_image=(size=size)
bounding_boxes = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g.
.. code::
image = make_image=()
bounding_boxes = make_bounding_box()
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
"""

def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
Expand All @@ -668,28 +646,16 @@ def sample_position(values, max_value):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

if spatial_size is None:
if size is None:
spatial_size = DEFAULT_SIZE
else:
height, width = size
height_margin, width_margin = torch.randint(10, (2,)).tolist()
spatial_size = (height + height_margin, width + width_margin)

dtype = dtype or torch.float32

if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)

if size is None:
h, w = [torch.randint(1, s, batch_dims) for s in spatial_size]
else:
h, w = [torch.full(batch_dims, s, dtype=torch.int) for s in size]

y = sample_position(h, spatial_size[0])
x = sample_position(w, spatial_size[1])
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])

if format is datapoints.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
Expand All @@ -706,37 +672,37 @@ def sample_position(values, max_value):
raise ValueError(f"Format {format} is not supported")

return datapoints.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)


def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")

def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
if num_coordinates != 4:
raise pytest.UsageError()

return make_bounding_box(
format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
)

return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)


def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
yield make_bounding_box_loader(**params, canvas_size=canvas_size)


make_bounding_boxes = from_loaders(make_bounding_box_loaders)
Expand All @@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp

def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)

def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape
Expand Down Expand Up @@ -802,15 +768,15 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
spatial_size = _parse_spatial_size(size)
canvas_size = _parse_canvas_size(size)

def fn(shape, dtype, device):
*batch_dims, height, width = shape
return make_segmentation_mask(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)

return MaskLoader(fn, shape=(*extra_dims, *spatial_size), dtype=dtype)
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)


def make_segmentation_mask_loaders(
Expand Down Expand Up @@ -860,7 +826,7 @@ def make_video_loader(
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_spatial_size(size)
size = _parse_canvas_size(size)

def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape
Expand Down
4 changes: 2 additions & 2 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_mask_instance(data):
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBoxes(data, format=format, spatial_size=(32, 32))
bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_wrap_like():
[
datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)),
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
],
)
Expand Down
34 changes: 17 additions & 17 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test__copy_paste(self, label_type):
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": BoundingBoxes(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(masks),
"labels": label_type(labels),
Expand All @@ -179,7 +179,7 @@ def test__copy_paste(self, label_type):
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": BoundingBoxes(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(paste_masks),
"labels": label_type(paste_labels),
Expand Down Expand Up @@ -210,13 +210,13 @@ class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
spatial_size = (11, 5)
canvas_size = (11, 5)

transform = transforms.FixedSizeCrop(size=crop_size)

flat_inputs = [
make_image(size=spatial_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape),
make_image(size=canvas_size, color_space="RGB"),
make_bounding_box(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
]
params = transform._get_params(flat_inputs)

Expand Down Expand Up @@ -295,7 +295,7 @@ def test__transform(self, mocker, needs):

def test__transform_culling(self, mocker):
batch_size = 10
spatial_size = (10, 10)
canvas_size = (10, 10)

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
Expand All @@ -304,17 +304,17 @@ def test__transform_culling(self, mocker):
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=is_valid,
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,))
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
Expand All @@ -334,23 +334,23 @@ def test__transform_culling(self, mocker):

def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3
spatial_size = (10, 10)
canvas_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=spatial_size[0],
width=spatial_size[1],
height=canvas_size[0],
width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")

Expand Down Expand Up @@ -496,7 +496,7 @@ def make_datapoints():

pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand All @@ -505,7 +505,7 @@ def make_datapoints():

tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand All @@ -514,7 +514,7 @@ def make_datapoints():

datapoint_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
Expand Down
Loading

0 comments on commit 7804725

Please sign in to comment.