-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvig.py
219 lines (188 loc) · 8.74 KB
/
vig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# 2022.10.31-Changed for building ViG model
# Huawei Technologies Co., Ltd. <foss@huawei.com>
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq
from gcn_lib import Grapher, act_layer
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'gnn_patch16_224': _cfg(
crop_pct=0.9, input_size=(3, 224, 224),
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Sequential(
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
nn.BatchNorm2d(hidden_features),
)
self.act = act_layer(act)
self.fc2 = nn.Sequential(
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
nn.BatchNorm2d(out_features),
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
class Stem(nn.Module):
""" Image to Visual Word Embedding
Overlap: https://arxiv.org/pdf/2106.13797.pdf
"""
def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_dim, out_dim//8, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//8),
act_layer(act),
nn.Conv2d(out_dim//8, out_dim//4, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//4),
act_layer(act),
nn.Conv2d(out_dim//4, out_dim//2, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//2),
act_layer(act),
nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim),
act_layer(act),
nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
)
def forward(self, x):
x = self.convs(x)
return x
class DeepGCN(torch.nn.Module):
def __init__(self, opt):
super(DeepGCN, self).__init__()
channels = opt.n_filters
k = opt.k
act = opt.act
norm = opt.norm
bias = opt.bias
epsilon = opt.epsilon
stochastic = opt.use_stochastic
conv = opt.conv
self.n_blocks = opt.n_blocks
drop_path = opt.drop_path
self.stem = Stem(out_dim=channels, act=act)
dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay rule
print('dpr', dpr)
num_knn = [int(x.item()) for x in torch.linspace(k, 2*k, self.n_blocks)] # number of knn's k
print('num_knn', num_knn)
max_dilation = 196 // max(num_knn)
self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14))
if opt.use_dilation:
self.backbone = Seq(*[Seq(Grapher(channels, num_knn[i], min(i // 4 + 1, max_dilation), conv, act, norm,
bias, stochastic, epsilon, 1, drop_path=dpr[i]),
FFN(channels, channels * 4, act=act, drop_path=dpr[i])
) for i in range(self.n_blocks)])
else:
self.backbone = Seq(*[Seq(Grapher(channels, num_knn[i], 1, conv, act, norm,
bias, stochastic, epsilon, 1, drop_path=dpr[i]),
FFN(channels, channels * 4, act=act, drop_path=dpr[i])
) for i in range(self.n_blocks)])
self.prediction = Seq(nn.Conv2d(channels, 1024, 1, bias=True),
nn.BatchNorm2d(1024),
act_layer(act),
nn.Dropout(opt.dropout),
nn.Conv2d(1024, opt.n_classes, 1, bias=True))
self.model_init()
def model_init(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
m.weight.requires_grad = True
if m.bias is not None:
m.bias.data.zero_()
m.bias.requires_grad = True
def forward(self, inputs):
x = self.stem(inputs) + self.pos_embed
B, C, H, W = x.shape
for i in range(self.n_blocks):
x = self.backbone[i](x)
x = F.adaptive_avg_pool2d(x, 1)
return self.prediction(x).squeeze(-1).squeeze(-1)
@register_model
def vig_ti_224_gelu(pretrained=False, **kwargs):
class OptInit:
def __init__(self, num_classes=1000, drop_path_rate=0.0, drop_rate=0.0, num_knn=9, **kwargs):
self.k = num_knn # neighbor num (default:9)
self.conv = 'mr' # graph conv layer {edge, mr}
self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
self.norm = 'batch' # batch or instance normalization {batch, instance}
self.bias = True # bias of conv layer True or False
self.n_blocks = 12 # number of basic blocks in the backbone
self.n_filters = 192 # number of channels of deep features
self.n_classes = num_classes # Dimension of out_channels
self.dropout = drop_rate # dropout rate
self.use_dilation = True # use dilated knn or not
self.epsilon = 0.2 # stochastic epsilon for gcn
self.use_stochastic = False # stochastic for gcn, True or False
self.drop_path = drop_path_rate
opt = OptInit(**kwargs)
model = DeepGCN(opt)
model.default_cfg = default_cfgs['gnn_patch16_224']
return model
@register_model
def vig_s_224_gelu(pretrained=False, **kwargs):
class OptInit:
def __init__(self, num_classes=1000, drop_path_rate=0.0, drop_rate=0.0, num_knn=9, **kwargs):
self.k = num_knn # neighbor num (default:9)
self.conv = 'mr' # graph conv layer {edge, mr}
self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
self.norm = 'batch' # batch or instance normalization {batch, instance}
self.bias = True # bias of conv layer True or False
self.n_blocks = 16 # number of basic blocks in the backbone
self.n_filters = 320 # number of channels of deep features
self.n_classes = num_classes # Dimension of out_channels
self.dropout = drop_rate # dropout rate
self.use_dilation = True # use dilated knn or not
self.epsilon = 0.2 # stochastic epsilon for gcn
self.use_stochastic = False # stochastic for gcn, True or False
self.drop_path = drop_path_rate
opt = OptInit(**kwargs)
model = DeepGCN(opt)
model.default_cfg = default_cfgs['gnn_patch16_224']
return model
@register_model
def vig_b_224_gelu(pretrained=False, **kwargs):
class OptInit:
def __init__(self, num_classes=1000, drop_path_rate=0.0, drop_rate=0.0, num_knn=9, **kwargs):
self.k = num_knn # neighbor num (default:9)
self.conv = 'mr' # graph conv layer {edge, mr}
self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
self.norm = 'batch' # batch or instance normalization {batch, instance}
self.bias = True # bias of conv layer True or False
self.n_blocks = 16 # number of basic blocks in the backbone
self.n_filters = 640 # number of channels of deep features
self.n_classes = num_classes # Dimension of out_channels
self.dropout = drop_rate # dropout rate
self.use_dilation = True # use dilated knn or not
self.epsilon = 0.2 # stochastic epsilon for gcn
self.use_stochastic = False # stochastic for gcn, True or False
self.drop_path = drop_path_rate
opt = OptInit(**kwargs)
model = DeepGCN(opt)
model.default_cfg = default_cfgs['gnn_patch16_224']
return model