-
Notifications
You must be signed in to change notification settings - Fork 4
/
transforms.py
60 lines (50 loc) · 1.86 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#############################################
## Artemis ##
## Copyright (c) 2022-present NAVER Corp. ##
## CC BY-NC-SA 4.0 ##
#############################################
import torchvision.transforms as transforms
from PIL import ImageOps
def get_transform(opt, phase):
# first standard transformations
t_list = [PadSquare(), transforms.Resize(256)]
# phase-specific transformations
if phase == 'train':
# include data augmentation
t_list += [transforms.RandomHorizontalFlip(), transforms.RandomCrop(opt.crop_size)]
else: # evaluate
t_list += [transforms.CenterCrop(opt.crop_size)]
# standard transformations
t_list += [transforms.ToTensor(), Normalizer()]
return MyTransforms(t_list)
class PadSquare(object):
"""
Pad the image with white pixel until it has the form of a square. The amount
of added white pixels is the same at the left & right, and at the top &
bottom.
Input & Output: PIL image.
"""
def __call__(self, img):
w, h = img.size
if w > h:
delta = w - h
padding = (0, delta//2, 0, delta - delta//2)
img = ImageOps.expand(img, padding, (255, 255, 255))
elif w < h:
delta = h - w
padding = (delta//2, 0, delta - delta//2, 0)
img = ImageOps.expand(img, padding, (255, 255, 255))
return img
def Normalizer():
"""
Normalize pixels of a PIL Image according to the mean and std of the
ImageNet pixels.
"""
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
class MyTransforms(object):
def __init__(self, trfs_list):
self.transform = transforms.Compose(trfs_list)
def __call__(self, x):
y = self.transform(x)
return y