-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c00a181
commit 5147d8b
Showing
6 changed files
with
119 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,82 @@ | ||
from collections import defaultdict | ||
|
||
import torch | ||
import transforms as T | ||
import torchvision | ||
|
||
torchvision.disable_beta_transforms_warning() | ||
import torchvision.transforms.v2 as T | ||
from torchvision import datapoints | ||
from transforms import PadIfSmaller, WrapIntoFeatures | ||
|
||
|
||
class SegmentationPresetTrain(T.Compose): | ||
def __init__( | ||
self, | ||
*, | ||
base_size, | ||
crop_size, | ||
hflip_prob=0.5, | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225), | ||
backend="pil", | ||
): | ||
|
||
transforms = [] | ||
|
||
transforms.append(WrapIntoFeatures()) | ||
|
||
class SegmentationPresetTrain: | ||
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): | ||
min_size = int(0.5 * base_size) | ||
max_size = int(2.0 * base_size) | ||
backend = backend.lower() | ||
if backend == "datapoint": | ||
transforms.append(T.ToImageTensor()) | ||
elif backend == "tensor": | ||
transforms.append(T.PILToTensor()) | ||
else: | ||
assert backend == "pil" | ||
|
||
transforms.append(T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size), antialias=True)) | ||
|
||
trans = [T.RandomResize(min_size, max_size)] | ||
if hflip_prob > 0: | ||
trans.append(T.RandomHorizontalFlip(hflip_prob)) | ||
trans.extend( | ||
[ | ||
T.RandomCrop(crop_size), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
self.transforms = T.Compose(trans) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) | ||
|
||
|
||
class SegmentationPresetEval: | ||
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomResize(base_size, base_size), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) | ||
transforms.append(T.RandomHorizontalFlip(hflip_prob)) | ||
|
||
transforms += [ | ||
# We need a custom pad transform here, since the padding we want to perform here is fundamentally | ||
# different from the padding in `RandomCrop` if `pad_if_needed=True`. | ||
PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), | ||
T.RandomCrop(crop_size), | ||
] | ||
|
||
if backend == "pil": | ||
transforms.append(T.ToImageTensor()) | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
|
||
super().__init__(transforms) | ||
|
||
|
||
class SegmentationPresetEval(T.Compose): | ||
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil"): | ||
transforms = [] | ||
|
||
transforms.append(WrapIntoFeatures()) | ||
|
||
backend = backend.lower() | ||
if backend == "datapoint": | ||
transforms.append(T.ToImageTensor()) | ||
elif backend == "tensor": | ||
transforms.append(T.PILToTensor()) | ||
else: | ||
assert backend == "pil" | ||
|
||
transforms.append(T.Resize(base_size, antialias=True)) | ||
|
||
if backend == "pil": | ||
transforms.append(T.ToImageTensor()) | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
super().__init__(transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters