@@ -8,8 +8,11 @@ class BERTDataset(Dataset):
8
8
def __init__ (self , corpus_path , vocab , seq_len , encoding = "utf-8" , corpus_lines = None , on_memory = True ):
9
9
self .vocab = vocab
10
10
self .seq_len = seq_len
11
+
11
12
self .on_memory = on_memory
12
13
self .corpus_lines = corpus_lines
14
+ self .corpus_path = corpus_path
15
+ self .encoding = encoding
13
16
14
17
with open (corpus_path , "r" , encoding = encoding ) as f :
15
18
if self .corpus_lines is None and not on_memory :
@@ -21,6 +24,13 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=N
21
24
for line in tqdm .tqdm (f , desc = "Loading Dataset" , total = corpus_lines )]
22
25
self .corpus_lines = len (self .lines )
23
26
27
+ if not on_memory :
28
+ self .file = open (corpus_path , "r" , encoding = encoding )
29
+ self .random_file = open (corpus_path , "r" , encoding = encoding )
30
+
31
+ for _ in range (random .randint (self .corpus_lines if self .corpus_lines < 1000 else 1000 )):
32
+ self .random_file .__next__ ()
33
+
24
34
def __len__ (self ):
25
35
return self .corpus_lines
26
36
@@ -78,8 +88,36 @@ def random_word(self, sentence):
78
88
return tokens , output_label
79
89
80
90
def random_sent (self , index ):
91
+ t1 , t2 = self .get_corpus_line (index )
92
+
81
93
# output_text, label(isNotNext:0, isNext:1)
82
94
if random .random () > 0.5 :
83
- return self .datas [index ][0 ], self .datas [index ][1 ], 1
95
+ return t1 , t2 , 1
96
+ else :
97
+ return t1 , self .get_random_line (), 0
98
+
99
+ def get_corpus_line (self , item ):
100
+ if self .on_memory :
101
+ return self .lines [item ][0 ], self .lines [item ][1 ]
84
102
else :
85
- return self .datas [index ][0 ], self .datas [random .randrange (len (self .datas ))][1 ], 0
103
+ line = self .file .__next__ ()
104
+ if line is None :
105
+ self .file .close ()
106
+ self .file = open (self .corpus_path , "r" , encoding = self .encoding )
107
+ line = self .file .__next__ ()
108
+
109
+ t1 , t2 = line [:- 1 ].split ("\t " )
110
+ return t1 , t2
111
+
112
+ def get_random_line (self ):
113
+ if self .on_memory :
114
+ return self .lines [random .randrange (len (self .lines ))][1 ]
115
+
116
+ line = self .file .__next__ ()
117
+ if line is None :
118
+ self .file .close ()
119
+ self .file = open (self .corpus_path , "r" , encoding = self .encoding )
120
+ for _ in range (random .randint (self .corpus_lines if self .corpus_lines < 1000 else 1000 )):
121
+ self .random_file .__next__ ()
122
+ line = self .random_file .__next__ ()
123
+ return line [:- 1 ].split ("\t " )[1 ]
0 commit comments