1919import numpy as np
2020import io
2121
22+ import tokenization
2223
2324class SimNetProcessor (object ):
2425 def __init__ (self , args , vocab ):
@@ -27,6 +28,10 @@ def __init__(self, args, vocab):
2728 self .vocab = vocab
2829 self .valid_label = np .array ([])
2930 self .test_label = np .array ([])
31+ if args .tokenizer :
32+ self .tokenizer = getattr (tokenization , args .tokenizer )()
33+ else :
34+ self .tokenizer = None
3035
3136 def get_reader (self , mode , epoch = 0 ):
3237 """
@@ -48,6 +53,12 @@ def reader_with_pairwise():
4853 logging .warning (
4954 "line not match format in test file" )
5055 continue
56+
57+ # tokenize
58+ if self .tokenizer :
59+ query = self .tokenizer .tokenize (query )
60+ title = self .tokenizer .tokenize (title )
61+
5162 query = [
5263 self .vocab [word ] for word in query .split (" " )
5364 if word in self .vocab
@@ -71,6 +82,12 @@ def reader_with_pairwise():
7182 logging .warning (
7283 "line not match format in test file" )
7384 continue
85+
86+ # tokenize
87+ if self .tokenizer :
88+ query = self .tokenizer .tokenize (query )
89+ title = self .tokenizer .tokenize (title )
90+
7491 query = [
7592 self .vocab [word ] for word in query .split (" " )
7693 if word in self .vocab
@@ -95,6 +112,12 @@ def reader_with_pairwise():
95112 logging .warning (
96113 "line not match format in test file" )
97114 continue
115+ # tokenize
116+ if self .tokenizer :
117+ query = self .tokenizer .tokenize (query )
118+ pos_title = self .tokenizer .tokenize (pos_title )
119+ neg_title = self .tokenizer .tokenize (neg_title )
120+
98121 query = [
99122 self .vocab [word ] for word in query .split (" " )
100123 if word in self .vocab
@@ -130,6 +153,12 @@ def reader_with_pointwise():
130153 logging .warning (
131154 "line not match format in test file" )
132155 continue
156+
157+ # tokenize
158+ if self .tokenizer :
159+ query = self .tokenizer .tokenize (query )
160+ title = self .tokenizer .tokenize (title )
161+
133162 query = [
134163 self .vocab [word ] for word in query .split (" " )
135164 if word in self .vocab
@@ -153,6 +182,12 @@ def reader_with_pointwise():
153182 logging .warning (
154183 "line not match format in test file" )
155184 continue
185+
186+ # tokenize
187+ if self .tokenizer :
188+ query = self .tokenizer .tokenize (query )
189+ title = self .tokenizer .tokenize (title )
190+
156191 query = [
157192 self .vocab [word ] for word in query .split (" " )
158193 if word in self .vocab
@@ -178,6 +213,12 @@ def reader_with_pointwise():
178213 logging .warning (
179214 "line not match format in test file" )
180215 continue
216+
217+ # tokenize
218+ if self .tokenizer :
219+ query = self .tokenizer .tokenize (query )
220+ title = self .tokenizer .tokenize (title )
221+
181222 query = [
182223 self .vocab [word ] for word in query .split (" " )
183224 if word in self .vocab
@@ -208,6 +249,10 @@ def get_infer_reader(self):
208249 if len (query ) == 0 or len (title ) == 0 :
209250 logging .warning ("line not match format in test file" )
210251 continue
252+ # tokenize
253+ if self .tokenizer :
254+ query = self .tokenizer .tokenize (query )
255+ title = self .tokenizer .tokenize (title )
211256 query = [
212257 self .vocab [word ] for word in query .split (" " )
213258 if word in self .vocab
0 commit comments