Skip to content

Commit 826d0ba

Browse files
Merge pull request #1 from brightmart/master
merge update from upstream
2 parents 68e2fcf + a01c5ab commit 826d0ba

37 files changed

+91648
-180
lines changed

README.md

+64-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Text Classification
22
-------------------------------------------------------------------------
3-
the purpose of this repository is to explore text classification methods in NLP with deep learning.
3+
the purpose of this repository is to explore text classification methods in NLP with deep learning.
4+
5+
UPDATE:
6+
7+
1. <a href='https://github.com/brightmart/ai_law'>
8+
Apply AI in law cases task(AI_LAW): Predict the name of crimes(accusations), relevant-articles given facts of law cases</a>, has been released
9+
10+
2. <a href='https://github.com/brightmart/nlu_sim'>sentence similarity project has been released</a> you can check it if you like.
11+
12+
3. if you want to try a model now, you can go to folder 'a02_TextCNN', run 'python -u p7_TextCNN_train.py', it will use sample data to train a model, and print loss and F1 score periodically.
413

514
it has all kinds of baseline models for text classificaiton.
615

@@ -20,7 +29,9 @@ we implement two memory network. one is dynamic memory network. previously it re
2029

2130
the second memory network we implemented is recurrent entity network: tracking state of the world. it has blocks of key-value pairs as memory, run in parallel, which achieve new state of art. it can be used for modelling question answering with contexts(or history). for example, you can let the model to read some sentences(as context), and ask a question(as query), then ask the model to predict an answer; if you feed story same as query, then it can do classification task.
2231

23-
if you need some sample data and word embedding pertrained on word2vec, you can find it in closed issues, such as:<a href="https://github.com/brightmart/text_classification/issues/3">issue 3</a>
32+
if you need some sample data and word embedding pertrained on word2vec, you can find it in closed issues, such as:<a href="https://github.com/brightmart/text_classification/issues/3">issue 3</a>.
33+
34+
you can also find some sample data at folder "data". it contains two files:'sample_single_label.txt', contains 50k data with single label; 'sample_multiple_label.txt', contains 20k data with multiple labels. input and label of is separate by " __label__".
2435

2536
if you want to know more detail about dataset of text classification or task these models can be used, one of choose is below:
2637
https://biendata.com/competition/zhihu/
@@ -39,9 +50,10 @@ Models:
3950
8) Dynamic Memory Network
4051
9) EntityNetwork:tracking state of the world
4152
10) Ensemble models
42-
11) Stacking for single model level (TODO):
53+
11) Boosting:
4354

4455
for a single model, stack identical models together. each layer is a model. the result will be based on logits added together. the only connection between layers are label's weights. the front layer's prediction error rate of each label will become weight for the next layers. those labels with high error rate will have big weight. so later layer's will pay more attention to those mis-predicted labels, and try to fix previous mistake of former layer. as a result, we will get a much strong model.
56+
check a00_boosting/boosting.py
4557

4658
and other models:
4759

@@ -70,21 +82,21 @@ Training| 10m | 2h |10h | 2h | 2h |3h |3h |5h
7082

7183
Notice:
7284

73-
'm' stand for minutes; 'h' stand for hours;
85+
`m` stand for **minutes**; `h` stand for **hours**;
7486

75-
'HierAtteNet' means Hierarchical Attention Networkk;
87+
`HierAtteNet` means Hierarchical Attention Networkk;
7688

77-
'Seq2seqAttn' means Seq2seq with attention;
89+
`Seq2seqAttn` means Seq2seq with attention;
7890

79-
'DynamicMemory' means DynamicMemoryNetwork;
91+
`DynamicMemory` means DynamicMemoryNetwork;
8092

81-
'Transformer' stand for model from 'Attention Is All You Need'.
93+
`Transformer` stand for model from 'Attention Is All You Need'.
8294

8395
Useage:
8496
-------------------------------------------------------------------------------------------------------
85-
1) model is in xxx_model.py
86-
2) run python xxx_train.py to train the model
87-
3) run python xxx_predict.py to do inference(test).
97+
1) model is in `xxx_model.py`
98+
2) run python `xxx_train.py` to train the model
99+
3) run python `xxx_predict.py` to do inference(test).
88100

89101
Each model has a test method under the model class. you can run the test method first to check whether the model can work properly.
90102

@@ -94,7 +106,9 @@ Environment:
94106
-------------------------------------------------------------------------------------------------------
95107
python 2.7+ tensorflow 1.1
96108

97-
(tensorflow 1.2 also works; most of models should also work fine in other tensorflow version, since we use very few features bond to certain version; if you use python 3.5, it will be fine as long as you change print/try catch function)
109+
(tensorflow 1.2,1.3,1.4 also works; most of models should also work fine in other tensorflow version, since we use very few features bond to certain version; if you use python 3.5, it will be fine as long as you change print/try catch function)
110+
111+
TextCNN model is already transfomed to python 3.6
98112

99113
-------------------------------------------------------------------------
100114

@@ -104,6 +118,19 @@ Some util function is in data_util.py;
104118
typical input like: "x1 x2 x3 x4 x5 __label__ 323434" where 'x1,x2' is words, '323434' is label;
105119
it has a function to load and assign pretrained word embedding to the model,where word embedding is pretrained in word2vec or fastText.
106120

121+
Pretrain Work Embedding:
122+
-------------------------------------------------------------------------------------------------------
123+
if word2vec.load not works, you may load pretrained word embedding, especially for chinese word embedding use following lines:
124+
125+
import gensim
126+
127+
from gensim.models import KeyedVectors
128+
129+
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_model_path, binary=True, unicode_errors='ignore') #
130+
131+
or you can turn off use pretrain word embedding flag to false to disable loading word embedding.
132+
133+
107134
Models Detail:
108135
-------------------------------------------------------------------------
109136

@@ -119,6 +146,7 @@ result: performance is as good as paper, speed also very fast.
119146

120147
check: p5_fastTextB_model.py
121148

149+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/fastText.JPG)
122150
-------------------------------------------------------------------------
123151

124152
2.TextCNN:
@@ -141,15 +169,25 @@ Thirdly, we will concatenate scalars to form final features. It is a fixed-size
141169

142170
Finally, we will use linear layer to project these features to per-defined labels.
143171

172+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/TextCNN.JPG)
173+
144174
-------------------------------------------------------------------------
145175

146176

147177
3.TextRNN
148178
-------------
149-
Structure:embedding--->bi-directional lstm--->concat output--->average----->softmax
179+
Structure v1:embedding--->bi-directional lstm--->concat output--->average----->softmax layer
150180

151181
check: p8_TextRNN_model.py
152182

183+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/bi-directionalRNN.JPG)
184+
185+
Structure v2:embedding-->bi-directional lstm---->dropout-->concat ouput--->lstm--->droput-->FC layer-->softmax layer
186+
187+
check: p8_TextRNN_model_multilayer.py
188+
189+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/emojifier-v2.png)
190+
153191

154192
-------------------------------------------------------------------------
155193

@@ -205,6 +243,7 @@ for left side context, it use a recurrent structure, a no-linearity transfrom of
205243

206244
check: p71_TextRCNN_model.py
207245

246+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/RCNN.JPG)
208247

209248
-------------------------------------------------------------------------
210249

@@ -226,6 +265,8 @@ Structure:
226265

227266
5) FC+Softmax
228267

268+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/HAN.JPG)
269+
229270
In NLP, text classification can be done for single sentence, but it can also be used for multiple sentences. we may call it document classification. Words are form to sentence. And sentence are form to document. In this circumstance, there may exists a intrinsic structure. So how can we model this kinds of task? Does all parts of document are equally relevant? And how we determine which part are more important than another?
230271

231272
It has two unique features:
@@ -254,6 +295,8 @@ In my training data, for each example, i have four parts. each part has same len
254295

255296
check:p1_HierarchicalAttention_model.py
256297

298+
for attentive attention you can check <a href='https://github.com/brightmart/text_classification/issues/55'>attentive attention</a>
299+
257300
-------------------------------------------------------------------------
258301

259302
9.Seq2seq with attention
@@ -264,6 +307,8 @@ I.Structure:
264307

265308
1)embedding 2)bi-GRU too get rich representation from source sentences(forward & backward). 3)decoder with attention.
266309

310+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/seq2seqAttention.JPG)
311+
267312
II.Input of data:
268313

269314
there are two kinds of three kinds of inputs:1)encoder inputs, which is a sentence; 2)decoder inputs, it is labels list with fixed length;3)target labels, it is also a list of labels.
@@ -308,6 +353,8 @@ For every building blocks, we include a test function in the each file below, an
308353

309354
Sequence to sequence with attention is a typical model to solve sequence generation problem, such as translate, dialogue system. most of time, it use RNN as buidling block to do these tasks. util recently, people also apply convolutional Neural Network for sequence to sequence problem. Transformer, however, it perform these tasks solely on attention mechansim. it is fast and acheive new state-of-art result.
310355

356+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/attention_is_all_you_need.JPG)
357+
311358
It also has two main parts: encoder and decoder. below is desc from paper:
312359

313360
Encoder:
@@ -365,6 +412,8 @@ b. get weighted sum of hidden state using possibility distribution.
365412

366413
c. non-linearity transform of query and hidden state to get predict label.
367414

415+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/EntityNet.JPG)
416+
368417
Main take away from this model:
369418

370419
1) use blocks of keys and values, which is independent from each other. so it can be run in parallel.
@@ -391,6 +440,8 @@ Outlook of Model:
391440

392441
4.Answer Module:generate an answer from the final memory vector.
393442

443+
![alt text](https://github.com/brightmart/text_classification/blob/master/images/DMN.JPG)
444+
394445
Detail:
395446

396447
1.Input Module:

a00_boosting/a08_boosting.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# -*- coding: utf-8 -*-
2+
import sys
3+
reload(sys)
4+
sys.setdefaultencoding('utf8')
5+
import tensorflow as tf
6+
7+
#main process for boosting:
8+
#1.compute label weight after each epoch using validation data.
9+
#2.get weights for each batch during traininig process
10+
#3.compute loss using cross entropy with weights
11+
12+
#1.compute label weight after each epoch using validation data.
13+
def compute_labels_weights(weights_label,logits,labels):
14+
"""
15+
compute weights for labels in current batch, and update weights_label(a dict)
16+
:param weights_label:a dict
17+
:param logit: [None,Vocabulary_size]
18+
:param label: [None,]
19+
:return:
20+
"""
21+
labels_predict=np.argmax(logits,axis=1) # logits:(256,108,754)
22+
for i in range(len(labels)):
23+
label=labels[i]
24+
label_predict=labels_predict[i]
25+
weight=weights_label.get(label,None)
26+
if weight==None:
27+
if label_predict == label:
28+
weights_label[label]=(1,1)
29+
else:
30+
weights_label[label]=(1,0)
31+
else:
32+
number=weight[0]
33+
correct=weight[1]
34+
number=number+1
35+
if label_predict==label:
36+
correct=correct+1
37+
weights_label[label]=(number,correct)
38+
return weights_label
39+
40+
#2.get weights for each batch during traininig process
41+
def get_weights_for_current_batch(answer_list,weights_dict):
42+
"""
43+
get weights for current batch
44+
:param answer_list: a numpy array contain labels for a batch
45+
:param weights_dict: a dict that contain weights for all labels
46+
:return: a list. length is label size.
47+
"""
48+
weights_list_batch=list(np.ones((len(answer_list))))
49+
answer_list=list(answer_list)
50+
for i,label in enumerate(answer_list):
51+
acc=weights_dict[label]
52+
weights_list_batch[i]=min(1.5,1.0/(acc+0.001))
53+
#if np.random.choice(200)==0: #print something from time to time
54+
# print("weights_list_batch:",weights_list_batch)
55+
return weights_list_batch
56+
57+
#3.compute loss using cross entropy with weights
58+
def loss(logits,labels,weights):
59+
loss= tf.losses.sparse_softmax_cross_entropy(labels, logits,weights=weights)
60+
return loss
61+
62+
#######################################################################
63+
#util function
64+
def get_weights_label_as_standard_dict(weights_label):
65+
weights_dict = {}
66+
for k,v in weights_label.items():
67+
count,correct=v
68+
weights_dict[k]=float(correct)/float(count)
69+
return weights_dict

a02_TextCNN/__init__.py

Whitespace-only changes.
3.36 KB
Binary file not shown.
Binary file not shown.

a02_TextCNN/data_util.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# -*- coding: utf-8 -*-
2+
import codecs
3+
import random
4+
import numpy as np
5+
from tflearn.data_utils import pad_sequences
6+
from collections import Counter
7+
import os
8+
import pickle
9+
10+
PAD_ID = 0
11+
UNK_ID=1
12+
_PAD="_PAD"
13+
_UNK="UNK"
14+
15+
16+
def load_data_multilabel(traning_data_path,vocab_word2index, vocab_label2index,sentence_len,training_portion=0.95):
17+
"""
18+
convert data as indexes using word2index dicts.
19+
:param traning_data_path:
20+
:param vocab_word2index:
21+
:param vocab_label2index:
22+
:return:
23+
"""
24+
file_object = codecs.open(traning_data_path, mode='r', encoding='utf-8')
25+
lines = file_object.readlines()
26+
random.shuffle(lines)
27+
label_size=len(vocab_label2index)
28+
X = []
29+
Y = []
30+
for i,line in enumerate(lines):
31+
raw_list = line.strip().split("__label__")
32+
input_list = raw_list[0].strip().split(" ")
33+
input_list = [x.strip().replace(" ", "") for x in input_list if x != '']
34+
x=[vocab_word2index.get(x,UNK_ID) for x in input_list]
35+
label_list = raw_list[1:]
36+
label_list=[l.strip().replace(" ", "") for l in label_list if l != '']
37+
label_list=[vocab_label2index[label] for label in label_list]
38+
y=transform_multilabel_as_multihot(label_list,label_size)
39+
X.append(x)
40+
Y.append(y)
41+
X = pad_sequences(X, maxlen=sentence_len, value=0.) # padding to max length
42+
number_examples = len(lines)
43+
training_number=int(training_portion* number_examples)
44+
train = (X[0:training_number], Y[0:training_number])
45+
valid_number=min(1000,number_examples-training_number)
46+
test = (X[training_number+ 1:training_number+valid_number+1], Y[training_number + 1:training_number+valid_number+1])
47+
return train,test
48+
49+
50+
def transform_multilabel_as_multihot(label_list,label_size):
51+
"""
52+
convert to multi-hot style
53+
:param label_list: e.g.[0,1,4], here 4 means in the 4th position it is true value(as indicate by'1')
54+
:param label_size: e.g.199
55+
:return:e.g.[1,1,0,1,0,0,........]
56+
"""
57+
result=np.zeros(label_size)
58+
#set those location as 1, all else place as 0.
59+
result[label_list] = 1
60+
return result
61+
62+
#use pretrained word embedding to get word vocabulary and labels, and its relationship with index
63+
def create_vocabulary(training_data_path,vocab_size,name_scope='cnn'):
64+
"""
65+
create vocabulary
66+
:param training_data_path:
67+
:param vocab_size:
68+
:param name_scope:
69+
:return:
70+
"""
71+
72+
cache_vocabulary_label_pik='cache'+"_"+name_scope # path to save cache
73+
if not os.path.isdir(cache_vocabulary_label_pik): # create folder if not exists.
74+
os.makedirs(cache_vocabulary_label_pik)
75+
76+
# if cache exists. load it; otherwise create it.
77+
cache_path =cache_vocabulary_label_pik+"/"+'vocab_label.pik'
78+
print("cache_path:",cache_path,"file_exists:",os.path.exists(cache_path))
79+
if os.path.exists(cache_path):
80+
with open(cache_path, 'rb') as data_f:
81+
return pickle.load(data_f)
82+
else:
83+
vocabulary_word2index={}
84+
vocabulary_index2word={}
85+
vocabulary_word2index[_PAD]=PAD_ID
86+
vocabulary_index2word[PAD_ID]=_PAD
87+
vocabulary_word2index[_UNK]=UNK_ID
88+
vocabulary_index2word[UNK_ID]=_UNK
89+
90+
vocabulary_label2index={}
91+
vocabulary_index2label={}
92+
93+
#1.load raw data
94+
file_object = codecs.open(training_data_path, mode='r', encoding='utf-8')
95+
lines=file_object.readlines()
96+
#2.loop each line,put to counter
97+
c_inputs=Counter()
98+
c_labels=Counter()
99+
for line in lines:
100+
raw_list=line.strip().split("__label__")
101+
102+
input_list = raw_list[0].strip().split(" ")
103+
input_list = [x.strip().replace(" ", "") for x in input_list if x != '']
104+
label_list=[l.strip().replace(" ","") for l in raw_list[1:] if l!='']
105+
c_inputs.update(input_list)
106+
c_labels.update(label_list)
107+
#return most frequency words
108+
vocab_list=c_inputs.most_common(vocab_size)
109+
label_list=c_labels.most_common()
110+
#put those words to dict
111+
for i,tuplee in enumerate(vocab_list):
112+
word,_=tuplee
113+
vocabulary_word2index[word]=i+2
114+
vocabulary_index2word[i+2]=word
115+
116+
for i,tuplee in enumerate(label_list):
117+
label,_=tuplee;label=str(label)
118+
vocabulary_label2index[label]=i
119+
vocabulary_index2label[i]=label
120+
121+
#save to file system if vocabulary of words not exists.
122+
if not os.path.exists(cache_path):
123+
with open(cache_path, 'ab') as data_f:
124+
pickle.dump((vocabulary_word2index,vocabulary_index2word,vocabulary_label2index,vocabulary_index2label), data_f)
125+
return vocabulary_word2index,vocabulary_index2word,vocabulary_label2index,vocabulary_index2label
126+
127+
#training_data_path='../data/sample_multiple_label3.txt'
128+
#vocab_size=100
129+
#create_voabulary(training_data_path,vocab_size)

a02_TextCNN/other_experiement/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)