-
Notifications
You must be signed in to change notification settings - Fork 12
/
augmentations.py
189 lines (160 loc) · 6.5 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import random
import torch
from torchvision import transforms
from histaugan.model import MD_multi
# ------------
# different types of augmentations used in the paper
# ------------
class RandomRotate90:
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return transforms.functional.rotate(x, angle)
geom_augmentations = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
RandomRotate90([0, 90, 180, 270]),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(),
])
# geometric augmentations + brightness/contrast jitter + Gaussian blur + random erasing
basic_augmentations = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
RandomRotate90([0, 90, 180, 270]),
transforms.RandomApply((transforms.GaussianBlur(3), ), p=0.25),
transforms.RandomApply((transforms.ColorJitter(
brightness=0.1, contrast=0.1), ), p=0.5),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(),
])
# same as geometric augmentations
gan_augmentations = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
RandomRotate90([0, 90, 180, 270]),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(),
])
# basic augmentations + hue/saturation jitter
color_augmentations = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
RandomRotate90([0, 90, 180, 270]),
transforms.RandomApply((transforms.GaussianBlur(3), ), p=0.25),
transforms.RandomApply((transforms.ColorJitter(
brightness=0.1, contrast=0.1), ), p=0.5),
transforms.RandomApply(
(transforms.ColorJitter(saturation=0.5, hue=0.5), ), p=0.5),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(),
])
# basic augmentations + light hue/saturation jitter
color_augmentations_light = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
RandomRotate90([0, 90, 180, 270]),
transforms.RandomApply((transforms.GaussianBlur(3), ), p=0.25),
transforms.RandomApply((transforms.ColorJitter(
brightness=0.1, contrast=0.1), ), p=0.5),
transforms.RandomApply(
(transforms.ColorJitter(saturation=0.1, hue=0.1), ), p=0.5),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(),
])
no_augmentations = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
def normalization(center):
assert center in range(5), 'center not valid, should be in range(5)'
mean = [
[0.6710, 0.5327, 0.6448],
[0.6475, 0.5139, 0.6222],
[0.7875, 0.6251, 0.7567],
[0.4120, 0.3270, 0.3959],
[0.7324, 0.5814, 0.7038]
]
std = [
[0.2083, 0.2294, 0.1771],
[0.2060, 0.2261, 0.1754],
[0.2585, 0.2679, 0.2269],
[0.2605, 0.2414, 0.2394],
[0.2269, 0.2450, 0.1950]
]
return mean[center], std[center]
# options for the model, default arguments + commandline arguments
class Args:
concat = 1
crop_size = 216 # only used as an argument for training
dis_norm = None
dis_scale = 3
dis_spectral_norm = False
dataroot = 'data'
gpu = 1
input_dim = 3
nThreads = 4
num_domains = 5
nz = 8
# resume = False
resume = 'gan_weights.pth'
opts = Args()
mean_domains = [
torch.tensor([0.3020, -2.6476, -0.9849, -0.7820, -
0.2746, 0.3361, 0.1694, -1.2148]),
torch.tensor([0.1453, -1.2400, -0.9484, 0.9697, -
2.0775, 0.7676, -0.5224, -0.2945]),
torch.tensor([2.1067, -1.8572, 0.0055, 1.2214, -
2.9363, 2.0249, -0.4593, -0.9771]),
torch.tensor([0.8378, -2.1174, -0.6531, 0.2986, -
1.3629, -0.1237, -0.3486, -1.0716]),
torch.tensor([1.6073, 1.9633, -0.3130, -1.9242, -
0.9673, 2.4990, -2.2023, -1.4109]),
]
std_domains = [
torch.tensor([0.6550, 1.5427, 0.5444, 0.7254,
0.6701, 1.0214, 0.6245, 0.6886]),
torch.tensor([0.4143, 0.6543, 0.5891, 0.4592,
0.8944, 0.7046, 0.4441, 0.3668]),
torch.tensor([0.5576, 0.7634, 0.7875, 0.5220,
0.7943, 0.8918, 0.6000, 0.5018]),
torch.tensor([0.4157, 0.4104, 0.5158, 0.3498,
0.2365, 0.3612, 0.3375, 0.4214]),
torch.tensor([0.6154, 0.3440, 0.7032, 0.6220,
0.4496, 0.6488, 0.4886, 0.2989]),
]
def generate_hist_augs(img, img_domain, model, z_content=None, same_attribute=False, new_domain=None, stats=None, device=torch.device('cpu')):
"""
Generates a new stain color for the input image img.
:img: input image of shape (3, 216, 216) [type: torch.Tensor]
:img_domain: int in range(5)
:model: HistAuGAN model
:z_content: content encoding, if None this will be computed from img
:same_attribute: [type: bool] indicates whether the attribute encoding of img or a randomly generated attribute are used
:new_domain: either int in range(5) or torch.Tensor of shape (1, 5)
:stats: (mean, std dev) of the latent space of HistAuGAN
:device: torch.device to map the tensors to
"""
# compute content vector
if z_content is None:
z_content = model.enc_c(img.sub(0.5).mul(2).unsqueeze(0))
# compute attribute
if same_attribute:
mu, logvar = model.enc_a.forward(img.sub(0.5).mul(
2).unsqueeze(0), torch.eye(5)[img_domain].unsqueeze(0).to(device))
std = logvar.mul(0.5).exp_().to(device)
eps = torch.randn((std.size(0), std.size(1))).to(device)
z_attr = eps.mul(std).add_(mu)
elif same_attribute == False and stats is not None and new_domain in range(5):
z_attr = (torch.randn((1, 8, )) * \
stats[1][new_domain] + stats[0][new_domain]).to(device)
else:
z_attr = torch.randn((1, 8, )).to(device)
# determine new domain vector
if isinstance(new_domain, int) and new_domain in range(5):
new_domain = torch.eye(5)[new_domain].unsqueeze(0).to(device)
elif isinstance(new_domain, torch.Tensor) and new_domain.shape == (1, 5):
new_domain = new_domain.to(device)
else:
new_domain = torch.eye(5)[np.random.randint(5)].unsqueeze(0).to(device)
# generate new histology image with same content as img
out = model.gen(z_content, z_attr, new_domain).detach().squeeze(0) # in range [-1, 1]
return out