diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 7cf19d39dc9..07c98a67ca2 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -6,7 +6,6 @@ import transforms as T from pycocotools import mask as coco_mask from pycocotools.coco import COCO -from torchvision.datasets import wrap_dataset_for_transforms_v2 def convert_coco_poly_to_mask(segmentations, height, width): @@ -213,6 +212,8 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_m ann_file = os.path.join(root, ann_file) if use_v2: + from torchvision.datasets import wrap_dataset_for_transforms_v2 + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) target_keys = ["boxes", "labels", "image_id"] if with_masks: diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index e02434012f1..6a15dbefb52 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -68,11 +68,6 @@ def _has_valid_annotation(anno): # if more than 1k pixels occupied in the image return sum(obj["area"] for obj in anno) > 1000 - if not isinstance(dataset, torchvision.datasets.CocoDetection): - raise TypeError( - f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" - ) - ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) @@ -86,7 +81,7 @@ def _has_valid_annotation(anno): return dataset -def get_coco(root, image_set, transforms): +def get_coco(root, image_set, transforms, use_v2=False): PATHS = { "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), @@ -94,13 +89,24 @@ def get_coco(root, image_set, transforms): } CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] - transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) - img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + # The 2 "Compose" below achieve the same thing: converting coco detection + # samples into segmentation-compatible samples. They just do it with + # slightly different implementations. We could refactor and unify, but + # keeping them separate helps keeping the v2 version clean + if use_v2: + import v2_extras + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"}) + else: + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) if image_set == "train": dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..abb70d8d0db 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,106 @@ +from collections import defaultdict + import torch -import transforms as T + + +def get_modules(use_v2): + # We need a protected import to avoid the V2 warning in case just V1 is used + if use_v2: + import torchvision.datapoints + import torchvision.transforms.v2 + import v2_extras + + return torchvision.transforms.v2, torchvision.datapoints, v2_extras + else: + import transforms + + return transforms, None, None class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) + def __init__( + self, + *, + base_size, + crop_size, + hflip_prob=0.5, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + backend="pil", + use_v2=False, + ): + T, datapoints, v2_extras = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))] - trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( - [ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), + transforms += [T.RandomHorizontalFlip(hflip_prob)] + + if use_v2: + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))] + + transforms += [T.RandomCrop(crop_size)] + + if backend == "pil": + transforms += [T.PILToTensor()] + + if use_v2: + img_type = datapoints.Image if backend == "datapoint" else torch.Tensor + transforms += [ + T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True) ] - ) - self.transforms = T.Compose(trans) + else: + # No need to explicitly convert masks as they're magically int64 already + transforms += [T.ConvertImageDtype(torch.float)] + + transforms += [T.Normalize(mean=mean, std=std)] + + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) class SegmentationPresetEval: - def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( - [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) + def __init__( + self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False + ): + T, _, _ = get_modules(use_v2) + + transforms = [] + backend = backend.lower() + if backend == "tensor": + transforms += [T.PILToTensor()] + elif backend == "datapoint": + transforms += [T.ToImageTensor()] + elif backend != "pil": + raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") + + if use_v2: + transforms += [T.Resize(size=(base_size, base_size))] + else: + transforms += [T.RandomResize(min_size=base_size, max_size=base_size)] + + if backend == "pil": + # Note: we could just convert to pure tensors even in v2? + transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + self.transforms = T.Compose(transforms) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..7ca4bd1c592 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -14,24 +14,30 @@ from torchvision.transforms import functional as F, InterpolationMode -def get_dataset(dir_path, name, image_set, transform): +def get_dataset(args, is_train): def sbd(*args, **kwargs): + kwargs.pop("use_v2") return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + def voc(*args, **kwargs): + kwargs.pop("use_v2") + return torchvision.datasets.VOCSegmentation(*args, **kwargs) + paths = { - "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), - "voc_aug": (dir_path, sbd, 21), - "coco": (dir_path, get_coco, 21), + "voc": (args.data_path, voc, 21), + "voc_aug": (args.data_path, sbd, 21), + "coco": (args.data_path, get_coco, 21), } - p, ds_fn, num_classes = paths[name] + p, ds_fn, num_classes = paths[args.dataset] - ds = ds_fn(p, image_set=image_set, transforms=transform) + image_set = "train" if is_train else "val" + ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) return ds, num_classes -def get_transform(train, args): - if train: - return presets.SegmentationPresetTrain(base_size=520, crop_size=480) +def get_transform(is_train, args): + if is_train: + return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() @@ -44,7 +50,7 @@ def preprocessing(img, target): return preprocessing else: - return presets.SegmentationPresetEval(base_size=520) + return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) def criterion(inputs, target): @@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): + if args.backend.lower() != "pil" and not args.use_v2: + # TODO: Support tensor backend in V1? + raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.") + if args.use_v2 and args.dataset != "coco": + raise ValueError("v2 is only support supported for coco dataset for now.") + if args.output_dir: utils.mkdir(args.output_dir) @@ -134,8 +146,8 @@ def main(args): else: torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) + dataset, num_classes = get_dataset(args, is_train=True) + dataset_test, _ = get_dataset(args, is_train=False) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -307,6 +319,8 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") + parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..2b3e79b1461 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None): def __call__(self, image, target): size = random.randint(self.min_size, self.max_size) - image = F.resize(image, size) + image = F.resize(image, size, antialias=True) target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) return image, target diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index 4ea24db83ed..cb200f23d76 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -267,9 +267,9 @@ def init_distributed_mode(args): args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py new file mode 100644 index 00000000000..c69827c22e7 --- /dev/null +++ b/references/segmentation/v2_extras.py @@ -0,0 +1,83 @@ +"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1.""" +import torch +from torchvision import datapoints +from torchvision.transforms import v2 + + +class PadIfSmaller(v2.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = v2._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = v2.utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = v2._utils._convert_fill_arg(fill) + + return v2.functional.pad(inpt, padding=params["padding"], fill=fill) + + +class CocoDetectionToVOCSegmentation(v2.Transform): + """Turn samples from datasets.CocoDetection into the same format as VOCSegmentation. + + This is achieved in two steps: + + 1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately, + the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not + present in VOC are dropped and replaced by background. + 2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual + mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where + the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation + mask while pixels that belong to multiple detection masks are marked as invalid. + """ + + COCO_TO_VOC_LABEL_MAP = dict( + zip( + [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72], + range(21), + ) + ) + INVALID_VALUE = 255 + + def _coco_detection_masks_to_voc_segmentation_mask(self, target): + if "masks" not in target: + return None + + instance_masks, instance_labels_coco = target["masks"], target["labels"] + + valid_labels_voc = [ + (idx, label_voc) + for idx, label_coco in enumerate(instance_labels_coco.tolist()) + if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None + ] + + if not valid_labels_voc: + return None + + valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc) + + instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8) + instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8) + + # Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as + # there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step. + segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0) + segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE + + return segmentation_mask + + def forward(self, image, target): + segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) + if segmentation_mask is None: + segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) + + return image, datapoints.Mask(segmentation_mask)