Skip to content

Commit 7b145dc

Browse files
committed
Adding none memory loading
1 parent 913c43a commit 7b145dc

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

bert_pytorch/dataset/dataset.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ class BERTDataset(Dataset):
88
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
99
self.vocab = vocab
1010
self.seq_len = seq_len
11+
1112
self.on_memory = on_memory
1213
self.corpus_lines = corpus_lines
14+
self.corpus_path = corpus_path
15+
self.encoding = encoding
1316

1417
with open(corpus_path, "r", encoding=encoding) as f:
1518
if self.corpus_lines is None and not on_memory:
@@ -21,6 +24,13 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=N
2124
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
2225
self.corpus_lines = len(self.lines)
2326

27+
if not on_memory:
28+
self.file = open(corpus_path, "r", encoding=encoding)
29+
self.random_file = open(corpus_path, "r", encoding=encoding)
30+
31+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
32+
self.random_file.__next__()
33+
2434
def __len__(self):
2535
return self.corpus_lines
2636

@@ -78,8 +88,36 @@ def random_word(self, sentence):
7888
return tokens, output_label
7989

8090
def random_sent(self, index):
91+
t1, t2 = self.get_corpus_line(index)
92+
8193
# output_text, label(isNotNext:0, isNext:1)
8294
if random.random() > 0.5:
83-
return self.datas[index][0], self.datas[index][1], 1
95+
return t1, t2, 1
96+
else:
97+
return t1, self.get_random_line(), 0
98+
99+
def get_corpus_line(self, item):
100+
if self.on_memory:
101+
return self.lines[item][0], self.lines[item][1]
84102
else:
85-
return self.datas[index][0], self.datas[random.randrange(len(self.datas))][1], 0
103+
line = self.file.__next__()
104+
if line is None:
105+
self.file.close()
106+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
107+
line = self.file.__next__()
108+
109+
t1, t2 = line[:-1].split("\t")
110+
return t1, t2
111+
112+
def get_random_line(self):
113+
if self.on_memory:
114+
return self.lines[random.randrange(len(self.lines))][1]
115+
116+
line = self.file.__next__()
117+
if line is None:
118+
self.file.close()
119+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
120+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
121+
self.random_file.__next__()
122+
line = self.random_file.__next__()
123+
return line[:-1].split("\t")[1]

bert_pytorch/trainer/pretrain.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from torch.optim import Adam
44
from torch.utils.data import DataLoader
55

6-
from encoding.parallel import DataParallelModel, DataParallelCriterion
7-
86
from ..model import BERTLM, BERT
97

108
import tqdm
@@ -49,7 +47,7 @@ def __init__(self, bert: BERT, vocab_size: int,
4947
# Distributed GPU training if CUDA can detect more than 1 GPU
5048
if with_cuda and torch.cuda.device_count() > 1:
5149
print("Using %d GPUS for BERT" % torch.cuda.device_count())
52-
self.model = DataParallelModel(self.model, device_ids=cuda_devices)
50+
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
5351

5452
# Setting the train and test data loader
5553
self.train_data = train_dataloader
@@ -60,8 +58,6 @@ def __init__(self, bert: BERT, vocab_size: int,
6058

6159
# Using Negative Log Likelihood Loss function for predicting the masked_token
6260
self.criterion = nn.NLLLoss(ignore_index=0)
63-
if with_cuda and torch.cuda.device_count() > 0:
64-
self.criterion = DataParallelCriterion(self.criterion, device_ids=cuda_devices)
6561

6662
self.log_freq = log_freq
6763

requirements.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
tqdm
22
numpy
3-
torch>=0.4.0
4-
torch-encodin
3+
torch>=0.4.0

0 commit comments

Comments
 (0)