-
Notifications
You must be signed in to change notification settings - Fork 1
/
van.py
59 lines (47 loc) · 1.55 KB
/
van.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Paper: Visual Attention Network
Link: https://arxiv.org/abs/2202.09741
"""
import torch
import torch.nn as nn
class LKA(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
self.conv1 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
class VanAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.dim = d_model
self.proj_1 = nn.Conv2d(d_model, d_model, 1)
self.activation = nn.GELU()
self.spatial_gating_unit = LKA(d_model)
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
def forward(self, x, H=14, W=14):
B, N, C = x.shape
x = x.permute(0, 2, 1).reshape(B, C, H, W)
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x.reshape(B, C, N).permute(0, 2, 1)
if __name__ == '__main__':
dim = 768
num_heads = 12
H = W = 14
B = 64
model = VanAttention(d_model=dim)
from utils import measure_flops_params, measure_throughput_cpu, measure_throughput_gpu
x = torch.randn(1, H * W, dim)
measure_flops_params(model, x)
measure_throughput_cpu(model)
measure_throughput_gpu(model)