-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling_VPT.py
435 lines (362 loc) · 17 KB
/
modeling_VPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
import torch.utils.checkpoint as checkpoint
import math
from functools import partial, reduce
from operator import mul
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, p_k, p_v, scale):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# promot_reparameterization
if scale is not None and p_k is not None and p_v is not None:
scale_k = scale[0].view(-1, 1)
scale_v = scale[1].view(-1, 1)
p_k = p_k * torch.maximum(scale_k, torch.ones_like(scale_k))
p_v = p_v * torch.maximum(scale_v, torch.ones_like(scale_v))
# concat prompts
if p_k is not None and p_v is not None:
k = torch.cat((k, p_k.expand(B, self.num_heads, -1, -1)), dim=2)
v = torch.cat((v, p_v.expand(B, self.num_heads, -1, -1)), dim=2)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, p_k, p_v, scale):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), p_k, p_v, scale))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), p_k, p_v, scale))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.tubelet_size = int(tubelet_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]))
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)
class PromptedVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
fc_drop_rate=0.,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
init_scale=0.,
all_frames=16,
tubelet_size=2,
prompt_num_tokens=5,
prompt_start=0,
prompt_end=11,
prompt_dropout=0.0,
prompt_reparam=False,
prompt_init="random",
transfer_type='end2end',
use_checkpoint=False,
use_mean_pooling=True):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
num_patches = self.patch_embed.num_patches
self.use_checkpoint = use_checkpoint
self.transfer_type = transfer_type
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.head_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
if transfer_type == 'prompt':
# Prompt Tuning
self.num_tokens = prompt_num_tokens
self.prompt_start = prompt_start
self.prompt_end = prompt_end
self.prompt_dropout = prompt_dropout
self.prompt_init = prompt_init
self.prompt_dropout = nn.Dropout(self.prompt_dropout)
self.prompt_reparam = prompt_reparam
assert self.prompt_start >=0 and self.prompt_start < depth, "prompt start should be in range [0, depth-1]"
assert self.prompt_end >=0 and self.prompt_end < depth, "prompt end should be in range [0, depth-1]"
assert self.prompt_end >= self.prompt_start, "prompt end should be larger than (or equal) prompt start"
if self.prompt_init == "random":
head_dim = self.embed_dim // num_heads
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + head_dim)) # noqa
self.prompt_embeddings_k= nn.Parameter(torch.zeros(
self.prompt_end-self.prompt_start+1,
self.num_tokens,
head_dim
))
self.prompt_embeddings_v= nn.Parameter(torch.zeros(
self.prompt_end-self.prompt_start+1,
self.num_tokens,
head_dim
))
if self.prompt_reparam:
self.prompt_scale = nn.Parameter(torch.ones(
self.prompt_end-self.prompt_start+1,
2,
self.num_tokens
))
print("Using reparameterization trick for prompt tuning")
else:
self.prompt_scale = None
nn.init.uniform_(
self.prompt_embeddings_k.data, -val, val)
nn.init.uniform_(
self.prompt_embeddings_v.data, -val, val)
else:
raise ValueError("Other initiation scheme is not supported")
elif transfer_type == 'linear':
self.prompt_start = depth # This will avoid going through the prompt block
self.prompt_end = depth - 1
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
B, N, _ = x.size()
if self.pos_embed is not None:
x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
B, N_vis, C = x.shape
num_layers = len(self.blocks)
for i in range(num_layers):
if i >= self.prompt_start and i <= self.prompt_end:
if self.use_checkpoint:
print("checkpoint")
x = checkpoint.checkpoint(self.blocks[i](x), x)
else:
x = self.blocks[i](x,
self.prompt_dropout(self.prompt_embeddings_k[i-self.prompt_start]),
self.prompt_dropout(self.prompt_embeddings_v[i-self.prompt_start]),
self.prompt_scale[i-self.prompt_start] if self.prompt_scale is not None else None)
else:
if self.use_checkpoint:
print("checkpoint")
x = checkpoint.checkpoint(self.blocks[i](x), x)
else:
x = self.blocks[i](x, None, None, None)
x = self.norm(x)
if self.head_norm is not None:
return self.head_norm(x.mean(1))
else:
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(self.head_dropout(x))
return x
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_large_patch16_224(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_large_patch16_512(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def vit_huge_patch16_224(pretrained=False, **kwargs):
model = PromptedVisionTransformer(
patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model