1
1
from torch .utils .data import Dataset
2
+ from .vocab import WordVocab
2
3
import tqdm
3
4
import random
4
5
import argparse
@@ -14,7 +15,7 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8"):
14
15
with open (corpus_path , "r" , encoding = encoding ) as f :
15
16
for line in tqdm .tqdm (f , desc = "Loading Dataset" ):
16
17
t1 , t2 , t1_l , t2_l , is_next = line [:- 1 ].split ("\t " )
17
- t1_l , t2_l = [[int ( i ) for i in label .split (", " )] for label in [t1_l , t2_l ]]
18
+ t1_l , t2_l = [[token for token in label .split (" " )] for label in [t1_l , t2_l ]]
18
19
is_next = int (is_next )
19
20
self .datas .append ({
20
21
"t1" : t1 ,
@@ -29,12 +30,14 @@ def __len__(self):
29
30
30
31
def __getitem__ (self , item ):
31
32
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
32
- t1 , t1_len = self .vocab .to_seq (self .datas [item ]["t1" ], seq_len = self .seq_len , with_sos = True , with_eos = True )
33
- t2 , t2_len = self .vocab .to_seq (self .datas [item ]["t2" ], seq_len = self .seq_len , with_eos = True )
33
+ t1 = self .vocab .to_seq (self .datas [item ]["t1" ], with_sos = True , with_eos = True )
34
+ t2 = self .vocab .to_seq (self .datas [item ]["t2" ], with_eos = True )
35
+
36
+ t1_label = self .vocab .to_seq (self .datas [item ]["t1_label" ])
37
+ t2_label = self .vocab .to_seq (self .datas [item ]["t2_label" ])
34
38
35
39
output = {"t1" : t1 , "t2" : t2 ,
36
- "t1_len" : t1_len , "t2_len" : t2_len ,
37
- "t1_label" : self .datas [item ]["t1_label" ], "t2_label" : self .datas [item ]["t2_label" ],
40
+ "t1_label" : t1_label , "t2_label" : t2_label ,
38
41
"is_next" : self .datas [item ]["is_next" ]}
39
42
40
43
return {key : torch .tensor (value ) for key , value in output .items ()}
@@ -79,38 +82,18 @@ def random_word(self, sentence):
79
82
def random_sent (self , index ):
80
83
# output_text, label(isNotNext:0, isNext:1)
81
84
if random .random () > 0.5 :
82
- return self .datas [index ][2 ], 1
85
+ return self .datas [index ][1 ], 1
83
86
else :
84
- return self .datas [random .randrange (len (self .datas ))][2 ], 0
87
+ return self .datas [random .randrange (len (self .datas ))][1 ], 0
85
88
86
89
def __getitem__ (self , index ):
87
- t1 , (t2 , is_next_label ) = self .datas [index ], self .random_sent (index )
90
+ t1 , (t2 , is_next_label ) = self .datas [index ][ 0 ] , self .random_sent (index )
88
91
t1_random , t1_label = self .random_word (t1 )
89
92
t2_random , t2_label = self .random_word (t2 )
90
93
91
94
return {"t1_random" : t1_random , "t2_random" : t2_random ,
92
95
"t1_label" : t1_label , "t2_label" : t2_label ,
93
96
"is_next" : is_next_label }
94
97
95
-
96
- if __name__ == "__main__" :
97
- from .vocab import WordVocab
98
-
99
- parser = argparse .ArgumentParser ()
100
- parser .add_argument ("-v" , "--vocab_path" , required = True , type = str )
101
- parser .add_argument ("-c" , "--corpus_path" , required = True , type = str )
102
- parser .add_argument ("-e" , "--encoding" , default = "utf-8" , type = str )
103
- parser .add_argument ("-o" , "--output_path" , required = True , type = str )
104
- args = parser .parse_args ()
105
-
106
- word_vocab = WordVocab .load_vocab (args .vocab_path )
107
- builder = BERTDatasetCreator (corpus_path = args .corpus_path , vocab = word_vocab , seq_len = None , encoding = args .encoding )
108
-
109
- with open (args .output_path , 'w' , encoding = args .encoding ) as f :
110
- for index in tqdm .tqdm (range (len (builder )), desc = "Building Dataset" , total = len (builder )):
111
- data = builder [index ]
112
- output_form = "%s\t %s\t %s\t %d\n "
113
- t1_text , t2_text = [" " .join (t ) for t in [data ["t1_random" ], data ["t2_random" ]]]
114
- t1_label , t2_label = ["," .join ([str (i ) for i in label ]) for label in [data ["t1_label" ], data ["t2_label" ]]]
115
- output = output_form % (t1_text , t2_text , t1_label , t2_label , data ["is_next" ])
116
- f .write (output_form )
98
+ def __len__ (self ):
99
+ return len (self .datas )
0 commit comments