Skip to content

Commit 132dde1

Browse files
committed
Fixing training environment
1 parent 60ff154 commit 132dde1

File tree

12 files changed

+244
-51
lines changed

12 files changed

+244
-51
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
data/
2+
output/
23

34
# Created by .ignore support plugin (hsz.mobi)
45
### Python template

README.md

+115-5
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,123 @@ Pytorch implementation of Google AI's 2018 BERT, with simple annotation
88

99
## Introduction
1010

11-
Currently WIP, with very high speed :)
12-
But it might be takes some days to validate my code
11+
Google AI's BERT paper shows the amazing result on various NLP task (new 17 NLP tasks SOTA),
12+
including outperform the human F1 score on SQuAD v1.1 QA task.
13+
This paper proved that Transformer(self-attention) based encoder can be powerfully used as
14+
alternative of previous language model with proper language model training method.
15+
And more importantly, they showed us that this pre-trained language model can be transfer
16+
into any NLP task without making task specific model architecture.
1317

14-
If you have any comment or question about my code, please leave it to issue.
15-
I'll reply back as soon as possible.
18+
This amazing result would be record in NLP history,
19+
and I expect many further papers about BERT will be published very soon.
1620

17-
Thank you
21+
This repo is implementation of BERT. Code is very simple and easy to understand fastly.
22+
Some of these codes are based on [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)
23+
24+
25+
## Language Model Pre-training
26+
27+
In the paper, authors shows the new language model training methods,
28+
which are "masked language model" and "predict next sentence".
29+
30+
31+
### Masked Language Model
32+
33+
> Original Paper : 3.3.1 Task #1: Masked LM
34+
35+
```
36+
Input Sequence : The man went to [MASK] store with [MASK] dog
37+
Target Sequence : the his
38+
```
39+
40+
#### Rules:
41+
Randomly 15% of input token will be changed into something, based on under sub-rules
42+
43+
1. Randomly 80% of tokens, gonna be a `[MASK]` token
44+
2. Randomly 10% of tokens, gonna be a `[RANDOM]` token(another word)
45+
3. Randomly 10% of tokens, will be remain as same. But need to be predicted.
46+
47+
### Predict Next Sentence
48+
49+
> Original Paper : 3.3.2 Task #2: Next Sentence Prediction
50+
51+
```
52+
Input : [CLS] the man went to the store [SEP] he bought a gallon of milk [SEP]
53+
Label : Is Next
54+
55+
Input = [CLS] the man heading to the store [SEP] penguin [MASK] are flight ##less birds [SEP]
56+
Label = NotNext
57+
```
58+
59+
"Is this sentence can be continuously connected?"
60+
61+
understanding the relationship, between two text sentences, which is
62+
not directly captured by language modeling
63+
64+
#### Rules:
65+
66+
1. Randomly 50% of next sentence, gonna be continuous sentence.
67+
2. Randomly 50% of next sentence, gonna be unrelated sentence.
68+
69+
70+
## Usage
71+
72+
### 1. Building vocab based on your corpus
73+
```shell
74+
python build_vocab.py -c data/corpus.small -o data/corpus.small.vocab
75+
```
76+
```shell
77+
usage: build_vocab.py [-h] -c CORPUS_PATH -o OUTPUT_PATH [-s VOCAB_SIZE]
78+
[-e ENCODING] [-m MIN_FREQ]
79+
80+
optional arguments:
81+
-h, --help show this help message and exit
82+
-c CORPUS_PATH, --corpus_path CORPUS_PATH
83+
-o OUTPUT_PATH, --output_path OUTPUT_PATH
84+
-s VOCAB_SIZE, --vocab_size VOCAB_SIZE
85+
-e ENCODING, --encoding ENCODING
86+
-m MIN_FREQ, --min_freq MIN_FREQ
87+
88+
```
89+
### 2. Building BERT train dataset with your corpus
90+
```shell
91+
python build_dataset.py -d data/corpus.small -v data/corpus.small.vocab -o data/dataset.small
92+
```
93+
94+
```shell
95+
usage: build_dataset.py [-h] -v VOCAB_PATH -c CORPUS_PATH [-e ENCODING] -o
96+
OUTPUT_PATH
97+
98+
optional arguments:
99+
-h, --help show this help message and exit
100+
-v VOCAB_PATH, --vocab_path VOCAB_PATH
101+
-c CORPUS_PATH, --corpus_path CORPUS_PATH
102+
-e ENCODING, --encoding ENCODING
103+
-o OUTPUT_PATH, --output_path OUTPUT_PATH
104+
```
105+
106+
### 3. Train your own BERT model
107+
```shell
108+
python train.py -d data/dataset.small -v data/corpus.small.vocab -o output/
109+
```
110+
```shell
111+
usage: train.py [-h] -d TRAIN_DATASET [-t TEST_DATASET] -v VOCAB_PATH -o
112+
OUTPUT_DIR [-hs HIDDEN] [-n LAYERS] [-a ATTN_HEADS]
113+
[-s SEQ_LEN] [-b BATCH_SIZE] [-e EPOCHS]
114+
115+
optional arguments:
116+
-h, --help show this help message and exit
117+
-d TRAIN_DATASET, --train_dataset TRAIN_DATASET
118+
-t TEST_DATASET, --test_dataset TEST_DATASET
119+
-v VOCAB_PATH, --vocab_path VOCAB_PATH
120+
-o OUTPUT_DIR, --output_dir OUTPUT_DIR
121+
-hs HIDDEN, --hidden HIDDEN
122+
-n LAYERS, --layers LAYERS
123+
-a ATTN_HEADS, --attn_heads ATTN_HEADS
124+
-s SEQ_LEN, --seq_len SEQ_LEN
125+
-b BATCH_SIZE, --batch_size BATCH_SIZE
126+
-e EPOCHS, --epochs EPOCHS
127+
```
18128

19129

20130
## Author

dataset/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .dataset import BERTDataset, build_dataset
1+
from .dataset import BERTDataset
22
from .vocab import WordVocab

dataset/dataset.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from torch.utils.data import Dataset
2-
from .vocab import WordVocab
32
import tqdm
43
import random
5-
import argparse
64
import torch
75

86

@@ -17,13 +15,7 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8"):
1715
t1, t2, t1_l, t2_l, is_next = line[:-1].split("\t")
1816
t1_l, t2_l = [[token for token in label.split(" ")] for label in [t1_l, t2_l]]
1917
is_next = int(is_next)
20-
self.datas.append({
21-
"t1": t1,
22-
"t2": t2,
23-
"t1_label": t1_l,
24-
"t2_label": t2_l,
25-
"is_next": is_next
26-
})
18+
self.datas.append({"t1": t1, "t2": t2, "t1_label": t1_l, "t2_label": t2_l, "is_next": is_next})
2719

2820
def __len__(self):
2921
return len(self.datas)
@@ -33,11 +25,19 @@ def __getitem__(self, item):
3325
t1 = self.vocab.to_seq(self.datas[item]["t1"], with_sos=True, with_eos=True)
3426
t2 = self.vocab.to_seq(self.datas[item]["t2"], with_eos=True)
3527

36-
t1_label = self.vocab.to_seq(self.datas[item]["t1_label"])
37-
t2_label = self.vocab.to_seq(self.datas[item]["t2_label"])
28+
t1_label = [0] + self.vocab.to_seq(self.datas[item]["t1_label"]) + [0]
29+
t2_label = self.vocab.to_seq(self.datas[item]["t2_label"]) + [0]
3830

39-
output = {"t1": t1, "t2": t2,
40-
"t1_label": t1_label, "t2_label": t2_label,
31+
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
32+
bert_input = (t1 + t2)[:self.seq_len]
33+
bert_label = (t1_label + t2_label)[:self.seq_len]
34+
35+
padding = [self.vocab.pad_index for _ in range(self.seq_len - len(t1) - len(t2))]
36+
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
37+
38+
output = {"bert_input": bert_input,
39+
"bert_label": bert_label,
40+
"segment_label": segment_label,
4141
"is_next": self.datas[item]["is_next"]}
4242

4343
return {key: torch.tensor(value) for key, value in output.items()}

model/bert.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@ def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0
2020
dropout=dropout)
2121
for _ in range(n_layers)])
2222

23-
def forward(self, x):
23+
def forward(self, x, segment_info):
2424
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
2525

26-
# sequence -> embedding : (batch_size, seq_len) -> (batch_size, seq_len, embed_size)
27-
x = self.embedding(x)
26+
x = self.embedding(x, segment_info)
2827

29-
# embedding through the transformer self-attention
30-
# embedding (batch_size, seq_len, embed_size = hidden) -> transformer_output (batch_size, seq_len, hidden)
31-
# loop transformer (batch_size, seq_len, hidden) -> transformer_output (batch_size, seq_len, hidden)
3228
for transformer in self.transformer_blocks:
3329
x = transformer.forward(x, mask)
3430

model/embedding/bert.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ def __init__(self, vocab_size, embed_size, dropout=0.1):
1111
self.position = PositionalEmbedding(self.token.embedding_dim, dropout=dropout)
1212
self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
1313

14-
def forward(self, sequence):
15-
return self.position(self.token(sequence))
14+
def forward(self, sequence, segment_label):
15+
return self.position(self.token(sequence)) + self.segment(segment_label)

model/embedding/segment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
class SegmentEmbedding(nn.Embedding):
55
def __init__(self, embed_size=512):
6-
super().__init__(2, embed_size)
6+
super().__init__(3, embed_size, padding_idx=0)

model/embedding/token.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
class TokenEmbedding(nn.Embedding):
55
def __init__(self, vocab_size, embed_size=512):
6-
super().__init__(vocab_size, embed_size)
6+
super().__init__(vocab_size, embed_size, padding_idx=0)

model/language_model.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
class BERTLM(nn.Module):
66
def __init__(self, bert: BERT, vocab_size):
77
super().__init__()
8-
self.next_sentence = BERTNextSentence(bert)
9-
self.mask_lm = BERTMaskLM(bert, vocab_size)
8+
self.bert = bert
9+
self.next_sentence = BERTNextSentence(self.bert.hidden)
10+
self.mask_lm = BERTMaskLM(self.bert.hidden, vocab_size)
1011

11-
def forward(self, x):
12+
def forward(self, x, segment_label):
13+
x = self.bert(x, segment_label)
1214
return self.next_sentence(x), self.mask_lm(x)
1315

1416

1517
class BERTNextSentence(nn.Module):
16-
def __init__(self, bert: BERT):
18+
def __init__(self, hidden):
1719
super().__init__()
18-
self.bert = bert
19-
self.linear = nn.Linear(self.bert.hidden, 2)
20+
self.linear = nn.Linear(hidden, 2)
2021
self.softmax = nn.LogSoftmax(dim=-1)
2122

2223
def forward(self, x):
23-
return self.softmax(self.linear(x))
24+
return self.softmax(self.linear(x[:, 0]))
2425

2526

2627
class BERTMaskLM(nn.Module):
27-
def __init__(self, bert: BERT, vocab_size):
28+
def __init__(self, hidden, vocab_size):
2829
super().__init__()
29-
self.bert = bert
30-
self.linear = nn.Linear(self.bert.hidden, vocab_size)
30+
self.linear = nn.Linear(hidden, vocab_size)
3131
self.softmax = nn.LogSoftmax(dim=-1)
3232

3333
def forward(self, x):

train.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,44 @@
11
import argparse
2-
from dataset.dataset import BERTDataset, WordVocab
2+
33
from torch.utils.data import DataLoader
4-
from model import BERT, BERTLM
4+
5+
from model import BERT
6+
from trainer import BERTTrainer
7+
from dataset import BERTDataset, WordVocab
58

69
parser = argparse.ArgumentParser()
7-
parser.add_argument("-d", "--dataset_path", required=True, type=str)
10+
11+
parser.add_argument("-d", "--train_dataset", required=True, type=str)
12+
parser.add_argument("-t", "--test_dataset", type=str, default=None)
813
parser.add_argument("-v", "--vocab_path", required=True, type=str)
14+
parser.add_argument("-o", "--output_dir", required=True, type=str)
15+
16+
parser.add_argument("-hs", "--hidden", type=int, default=128)
17+
parser.add_argument("-n", "--layers", type=int, default=2)
18+
parser.add_argument("-a", "--attn_heads", type=int, default=4)
19+
parser.add_argument("-s", "--seq_len", type=int, default=10)
20+
21+
parser.add_argument("-b", "--batch_size", type=int, default=64)
22+
parser.add_argument("-e", "--epochs", type=int, default=10)
23+
924
args = parser.parse_args()
1025

1126
vocab = WordVocab.load_vocab(args.vocab_path)
12-
dataset = BERTDataset(args.dataset_path, vocab, seq_len=10)
13-
data_loader = DataLoader(dataset, batch_size=16)
1427

15-
bert = BERT(len(vocab), hidden=128, n_layers=2, attn_heads=4)
28+
print("Loading Train Dataset", args.train_dataset)
29+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len)
30+
print("Loading Test Dataset", args.test_dataset)
31+
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len) if args.test_dataset is not None else None
32+
33+
train_data_loader = DataLoader(train_dataset, batch_size=16)
34+
test_data_loader = DataLoader(test_dataset) if test_dataset is not None else None
35+
36+
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
37+
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader)
1638

39+
for epoch in range(args.epochs):
40+
trainer.train(epoch)
41+
trainer.save(args.output_dir, epoch)
1742

18-
for data in data_loader:
19-
x = model.forward(data["t1"])
20-
print(x.size())
43+
if test_data_loader is not None:
44+
trainer.test(epoch)

trainer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pretrain import BERTTrainer

0 commit comments

Comments
 (0)