Skip to content

Commit 60ff154

Browse files
committed
Fixing bugs & implement training process
1 parent 4799c6e commit 60ff154

18 files changed

+180
-75
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
data/
2+
13
# Created by .ignore support plugin (hsz.mobi)
24
### Python template
35
# Byte-compiled / optimized / DLL files

build_dataset.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dataset.dataset import BERTDatasetCreator
2+
from dataset import WordVocab
3+
import argparse
4+
import tqdm
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("-v", "--vocab_path", required=True, type=str)
8+
parser.add_argument("-c", "--corpus_path", required=True, type=str)
9+
parser.add_argument("-e", "--encoding", default="utf-8", type=str)
10+
parser.add_argument("-o", "--output_path", required=True, type=str)
11+
args = parser.parse_args()
12+
13+
word_vocab = WordVocab.load_vocab(args.vocab_path)
14+
builder = BERTDatasetCreator(corpus_path=args.corpus_path, vocab=word_vocab, seq_len=None, encoding=args.encoding)
15+
16+
with open(args.output_path, 'w', encoding=args.encoding) as f:
17+
for index in tqdm.tqdm(range(len(builder)), desc="Building Dataset", total=len(builder)):
18+
data = builder[index]
19+
output_form = "%s\t%s\t%s\t%s\t%d\n"
20+
t1_text, t2_text = [" ".join(t) for t in [data["t1_random"], data["t2_random"]]]
21+
t1_label, t2_label = [" ".join([str(i) for i in label]) for label in [data["t1_label"], data["t2_label"]]]
22+
output = output_form % (t1_text, t2_text, t1_label, t2_label, data["is_next"])
23+
f.write(output)

build_vocab.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import argparse
2+
from dataset import WordVocab
3+
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument("-c", "--corpus_path", required=True, type=str)
6+
parser.add_argument("-o", "--output_path", required=True, type=str)
7+
parser.add_argument("-s", "--vocab_size", type=int, default=None)
8+
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
9+
parser.add_argument("-m", "--min_freq", type=int, default=1)
10+
args = parser.parse_args()
11+
12+
with open(args.corpus_path, "r", encoding=args.encoding) as f:
13+
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)
14+
15+
vocab.save_vocab(args.output_path)

dataset/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .dataset import BERTDataset, build_dataset
2+
from .vocab import WordVocab

dataset/dataset.py

+13-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch.utils.data import Dataset
2+
from .vocab import WordVocab
23
import tqdm
34
import random
45
import argparse
@@ -14,7 +15,7 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8"):
1415
with open(corpus_path, "r", encoding=encoding) as f:
1516
for line in tqdm.tqdm(f, desc="Loading Dataset"):
1617
t1, t2, t1_l, t2_l, is_next = line[:-1].split("\t")
17-
t1_l, t2_l = [[int(i) for i in label.split(",")] for label in [t1_l, t2_l]]
18+
t1_l, t2_l = [[token for token in label.split(" ")] for label in [t1_l, t2_l]]
1819
is_next = int(is_next)
1920
self.datas.append({
2021
"t1": t1,
@@ -29,12 +30,14 @@ def __len__(self):
2930

3031
def __getitem__(self, item):
3132
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
32-
t1, t1_len = self.vocab.to_seq(self.datas[item]["t1"], seq_len=self.seq_len, with_sos=True, with_eos=True)
33-
t2, t2_len = self.vocab.to_seq(self.datas[item]["t2"], seq_len=self.seq_len, with_eos=True)
33+
t1 = self.vocab.to_seq(self.datas[item]["t1"], with_sos=True, with_eos=True)
34+
t2 = self.vocab.to_seq(self.datas[item]["t2"], with_eos=True)
35+
36+
t1_label = self.vocab.to_seq(self.datas[item]["t1_label"])
37+
t2_label = self.vocab.to_seq(self.datas[item]["t2_label"])
3438

3539
output = {"t1": t1, "t2": t2,
36-
"t1_len": t1_len, "t2_len": t2_len,
37-
"t1_label": self.datas[item]["t1_label"], "t2_label": self.datas[item]["t2_label"],
40+
"t1_label": t1_label, "t2_label": t2_label,
3841
"is_next": self.datas[item]["is_next"]}
3942

4043
return {key: torch.tensor(value) for key, value in output.items()}
@@ -79,38 +82,18 @@ def random_word(self, sentence):
7982
def random_sent(self, index):
8083
# output_text, label(isNotNext:0, isNext:1)
8184
if random.random() > 0.5:
82-
return self.datas[index][2], 1
85+
return self.datas[index][1], 1
8386
else:
84-
return self.datas[random.randrange(len(self.datas))][2], 0
87+
return self.datas[random.randrange(len(self.datas))][1], 0
8588

8689
def __getitem__(self, index):
87-
t1, (t2, is_next_label) = self.datas[index], self.random_sent(index)
90+
t1, (t2, is_next_label) = self.datas[index][0], self.random_sent(index)
8891
t1_random, t1_label = self.random_word(t1)
8992
t2_random, t2_label = self.random_word(t2)
9093

9194
return {"t1_random": t1_random, "t2_random": t2_random,
9295
"t1_label": t1_label, "t2_label": t2_label,
9396
"is_next": is_next_label}
9497

95-
96-
if __name__ == "__main__":
97-
from .vocab import WordVocab
98-
99-
parser = argparse.ArgumentParser()
100-
parser.add_argument("-v", "--vocab_path", required=True, type=str)
101-
parser.add_argument("-c", "--corpus_path", required=True, type=str)
102-
parser.add_argument("-e", "--encoding", default="utf-8", type=str)
103-
parser.add_argument("-o", "--output_path", required=True, type=str)
104-
args = parser.parse_args()
105-
106-
word_vocab = WordVocab.load_vocab(args.vocab_path)
107-
builder = BERTDatasetCreator(corpus_path=args.corpus_path, vocab=word_vocab, seq_len=None, encoding=args.encoding)
108-
109-
with open(args.output_path, 'w', encoding=args.encoding) as f:
110-
for index in tqdm.tqdm(range(len(builder)), desc="Building Dataset", total=len(builder)):
111-
data = builder[index]
112-
output_form = "%s\t%s\t%s\t%d\n"
113-
t1_text, t2_text = [" ".join(t) for t in [data["t1_random"], data["t2_random"]]]
114-
t1_label, t2_label = [",".join([str(i) for i in label]) for label in [data["t1_label"], data["t2_label"]]]
115-
output = output_form % (t1_text, t2_text, t1_label, t2_label, data["is_next"])
116-
f.write(output_form)
98+
def __len__(self):
99+
return len(self.datas)

dataset/vocab.py

+9-17
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,20 @@ def __init__(self, texts, max_size=None, min_freq=1):
121121
print("Building Vocab")
122122
counter = Counter()
123123
for line in tqdm.tqdm(texts):
124-
words = line.replace("\n", "").replace("\t", "").split()
124+
if isinstance(line, list):
125+
words = line
126+
else:
127+
words = line.replace("\n", "").replace("\t", "").split()
128+
125129
for word in words:
126130
counter[word] += 1
127131
super().__init__(counter, max_size=max_size, min_freq=min_freq)
128132

129133
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
130-
seq = [self.stoi.get(word, self.unk_index) for word in sentence.split()]
134+
if isinstance(sentence, str):
135+
sentence = sentence.split()
136+
137+
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
131138

132139
if with_eos:
133140
seq += [self.eos_index] # this would be index 1
@@ -158,18 +165,3 @@ def from_seq(self, seq, join=False, with_pad=False):
158165
def load_vocab(vocab_path: str) -> 'WordVocab':
159166
with open(vocab_path, "rb") as f:
160167
return pickle.load(f)
161-
162-
163-
if __name__ == "__main__":
164-
import argparse
165-
166-
parser = argparse.ArgumentParser()
167-
parser.add_argument("-c", "--corpus_path", required=True, type=str)
168-
parser.add_argument("-o", "--output_path", required=True, type=str)
169-
parser.add_argument("-s", "--vocab_size", type=int, default=None)
170-
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
171-
parser.add_argument("-m", "--min_freq", type=int, default=1)
172-
args = parser.parse_args()
173-
174-
with open(args.corpus_path, "r", encoding=args.encoding) as f:
175-
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)

model/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .bert import BERT
2+
from .language_model import BERTLM

model/attention/multi_head.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def __init__(self, h, d_model, dropout=0.1):
1212
self.d_k = d_model // h
1313
self.h = h
1414

15-
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(4)])
15+
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
16+
self.output_linear = nn.Linear(d_model, d_model)
1617
self.attention = Attention()
1718

1819
self.attn = None
@@ -37,4 +38,4 @@ def forward(self, query, key, value, mask=None):
3738
# 3) "Concat" using a view and apply a final linear.
3839
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
3940

40-
return self.linears[-1](x)
41+
return self.output_linear(x)

model/bert.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,31 @@
55

66

77
class BERT(nn.Module):
8-
def __init__(self, embedding: BERTEmbedding, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
8+
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
99
super().__init__()
1010
self.hidden = hidden
1111
self.n_layers = n_layers
1212
self.attn_heads = attn_heads
1313
self.feed_forward_hidden = hidden * 4
1414

15-
self.embedding: BERTEmbedding = embedding
15+
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
1616
self.transformer_blocks = nn.ModuleList(
1717
[TransformerBlock(hidden=hidden,
1818
attn_heads=attn_heads,
1919
feed_forward_hidden=hidden * 4,
2020
dropout=dropout)
2121
for _ in range(n_layers)])
2222

23-
def forward(self, x, mask=None):
23+
def forward(self, x):
24+
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
25+
26+
# sequence -> embedding : (batch_size, seq_len) -> (batch_size, seq_len, embed_size)
27+
x = self.embedding(x)
28+
29+
# embedding through the transformer self-attention
30+
# embedding (batch_size, seq_len, embed_size = hidden) -> transformer_output (batch_size, seq_len, hidden)
31+
# loop transformer (batch_size, seq_len, hidden) -> transformer_output (batch_size, seq_len, hidden)
2432
for transformer in self.transformer_blocks:
2533
x = transformer.forward(x, mask)
34+
2635
return x

model/embedding/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .bert_embedding import BERTEmbedding
1+
from .bert import BERTEmbedding

model/embedding/bert.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch.nn as nn
2+
from .token import TokenEmbedding
3+
from .position import PositionalEmbedding
4+
from .segment import SegmentEmbedding
5+
6+
7+
class BERTEmbedding(nn.Module):
8+
def __init__(self, vocab_size, embed_size, dropout=0.1):
9+
super().__init__()
10+
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
11+
self.position = PositionalEmbedding(self.token.embedding_dim, dropout=dropout)
12+
self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
13+
14+
def forward(self, sequence):
15+
return self.position(self.token(sequence))

model/embedding/bert_embedding.py

-15
This file was deleted.

model/embedding/position.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def __init__(self, d_model, dropout, max_len=512):
1111
self.dropout = nn.Dropout(p=dropout)
1212

1313
# Compute the positional encodings once in log space.
14-
pe = torch.zeros(max_len, d_model)
14+
pe = torch.zeros(max_len, d_model).float()
1515
pe.require_grad = False
1616

17-
position = torch.arange(0, max_len).unsqueeze(1)
18-
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
17+
position = torch.arange(0, max_len).float().unsqueeze(1)
18+
div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()
1919

2020
pe[:, 0::2] = torch.sin(position * div_term)
2121
pe[:, 1::2] = torch.cos(position * div_term)

model/language_model.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from .bert import BERT
2+
import torch.nn as nn
3+
4+
5+
class BERTLM(nn.Module):
6+
def __init__(self, bert: BERT, vocab_size):
7+
super().__init__()
8+
self.next_sentence = BERTNextSentence(bert)
9+
self.mask_lm = BERTMaskLM(bert, vocab_size)
10+
11+
def forward(self, x):
12+
return self.next_sentence(x), self.mask_lm(x)
13+
14+
15+
class BERTNextSentence(nn.Module):
16+
def __init__(self, bert: BERT):
17+
super().__init__()
18+
self.bert = bert
19+
self.linear = nn.Linear(self.bert.hidden, 2)
20+
self.softmax = nn.LogSoftmax(dim=-1)
21+
22+
def forward(self, x):
23+
return self.softmax(self.linear(x))
24+
25+
26+
class BERTMaskLM(nn.Module):
27+
def __init__(self, bert: BERT, vocab_size):
28+
super().__init__()
29+
self.bert = bert
30+
self.linear = nn.Linear(self.bert.hidden, vocab_size)
31+
self.softmax = nn.LogSoftmax(dim=-1)
32+
33+
def forward(self, x):
34+
return self.softmax(self.linear(x))

model/transformer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
1313
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
1414

1515
def forward(self, x, mask):
16-
_x = self.attention.forward(x, x, x, mask=mask)
17-
x = self.input_sublayer(x, _x)
18-
_x = self.feed_forward(x)
19-
x = self.output_sublayer(x, _x)
16+
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
17+
x = self.output_sublayer(x, self.feed_forward)
2018
return x

train.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import argparse
2+
from dataset.dataset import BERTDataset, WordVocab
3+
from torch.utils.data import DataLoader
4+
from model import BERT, BERTLM
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("-d", "--dataset_path", required=True, type=str)
8+
parser.add_argument("-v", "--vocab_path", required=True, type=str)
9+
args = parser.parse_args()
10+
11+
vocab = WordVocab.load_vocab(args.vocab_path)
12+
dataset = BERTDataset(args.dataset_path, vocab, seq_len=10)
13+
data_loader = DataLoader(dataset, batch_size=16)
14+
15+
bert = BERT(len(vocab), hidden=128, n_layers=2, attn_heads=4)
16+
17+
18+
for data in data_loader:
19+
x = model.forward(data["t1"])
20+
print(x.size())

trainer/__init__.py

Whitespace-only changes.

trainer/pretrain.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch.nn as nn
2+
from torch.optim import Adam
3+
from model import BERTLM, BERT
4+
5+
6+
class BERTTrainer:
7+
def __init__(self, bert: BERT, vocab_size, train_dataloader, test_dataloader=None):
8+
self.bert = bert
9+
self.lm = BERTLM(bert, vocab_size)
10+
11+
self.train_data = train_dataloader
12+
self.test_data = test_dataloader
13+
14+
self.optim = Adam(self.lm.parameters())
15+
self.criterion = nn.NLLLoss()
16+
17+
def train(self, epoch):
18+
self.iteration(epoch, self.train_data)
19+
20+
def test(self, epoch):
21+
self.iteration(epoch, self.test_data, train=False)
22+
23+
def iteration(self, epoch, data_loader, train=True):
24+
pass

0 commit comments

Comments
 (0)