Skip to content

Commit 993f693

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into cutmix-mixup
2 parents 5e02675 + 8071c17 commit 993f693

23 files changed

+547
-355
lines changed

docs/source/transforms.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ Conversion
234234
v2.PILToTensor
235235
v2.ToImageTensor
236236
ConvertImageDtype
237-
v2.ConvertDtype
238237
v2.ConvertImageDtype
239238
v2.ToDtype
240239
v2.ConvertBoundingBoxFormat

gallery/plot_transforms_v2_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def show(sample):
2929
image, target = sample
3030
if isinstance(image, PIL.Image.Image):
3131
image = F.to_image_tensor(image)
32-
image = F.convert_dtype(image, torch.uint8)
32+
image = F.to_dtype(image, torch.uint8, scale=True)
3333
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
3434

3535
fig, ax = plt.subplots()

references/detection/coco_utils.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import os
32

43
import torch
@@ -7,25 +6,6 @@
76
import transforms as T
87
from pycocotools import mask as coco_mask
98
from pycocotools.coco import COCO
10-
from torchvision.datasets import wrap_dataset_for_transforms_v2
11-
12-
13-
class FilterAndRemapCocoCategories:
14-
def __init__(self, categories, remap=True):
15-
self.categories = categories
16-
self.remap = remap
17-
18-
def __call__(self, image, target):
19-
anno = target["annotations"]
20-
anno = [obj for obj in anno if obj["category_id"] in self.categories]
21-
if not self.remap:
22-
target["annotations"] = anno
23-
return image, target
24-
anno = copy.deepcopy(anno)
25-
for obj in anno:
26-
obj["category_id"] = self.categories.index(obj["category_id"])
27-
target["annotations"] = anno
28-
return image, target
299

3010

3111
def convert_coco_poly_to_mask(segmentations, height, width):
@@ -219,7 +199,7 @@ def __getitem__(self, idx):
219199
return img, target
220200

221201

222-
def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
202+
def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
223203
anno_file_template = "{}_{}2017.json"
224204
PATHS = {
225205
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
@@ -232,10 +212,15 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
232212
ann_file = os.path.join(root, ann_file)
233213

234214
if use_v2:
215+
from torchvision.datasets import wrap_dataset_for_transforms_v2
216+
235217
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
236-
# TODO: need to update target_keys to handle masks for segmentation!
237-
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
218+
target_keys = ["boxes", "labels", "image_id"]
219+
if with_masks:
220+
target_keys += ["masks"]
221+
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
238222
else:
223+
# TODO: handle with_masks for V1?
239224
t = [ConvertCocoPolysToMask()]
240225
if transforms is not None:
241226
t.append(transforms)
@@ -249,9 +234,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
249234
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
250235

251236
return dataset
252-
253-
254-
def get_coco_kp(root, image_set, transforms, use_v2=False):
255-
if use_v2:
256-
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
257-
return get_coco(root, image_set, transforms, mode="person_keypoints")

references/detection/train.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torchvision.models.detection
2929
import torchvision.models.detection.mask_rcnn
3030
import utils
31-
from coco_utils import get_coco, get_coco_kp
31+
from coco_utils import get_coco
3232
from engine import evaluate, train_one_epoch
3333
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
3434
from torchvision.transforms import InterpolationMode
@@ -42,10 +42,16 @@ def copypaste_collate_fn(batch):
4242

4343
def get_dataset(is_train, args):
4444
image_set = "train" if is_train else "val"
45-
paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)}
46-
p, ds_fn, num_classes = paths[args.dataset]
47-
48-
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
45+
num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
46+
with_masks = "mask" in args.model
47+
ds = get_coco(
48+
root=args.data_path,
49+
image_set=image_set,
50+
transforms=get_transform(is_train, args),
51+
mode=mode,
52+
use_v2=args.use_v2,
53+
with_masks=with_masks,
54+
)
4955
return ds, num_classes
5056

5157

@@ -68,7 +74,12 @@ def get_args_parser(add_help=True):
6874
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
6975

7076
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
71-
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
77+
parser.add_argument(
78+
"--dataset",
79+
default="coco",
80+
type=str,
81+
help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
82+
)
7283
parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
7384
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
7485
parser.add_argument(
@@ -171,6 +182,12 @@ def get_args_parser(add_help=True):
171182
def main(args):
172183
if args.backend.lower() == "datapoint" and not args.use_v2:
173184
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")
185+
if args.dataset not in ("coco", "coco_kp"):
186+
raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
187+
if "keypoint" in args.model and args.dataset != "coco_kp":
188+
raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
189+
if args.dataset == "coco_kp" and args.use_v2:
190+
raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
174191

175192
if args.output_dir:
176193
utils.mkdir(args.output_dir)

references/segmentation/coco_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
6868
# if more than 1k pixels occupied in the image
6969
return sum(obj["area"] for obj in anno) > 1000
7070

71-
if not isinstance(dataset, torchvision.datasets.CocoDetection):
72-
raise TypeError(
73-
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74-
)
75-
7671
ids = []
7772
for ds_idx, img_id in enumerate(dataset.ids):
7873
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
@@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
8681
return dataset
8782

8883

89-
def get_coco(root, image_set, transforms):
84+
def get_coco(root, image_set, transforms, use_v2=False):
9085
PATHS = {
9186
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
9287
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
9388
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
9489
}
9590
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
9691

97-
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
98-
9992
img_folder, ann_file = PATHS[image_set]
10093
img_folder = os.path.join(root, img_folder)
10194
ann_file = os.path.join(root, ann_file)
10295

103-
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
96+
# The 2 "Compose" below achieve the same thing: converting coco detection
97+
# samples into segmentation-compatible samples. They just do it with
98+
# slightly different implementations. We could refactor and unify, but
99+
# keeping them separate helps keeping the v2 version clean
100+
if use_v2:
101+
import v2_extras
102+
from torchvision.datasets import wrap_dataset_for_transforms_v2
103+
104+
transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
105+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
106+
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
107+
else:
108+
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
109+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
104110

105111
if image_set == "train":
106112
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)

references/segmentation/presets.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,106 @@
1+
from collections import defaultdict
2+
13
import torch
2-
import transforms as T
4+
5+
6+
def get_modules(use_v2):
7+
# We need a protected import to avoid the V2 warning in case just V1 is used
8+
if use_v2:
9+
import torchvision.datapoints
10+
import torchvision.transforms.v2
11+
import v2_extras
12+
13+
return torchvision.transforms.v2, torchvision.datapoints, v2_extras
14+
else:
15+
import transforms
16+
17+
return transforms, None, None
318

419

520
class SegmentationPresetTrain:
6-
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
7-
min_size = int(0.5 * base_size)
8-
max_size = int(2.0 * base_size)
21+
def __init__(
22+
self,
23+
*,
24+
base_size,
25+
crop_size,
26+
hflip_prob=0.5,
27+
mean=(0.485, 0.456, 0.406),
28+
std=(0.229, 0.224, 0.225),
29+
backend="pil",
30+
use_v2=False,
31+
):
32+
T, datapoints, v2_extras = get_modules(use_v2)
33+
34+
transforms = []
35+
backend = backend.lower()
36+
if backend == "datapoint":
37+
transforms.append(T.ToImageTensor())
38+
elif backend == "tensor":
39+
transforms.append(T.PILToTensor())
40+
elif backend != "pil":
41+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
42+
43+
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]
944

10-
trans = [T.RandomResize(min_size, max_size)]
1145
if hflip_prob > 0:
12-
trans.append(T.RandomHorizontalFlip(hflip_prob))
13-
trans.extend(
14-
[
15-
T.RandomCrop(crop_size),
16-
T.PILToTensor(),
17-
T.ConvertImageDtype(torch.float),
18-
T.Normalize(mean=mean, std=std),
46+
transforms += [T.RandomHorizontalFlip(hflip_prob)]
47+
48+
if use_v2:
49+
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
50+
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
51+
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
52+
53+
transforms += [T.RandomCrop(crop_size)]
54+
55+
if backend == "pil":
56+
transforms += [T.PILToTensor()]
57+
58+
if use_v2:
59+
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
60+
transforms += [
61+
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
1962
]
20-
)
21-
self.transforms = T.Compose(trans)
63+
else:
64+
# No need to explicitly convert masks as they're magically int64 already
65+
transforms += [T.ConvertImageDtype(torch.float)]
66+
67+
transforms += [T.Normalize(mean=mean, std=std)]
68+
69+
self.transforms = T.Compose(transforms)
2270

2371
def __call__(self, img, target):
2472
return self.transforms(img, target)
2573

2674

2775
class SegmentationPresetEval:
28-
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29-
self.transforms = T.Compose(
30-
[
31-
T.RandomResize(base_size, base_size),
32-
T.PILToTensor(),
33-
T.ConvertImageDtype(torch.float),
34-
T.Normalize(mean=mean, std=std),
35-
]
36-
)
76+
def __init__(
77+
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
78+
):
79+
T, _, _ = get_modules(use_v2)
80+
81+
transforms = []
82+
backend = backend.lower()
83+
if backend == "tensor":
84+
transforms += [T.PILToTensor()]
85+
elif backend == "datapoint":
86+
transforms += [T.ToImageTensor()]
87+
elif backend != "pil":
88+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
89+
90+
if use_v2:
91+
transforms += [T.Resize(size=(base_size, base_size))]
92+
else:
93+
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]
94+
95+
if backend == "pil":
96+
# Note: we could just convert to pure tensors even in v2?
97+
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
98+
99+
transforms += [
100+
T.ConvertImageDtype(torch.float),
101+
T.Normalize(mean=mean, std=std),
102+
]
103+
self.transforms = T.Compose(transforms)
37104

38105
def __call__(self, img, target):
39106
return self.transforms(img, target)

0 commit comments

Comments
 (0)