Skip to content

Commit

Permalink
Add segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Feb 22, 2023
1 parent c00a181 commit 5147d8b
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 45 deletions.
1 change: 0 additions & 1 deletion references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def get_coco(root, image_set, transforms, mode="instances"):
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)


dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)

Expand Down
12 changes: 7 additions & 5 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict

import torch
import transforms as reference_transforms
import torchvision
import transforms as reference_transforms

torchvision.disable_beta_transforms_warning()
from torchvision import datapoints
import torchvision.transforms.v2 as T
from torchvision import datapoints


# TODO: Should we provide a transforms that filters-out keys?
Expand Down Expand Up @@ -64,7 +65,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104
transforms += [
T.ConvertImageDtype(torch.float),
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]) # TODO: sad it's not the default!
T.SanitizeBoundingBoxes(
labels_getter=lambda sample: sample[1]["labels"]
), # TODO: sad it's not the default!
]

super().__init__(transforms)
Expand All @@ -78,9 +81,8 @@ def __init__(self, backend="pil"):
backend = backend.lower()
if backend == "tensor":
transforms.append(T.PILToTensor())
else: # for datapoint **and** PIL
else: # for datapoint **and** PIL
transforms.append(T.ToImageTensor())


transforms.append(T.ConvertImageDtype(torch.float))
super().__init__(transforms)
111 changes: 77 additions & 34 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,82 @@
from collections import defaultdict

import torch
import transforms as T
import torchvision

torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as T
from torchvision import datapoints
from transforms import PadIfSmaller, WrapIntoFeatures


class SegmentationPresetTrain(T.Compose):
def __init__(
self,
*,
base_size,
crop_size,
hflip_prob=0.5,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
backend="pil",
):

transforms = []

transforms.append(WrapIntoFeatures())

class SegmentationPresetTrain:
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
else:
assert backend == "pil"

transforms.append(T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True))

trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend(
[
T.RandomCrop(crop_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
self.transforms = T.Compose(trans)

def __call__(self, img, target):
return self.transforms(img, target)


class SegmentationPresetEval:
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose(
[
T.RandomResize(base_size, base_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)

def __call__(self, img, target):
return self.transforms(img, target)
transforms.append(T.RandomHorizontalFlip(hflip_prob))

transforms += [
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
T.RandomCrop(crop_size),
]

if backend == "pil":
transforms.append(T.ToImageTensor())

transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

super().__init__(transforms)


class SegmentationPresetEval(T.Compose):
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil"):
transforms = []

transforms.append(WrapIntoFeatures())

backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
else:
assert backend == "pil"

transforms.append(T.Resize(base_size, antialias=True))

if backend == "pil":
transforms.append(T.ToImageTensor())

transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
super().__init__(transforms)
5 changes: 3 additions & 2 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sbd(*args, **kwargs):

def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
Expand All @@ -44,7 +44,7 @@ def preprocessing(img, target):

return preprocessing
else:
return presets.SegmentationPresetEval(base_size=520)
return presets.SegmentationPresetEval(base_size=520, backend=args.backend)


def criterion(inputs, target):
Expand Down Expand Up @@ -306,6 +306,7 @@ def get_args_parser(add_help=True):

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
parser.add_argument("--backend", default="PIL", type=str, help="PIL, tensor or datapoint - case insensitive")

return parser

Expand Down
33 changes: 32 additions & 1 deletion references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,41 @@

import numpy as np
import torch
from torchvision import transforms as T
import torchvision.transforms.v2 as PT
import torchvision.transforms.v2.functional as PF
from torchvision import datapoints, transforms as T
from torchvision.transforms import functional as F


class WrapIntoFeatures(PT.Transform):
def forward(self, sample):
image, mask = sample
# return PF.to_image_tensor(image), datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64)
return image, datapoints.Mask(PF.pil_to_tensor(mask).squeeze(0), dtype=torch.int64)


class PadIfSmaller(PT.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = PT._geometry._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = PT.utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = PT._utils._convert_fill_arg(fill)

return PF.pad(inpt, padding=params["padding"], fill=fill)


def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
Expand Down
2 changes: 0 additions & 2 deletions torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,6 @@ def wrapper(idx, sample):
len(batched_target["keypoints"]), -1, 3
)



return image, batched_target

return wrapper
Expand Down

0 comments on commit 5147d8b

Please sign in to comment.