Skip to content

Commit

Permalink
Update transformer block code to run with test code brought in
Browse files Browse the repository at this point in the history
  • Loading branch information
insop committed Sep 19, 2020
1 parent 9506c44 commit a06f9ed
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
# Implement Transformer (TODO reference) with C

Implement Tranformer with C to understand the internal operations, so it is intentionally implemented for readablilty than optimality
Implement Tranformer with C and Python for educational purpose. Code is written for readability.

# Build

```
cd sml
cd lib/sml
make
cd ..
cd src
cd src/c
make
```

# Run

## C version
```
cd src
cd src/c
./transformer
```

## Python version
```
cd src/python
python ./experiments/classify.py --random-seed=1234 --num-epochs=1 --tiny
```

# TODO:
- have a main Makefile
- main Makefile for build library and c executable
- add config load
- add trained weight load
- add python code to generate test vector, and use that to test the C code

# Reference:
- SML: small math library, http://www.bios.unc.edu/distrib/bios235/sml/
- Transformer tutorial: http://jalammar.github.io/illustrated-transformer/
- Transformer tutorial2: http://peterbloem.nl/blog/transformers
- Python Transformer implementation: http://peterbloem.nl/blog/transformers
5 changes: 3 additions & 2 deletions src/c/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ CC = gcc
RM = rm

ROOT=../../
CFLAGS=-I${ROOT}/sml -g
LFLAGS=-L${ROOT}/sml ${ROOT}sml/sml.lib
LIB=${ROOT}/lib

CFLAGS=-I${LIB}/sml -g
LFLAGS=-L${LIB}/sml ${LIB}/sml/sml.lib

#*****************************************************************

Expand Down
104 changes: 71 additions & 33 deletions src/python/transformer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,45 @@

import random, math

from util import d, here

class SelfAttention(nn.Module):
def __init__(self, emb, dim_internal, heads=8, mask=False, dropout=0.0):
def __init__(self, dim_emb, dim_internal, heads=8, mask=False, dropout=0.0):
"""
:param emb: embedding dimension
:param dim_emb: embedding dimension
:param dim_internal: dimension of internal representation
:param head: number of multi head
:param mask
"""
super().__init__()

self.emb = emb
self.dim_emb = dim_emb
self.dim_internal = dim_internal
self.heads = heads
self.mask = mask

self.toqueries = nn.Linear(emb, dim_internal)
self.tokeys = nn.Linear(emb, dim_internal)
self.tovalues = nn.Linear(emb, dim_internal)
self.toqueries = nn.Linear(dim_emb, dim_internal)
self.tokeys = nn.Linear(dim_emb, dim_internal)
self.tovalues = nn.Linear(dim_emb, dim_internal)

self.kSqrt_dim_emb = math.sqrt(self.dim_emb)

def forward(self, x):
# single batch
# [batch size, seq, embedding]

# seq, embedding
t, e = x.size()
b, t, e = x.size()

assert e == self.emb, f'Input embedding ({e}) should match the layer embedding ({self.emb})'
assert e == self.dim_emb, f'Input embedding ({e}) should match the layer embedding ({self.dim_emb})'

queries = self.toqueries(x)
keys = self.tokeys(x)
values = self.tovalues(x)

dot = torch.matmul(queries, keys.transpose(0, 1)) / math.sqrt(self.emb)
dot = torch.matmul(queries, keys.transpose(-2, -1)) / self.kSqrt_dim_emb

# softmax on row-wise element
p_attn = F.softmax(dot, dim=1)
p_attn = F.softmax(dot, dim=2)

z = torch.matmul(p_attn, values)

Expand All @@ -48,9 +51,9 @@ def forward(self, x):


class MultiHeadAttention(nn.Module):
def __init__(self, n_seq, emb, dim_internal, heads=8, mask=False, dropout=0.0):
def __init__(self, n_seq, dim_emb, dim_internal, heads=8, mask=False, dropout=0.0):
"""
:param emb: embedding dimension
:param dim_emb: embedding dimension
:param dim_internal: dimension of internal representation
:param head: number of multi head
:param mask
Expand All @@ -59,16 +62,13 @@ def __init__(self, n_seq, emb, dim_internal, heads=8, mask=False, dropout=0.0):
super().__init__()

self.n_seq = n_seq
self.emb = emb
self.dim_emb = dim_emb
self.heads = heads
self.mask = mask

self.toqueries = nn.Linear(emb, dim_internal)
self.tokeys = nn.Linear(emb, dim_internal)

self.attentions = nn.ModuleList([SelfAttention(emb, dim_internal, heads, mask, dropout) \
self.attentions = nn.ModuleList([SelfAttention(dim_emb, dim_internal, heads, mask, dropout) \
for _ in range(heads)])
self.w_o = nn.ModuleList([nn.Linear(dim_internal, emb) \
self.w_o = nn.ModuleList([nn.Linear(dim_internal, dim_emb) \
for _ in range(heads)])

self.layer_norm = nn.LayerNorm(self.n_seq, eps=1e-6)
Expand All @@ -81,31 +81,26 @@ def forward(self, x):
z = attention(x)
output += w_o(z)

#output = F.dropout(output, p=dropout)
# residual
#output = output + x

#output = self.layer_norm(output)

return output


class TransformerBlock(nn.Module):
def __init__(self, n_seq, emb, dim_internal, heads=8, mask=False, ff_hidden_mult=4, dropout=0.0):
def __init__(self, n_seq, dim_emb, dim_internal, heads=8, mask=False, ff_hidden_mult=4, dropout=0.0):
"""
:ff_hidden_mult: number of multiples of embedding for total hidden size
"""
super().__init__()

self.mha = MultiHeadAttention(n_seq=n_seq, emb=emb, dim_internal=dim_internal, heads=heads, mask=mask, dropout=dropout)
self.mha = MultiHeadAttention(n_seq=n_seq, dim_emb=dim_emb, dim_internal=dim_internal, heads=heads, mask=mask, dropout=dropout)
self.mask = mask

self.norm1 = nn.LayerNorm(emb)
self.norm2 = nn.LayerNorm(emb)
self.norm1 = nn.LayerNorm(dim_emb)
self.norm2 = nn.LayerNorm(dim_emb)

self.ff = nn.Sequential(
nn.Linear(emb, ff_hidden_mult * emb),
nn.Linear(dim_emb, ff_hidden_mult * dim_emb),
nn.ReLU(),
nn.Linear(ff_hidden_mult * emb, emb)
nn.Linear(ff_hidden_mult * dim_emb, dim_emb)
)
self.do = nn.Dropout(dropout)

Expand All @@ -128,6 +123,45 @@ def forward(self, x):

return out

## for classify
class TransformerSimpleClassify(nn.Module):
def __init__(self, n_seq, dim_emb, dim_internal, num_tokens, num_classes, max_pool=True, heads=8, depth=6, mask=False, ff_hidden_mult=4, dropout=0.0):
super().__init__()

self.num_tokens = num_tokens
self.max_pool = max_pool
self.depth = depth

self.token_embedding = nn.Embedding(embedding_dim=dim_emb, num_embeddings=num_tokens)
self.pos_embedding = nn.Embedding(embedding_dim=dim_emb, num_embeddings=n_seq)

trfm_blocks = [TransformerBlock(n_seq=n_seq, dim_emb=dim_emb, dim_internal=dim_emb, heads=heads) \
for _ in range(depth)]

self.trfm_blocks = nn.Sequential(*trfm_blocks)

self.toprobs = nn.Linear(dim_emb, num_classes)

self.do = nn.Dropout(dropout)

def forward(self, x):
tokens = self.token_embedding(x)

b, t, e = tokens.size()

positions = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b, t, e)
x = tokens + positions

x = self.do(x)

x = self.trfm_blocks(x)

x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1) # pool in sequence direction

x = self.toprobs(x)

return F.log_softmax(x, dim=1)


if __name__ == "__main__":

Expand All @@ -146,13 +180,17 @@ def forward(self, x):
model = TransformerBlock(2, 4, 3)
print(model)

model_tb = TransformerSimpleClassify(2, 4, 4, 10, 2)
print(model_tb)

"""
opt = torch.optim.Adam(lr=arg.lr, params=model.parameters())
sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0))
"""

x = torch.ones([2,4])
n_batch = 1
x = torch.ones([n_batch, 2,4])
# forward path
model.train(False)

Expand Down

0 comments on commit a06f9ed

Please sign in to comment.