diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 893c51d99a3..ceab6a4f493 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -207,7 +207,6 @@ def get_coco(root, image_set, transforms, mode="instances"): img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) dataset = wrap_dataset_for_transforms_v2(dataset) diff --git a/references/detection/presets.py b/references/detection/presets.py index bd7d12de7fe..20e134f6cfa 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,11 +1,12 @@ from collections import defaultdict import torch -import transforms as reference_transforms import torchvision +import transforms as reference_transforms + torchvision.disable_beta_transforms_warning() -from torchvision import datapoints import torchvision.transforms.v2 as T +from torchvision import datapoints # TODO: Should we provide a transforms that filters-out keys? @@ -64,7 +65,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104 transforms += [ T.ConvertImageDtype(torch.float), T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY), - T.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]) # TODO: sad it's not the default! + T.SanitizeBoundingBoxes( + labels_getter=lambda sample: sample[1]["labels"] + ), # TODO: sad it's not the default! ] super().__init__(transforms) @@ -78,9 +81,8 @@ def __init__(self, backend="pil"): backend = backend.lower() if backend == "tensor": transforms.append(T.PILToTensor()) - else: # for datapoint **and** PIL + else: # for datapoint **and** PIL transforms.append(T.ToImageTensor()) - transforms.append(T.ConvertImageDtype(torch.float)) super().__init__(transforms) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index ed02ae660e4..92a0908ac8a 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,39 +1,82 @@ +from collections import defaultdict + import torch -import transforms as T +import torchvision + +torchvision.disable_beta_transforms_warning() +import torchvision.transforms.v2 as T +from torchvision import datapoints +from transforms import PadIfSmaller, WrapIntoFeatures + + +class SegmentationPresetTrain(T.Compose): + def __init__( + self, + *, + base_size, + crop_size, + hflip_prob=0.5, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + backend="pil", + ): + + transforms = [] + transforms.append(WrapIntoFeatures()) -class SegmentationPresetTrain: - def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - min_size = int(0.5 * base_size) - max_size = int(2.0 * base_size) + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + else: + assert backend == "pil" + + transforms.append(T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True)) - trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: - trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend( - [ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - self.transforms = T.Compose(trans) - - def __call__(self, img, target): - return self.transforms(img, target) - - -class SegmentationPresetEval: - def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose( - [ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img, target): - return self.transforms(img, target) + transforms.append(T.RandomHorizontalFlip(hflip_prob)) + + transforms += [ + # We need a custom pad transform here, since the padding we want to perform here is fundamentally + # different from the padding in `RandomCrop` if `pad_if_needed=True`. + PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), + T.RandomCrop(crop_size), + ] + + if backend == "pil": + transforms.append(T.ToImageTensor()) + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + + super().__init__(transforms) + + +class SegmentationPresetEval(T.Compose): + def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil"): + transforms = [] + + transforms.append(WrapIntoFeatures()) + + backend = backend.lower() + if backend == "datapoint": + transforms.append(T.ToImageTensor()) + elif backend == "tensor": + transforms.append(T.PILToTensor()) + else: + assert backend == "pil" + + transforms.append(T.Resize(base_size, antialias=True)) + + if backend == "pil": + transforms.append(T.ToImageTensor()) + + transforms += [ + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + super().__init__(transforms) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 1aa72a9fe38..b06bd9ae985 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -31,7 +31,7 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: - return presets.SegmentationPresetTrain(base_size=520, crop_size=480) + return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() @@ -44,7 +44,7 @@ def preprocessing(img, target): return preprocessing else: - return presets.SegmentationPresetEval(base_size=520) + return presets.SegmentationPresetEval(base_size=520, backend=args.backend) def criterion(inputs, target): @@ -306,6 +306,7 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive") return parser diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..cea3668247e 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -2,10 +2,41 @@ import numpy as np import torch -from torchvision import transforms as T +import torchvision.transforms.v2 as PT +import torchvision.transforms.v2.functional as PF +from torchvision import datapoints, transforms as T from torchvision.transforms import functional as F +class WrapIntoFeatures(PT.Transform): + def forward(self, sample): + image, mask = sample + # return PF.to_image_tensor(image), datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + return image, datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64) + + +class PadIfSmaller(PT.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = PT._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = PT.utils.query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = PT._utils._convert_fill_arg(fill) + + return PF.pad(inpt, padding=params["padding"], fill=fill) + + def pad_if_smaller(img, size, fill=0): min_size = min(img.size) if min_size < size: diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 396aae54da0..51d2ac64a0f 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -253,8 +253,6 @@ def wrapper(idx, sample): len(batched_target["keypoints"]), -1, 3 ) - - return image, batched_target return wrapper