-
Notifications
You must be signed in to change notification settings - Fork 0
/
xca.py
66 lines (51 loc) · 2.01 KB
/
xca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Paper: XCiT: Cross-Covariance Image Transformers
Link: https://arxiv.org/abs/2106.09681
"""
import torch
import torch.nn as nn
class XCA(nn.Module):
""" Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted
sum. The weights are obtained from the (softmax normalized) Cross-covariance
matrix (Q^T K \\in d_h \\times d_h)
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, H=14, W=14):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}
if __name__ == '__main__':
dim = 768
num_heads = 12
H = W = 14
B = 64
model = XCA(dim=dim, num_heads=num_heads, qkv_bias=True)
from utils import measure_flops_params, measure_throughput_cpu, measure_throughput_gpu
x = torch.randn(1, H * W, dim)
measure_flops_params(model, x)
measure_throughput_cpu(model)
measure_throughput_gpu(model)