-
Notifications
You must be signed in to change notification settings - Fork 2
/
LoRa_model.py
105 lines (81 loc) · 3.37 KB
/
LoRa_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
class LowRankAttention(nn.Module):
def __init__(self, dim, rank):
super(LowRankAttention, self).__init__()
self.rank = rank
self.Wq = nn.Linear(dim, rank, bias=False)
self.Wk = nn.Linear(dim, rank, bias=False)
self.Wv = nn.Linear(dim, rank, bias=False)
self.Wo = nn.Linear(rank, dim, bias=False)
def forward(self, q, k, v):
Q = self.Wq(q)
K = self.Wk(k)
V = self.Wv(v)
# Compute the attention scores using low-rank approximation
A = torch.bmm(Q, K.transpose(-2, -1)) / (self.rank ** 0.5)
# Softmax along the key dimension
A = torch.softmax(A, dim=-1)
# Compute the attention-weighted values using low-rank approximation
AV = torch.bmm(A, V)
# Apply the output layer to the attention-weighted values
out = self.Wo(AV)
return out
class LowRankTransformerLayer(nn.Module):
def __init__(self, dim, rank, dropout=0.2):
super(LowRankTransformerLayer, self).__init__()
self.attention = LowRankAttention(dim, rank)
self.norm1 = nn.LayerNorm(dim)
self.dropout1 = nn.Dropout(dropout)
self.feedforward = nn.Sequential(
nn.Linear(dim, dim * 3),
nn.GELU(),
nn.Linear(dim * 3, dim)
)
self.norm2 = nn.LayerNorm(dim)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
# Compute the self-attention layer
attention_out = self.attention(x, x, x)
# Add residual connection and normalize
x = self.norm1(x + self.dropout1(attention_out))
# Feed-forward layer
ff_out = self.feedforward(x)
# Add residual connection and normalize
x = self.norm2(x + self.dropout2(ff_out))
return x
class LowRankTransformer(nn.Module):
def __init__(self, vocab_size, num_layers, dim, rank, num_heads, dropout= 0.2):
super(LowRankTransformer, self).__init__()
self.layers = nn.ModuleList([LowRankTransformerLayer(dim, rank, dropout) for _ in range(num_layers)])
self.num_layers = num_layers
self.dim = dim
self.rank = rank
self.num_heads = num_heads
self.pos_embedding = nn.Embedding(vocab_size, dim)
self.dropout = nn.Dropout(dropout)
# init all weights
## from karpathy
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('Wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layers))
# report number of parameters
print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
# Add positional embeddings
x = x + self.pos_embedding[:, :x.size(1)]
# Apply dropout
x = self.dropout(x)
# Apply the transformer layers
for layer in self.layers:
x = layer(x)
return x