Skip to content

Commit 75e6a06

Browse files
committed
Update
1 parent 51474d0 commit 75e6a06

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

ReutersTextCategorizationTransformerDemoKeras.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,8 @@
1313
from glob import glob
1414
from scipy.sparse import csc_matrix, csr_matrix, dok_array
1515

16-
relevance_threshold = 0.5
17-
profile_threshold = 2.5
18-
1916
@jit(nopython=True)
20-
def count_tokens(X_indices, X_indptr, word_profile_data, word_profile_indices, word_profile_indptr, feature_map):
17+
def count_tokens(X_indices, X_indptr, word_profile_data, word_profile_indices, word_profile_indptr, feature_map, relevance_threshold, profile_threshold):
2118
document_vector = np.zeros(word_profile_indptr.shape[0]-1)
2219
target_word_profile = np.zeros(word_profile_indptr.shape[0]-1)
2320
target_word_refined_profile = np.zeros(word_profile_indptr.shape[0]-1)
@@ -82,7 +79,7 @@ def count_tokens(X_indices, X_indptr, word_profile_data, word_profile_indices, w
8279
return global_token_count
8380

8481
@jit(nopython=True)
85-
def embed_X(X_indices, X_indptr, word_profile_data, word_profile_indices, word_profile_indptr, feature_map, token_count):
82+
def embed_X(X_indices, X_indptr, word_profile_data, word_profile_indices, word_profile_indptr, feature_map, token_count, relevance_threshold, profile_threshold):
8683
document_vector = np.zeros(word_profile_indptr.shape[0]-1)
8784
target_word_profile = np.zeros(word_profile_indptr.shape[0]-1)
8885
target_word_refined_profile = np.zeros(word_profile_indptr.shape[0]-1)
@@ -146,7 +143,8 @@ def embed_X(X_indices, X_indptr, word_profile_data, word_profile_indices, word_p
146143
global_token_count += 1
147144
X_embedded_indptr[row+1] = global_token_count
148145

149-
print(row, X_indptr.shape[0]-1, document_token_count)
146+
if row % 100 == 0:
147+
print(row, X_indptr.shape[0]-1, document_token_count)
150148

151149
#print(row, X_indptr[row], X_indptr[row+1], X_indptr[row+1] - X_indptr[row], document_token_count)
152150
return (X_embedded_data, X_embedded_indices, X_embedded_indptr)
@@ -166,6 +164,9 @@ def embed_X(X_indices, X_indptr, word_profile_data, word_profile_indices, word_p
166164
parser.add_argument("--features", default=5000, type=int)
167165
parser.add_argument("--reuters-num-words", default=10000, type=int)
168166
parser.add_argument("--reuters-index-from", default=2, type=int)
167+
parser.add_argument("--relevance_threshold", default=0.25, type=float)
168+
parser.add_argument("--profile_threshold", default=0.5, type=float)
169+
169170
args = parser.parse_args()
170171

171172

@@ -233,13 +234,13 @@ def embed_X(X_indices, X_indptr, word_profile_data, word_profile_indices, word_p
233234
feature_map[i] = 0
234235

235236
# Counts number of tokens in the augmented dataset to allocate memory for sparse data structure
236-
token_count = count_tokens(X_train.indices, X_train.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map)
237-
(X_train_embedded_data, X_train_embedded_indices, X_train_embedded_indptr) = embed_X(X_train.indices, X_train.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, token_count)
237+
token_count = count_tokens(X_train.indices, X_train.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, args.relevance_threshold, args.profile_threshold)
238+
(X_train_embedded_data, X_train_embedded_indices, X_train_embedded_indptr) = embed_X(X_train.indices, X_train.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, token_count, args.relevance_threshold, args.profile_threshold)
238239
X_train_embedded = csr_matrix((X_train_embedded_data, X_train_embedded_indices, X_train_embedded_indptr))
239240

240241
# Counts number of tokens in the augmented dataset to allocate memory for sparse data structure
241-
token_count = count_tokens(X_test.indices, X_test.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map)
242-
(X_test_embedded_data, X_test_embedded_indices, X_test_embedded_indptr) = embed_X(X_test.indices, X_test.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, token_count)
242+
token_count = count_tokens(X_test.indices, X_test.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, args.relevance_threshold, args.profile_threshold)
243+
(X_test_embedded_data, X_test_embedded_indices, X_test_embedded_indptr) = embed_X(X_test.indices, X_test.indptr, word_profile.data, word_profile.indices, word_profile.indptr, feature_map, token_count, args.relevance_threshold, args.profile_threshold)
243244
X_test_embedded = csr_matrix((X_test_embedded_data, X_test_embedded_indices, X_test_embedded_indptr))
244245

245246

0 commit comments

Comments
 (0)