Skip to content

Commit 317e5e0

Browse files
committed
add new model
1 parent 75d28c7 commit 317e5e0

File tree

11 files changed

+555
-235
lines changed

11 files changed

+555
-235
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
/datasets
1+
/datasets
2+
.ipynb_checkpoints/
3+
__pycache__/

README.md

Lines changed: 117 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,117 @@
1-
## Text-CNN
2-
3-
使用 PyTorch 实现 [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) 中提出的文本分类方法。
4-
5-
## 数据集
6-
7-
此处使用的数据集来自 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 作者整理的数据集。下载链接:https://pan.baidu.com/s/1hugrfRu 密码: qfud
8-
9-
该数据集共包含 10 个类别,每个类别有 6500 条数据。类别如下:
10-
11-
```
12-
'体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'
13-
```
14-
15-
数据集划分如下:
16-
17-
- 训练集: 5000 * 10
18-
- 验证集: 500 * 10
19-
- 测试集: 1000 * 10
20-
21-
## 运行方法
22-
23-
下载数据集,并解压至 `datasets` 目录下,在 `main.py` 中做适当调整,然后运行:
24-
25-
```
26-
$ python main.py
27-
```
28-
29-
运行结果:
30-
31-
```
32-
2019-05-24 20:45:03,204 - using device: cuda:7
33-
2019-05-24 20:45:03,205 - load and preprocess data...
34-
2019-05-24 20:45:15,800 - training...
35-
2019-05-24 20:45:30,872 - epoch: 1 - loss: 0.06 acc: 0.65 - val_loss: 0.03 val_acc: 0.75
36-
2019-05-24 20:45:41,568 - epoch: 2 - loss: 0.05 acc: 0.80 - val_loss: 0.03 val_acc: 0.77
37-
2019-05-24 20:45:52,137 - epoch: 3 - loss: 0.05 acc: 0.82 - val_loss: 0.03 val_acc: 0.82
38-
2019-05-24 20:46:02,975 - epoch: 4 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.78
39-
2019-05-24 20:46:13,769 - epoch: 5 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.82
40-
2019-05-24 20:46:24,514 - epoch: 6 - loss: 0.05 acc: 0.87 - val_loss: 0.02 val_acc: 0.90
41-
2019-05-24 20:46:35,237 - epoch: 7 - loss: 0.05 acc: 0.92 - val_loss: 0.02 val_acc: 0.90
42-
2019-05-24 20:46:45,801 - epoch: 8 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.91
43-
2019-05-24 20:46:56,050 - epoch: 9 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.93
44-
2019-05-24 20:47:06,771 - epoch: 10 - loss: 0.05 acc: 0.94 - val_loss: 0.02 val_acc: 0.94
45-
2019-05-24 20:47:07,000 - predicting...
46-
2019-05-24 20:47:07,435 - test - acc: 0.9326
47-
```
48-
49-
这里并没有对文本进行过多的预处理,比如去除特殊符号,停用词等。另外直接采用了字作为特征,对于中文文本分类,感觉分词已经没有必要了。
50-
51-
我使用 [FastText](https://fasttext.cc/) 对该数据集进行了分类,发现分类准确度能轻松达到 99% 以上。这也表明,对于长文本分类问题,词袋模型就足够了。深度模型,可能更适合于一些复杂的场景,比如词与词之间关系较大时。
52-
53-
```
54-
F1-Score : 0.999400 Precision : 0.999800 Recall : 0.999000 __label__0
55-
F1-Score : 0.995690 Precision : 0.997991 Recall : 0.993400 __label__5
56-
F1-Score : 0.996396 Precision : 0.997395 Recall : 0.995400 __label__1
57-
F1-Score : 0.998701 Precision : 0.998003 Recall : 0.999400 __label__2
58-
F1-Score : 0.999000 Precision : 0.999400 Recall : 0.998600 __label__3
59-
F1-Score : 0.983119 Precision : 0.987884 Recall : 0.978400 __label__8
60-
F1-Score : 0.997598 Precision : 0.998397 Recall : 0.996800 __label__9
61-
F1-Score : 0.985344 Precision : 0.975873 Recall : 0.995000 __label__4
62-
F1-Score : 0.996898 Precision : 0.997597 Recall : 0.996200 __label__6
63-
F1-Score : 0.998700 Precision : 0.998800 Recall : 0.998600 __label__7
64-
N 50000
65-
P@1 0.995
66-
R@1 0.995
67-
```
68-
69-
## 配置
70-
71-
```python
72-
class_num=10 # 类别数
73-
embed_num=5000 # 字典大小
74-
embed_dim=64 # 字向量维度
75-
kernel_num=128 # 卷积核数量
76-
kernel_size_list=[3,4,5] # 卷积核尺寸
77-
dropout=0.5 # 置 0 的概率
78-
```
79-
80-
## Text CNN 模型
81-
82-
![image](https://user-images.githubusercontent.com/7794103/58327903-63a30180-7e63-11e9-9c82-acc55c8e0b21.png)
83-
84-
该模型的基本思想是对输入序列先做 Embedding,而后使用不同窗口大小的 1D Conv 提取特征,经过 MaxPooing1D 后 一个卷积核得到一个标量,最后全部拼接起来,得到一个向量,然后使用全连接层加 softmax 进行分类。
1+
## Text-Classification
2+
3+
使用 PyTorch 实现了以下几种文本分类模型:
4+
5+
#### Text-CNN
6+
7+
- 目录:[cnn](./cnn)
8+
- 论文:[Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf)
9+
10+
#### Text-RCNN
11+
12+
- 目录:[rcnn](./rcnn)
13+
- 论文: [Recurrent Convolutional Neural Networks for Text Classification](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745/9552)
14+
15+
#### RNN-Attention
16+
17+
- 目录:[rnn-attention](./rnn-attention)
18+
- 论文: [Hierarchical Attention Networks for Document Classification](https://www.aclweb.org/anthology/N16-1174) - 简化版实现。
19+
20+
## 数据集
21+
22+
此处使用的数据集来自 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn) 作者整理的数据集。下载链接:https://pan.baidu.com/s/1hugrfRu 密码: qfud
23+
24+
该数据集共包含 10 个类别,每个类别有 6500 条数据。类别如下:
25+
26+
```
27+
'体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'
28+
```
29+
30+
数据集划分如下:
31+
32+
- 训练集: 5000 * 10
33+
- 验证集: 500 * 10
34+
- 测试集: 1000 * 10
35+
36+
## 运行方法
37+
38+
**1. 下载数据集**
39+
40+
下载数据集并解压至 `datasets` 目录下。
41+
42+
**2. 配置参数**
43+
44+
`mian.py` 中做适当调整,然后运行:
45+
46+
```
47+
$ python main.py
48+
```
49+
50+
## 运行结果:
51+
52+
这里并没有对文本进行过多的预处理,比如去除特殊符号,停用词等。另外直接采用了字作为特征,对于中文文本分类,感觉分词已经没有必要了。
53+
54+
以下都是用默认参数跑出来的结果,实验使用的 GPU 为 Tesla V100,如果要用 CPU 跑建议减少数据量,并限制文本长度。
55+
56+
### Text-CNN
57+
58+
```
59+
2019-05-24 20:45:30,872 - epoch: 1 - loss: 0.06 acc: 0.65 - val_loss: 0.03 val_acc: 0.75
60+
2019-05-24 20:45:41,568 - epoch: 2 - loss: 0.05 acc: 0.80 - val_loss: 0.03 val_acc: 0.77
61+
2019-05-24 20:45:52,137 - epoch: 3 - loss: 0.05 acc: 0.82 - val_loss: 0.03 val_acc: 0.82
62+
2019-05-24 20:46:02,975 - epoch: 4 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.78
63+
2019-05-24 20:46:13,769 - epoch: 5 - loss: 0.05 acc: 0.83 - val_loss: 0.03 val_acc: 0.82
64+
2019-05-24 20:46:24,514 - epoch: 6 - loss: 0.05 acc: 0.87 - val_loss: 0.02 val_acc: 0.90
65+
2019-05-24 20:46:35,237 - epoch: 7 - loss: 0.05 acc: 0.92 - val_loss: 0.02 val_acc: 0.90
66+
2019-05-24 20:46:45,801 - epoch: 8 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.91
67+
2019-05-24 20:46:56,050 - epoch: 9 - loss: 0.05 acc: 0.93 - val_loss: 0.02 val_acc: 0.93
68+
2019-05-24 20:47:06,771 - epoch: 10 - loss: 0.05 acc: 0.94 - val_loss: 0.02 val_acc: 0.94
69+
70+
2019-05-24 20:47:07,435 - test - acc: 0.9326
71+
```
72+
73+
### Text-RCNN
74+
75+
```
76+
2019-05-26 12:40:35,331 - epoch 1 - loss: 0.02 acc: 0.81 - val_loss: 0.00 val_acc: 0.90
77+
2019-05-26 12:42:10,316 - epoch 2 - loss: 0.01 acc: 0.94 - val_loss: 0.01 val_acc: 0.90
78+
2019-05-26 12:43:42,279 - epoch 3 - loss: 0.01 acc: 0.95 - val_loss: 0.00 val_acc: 0.93
79+
2019-05-26 12:45:14,370 - epoch 4 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.91
80+
2019-05-26 12:46:46,713 - epoch 5 - loss: 0.00 acc: 0.96 - val_loss: 0.00 val_acc: 0.94
81+
82+
2019-05-26 12:46:51,099 - test - acc: 0.95
83+
```
84+
85+
相对 CNN 而言,RCNN 训练花费时间更多,RCNN 训练一个 epoch 可以让 CNN 训练 10 个 epoch。另外 RCNN 需要的 epoch 数相对较少,这里第一个 epoch 结束后,验证集上就达到了 90% 的准确度。
86+
87+
### RNN-Attention
88+
89+
```
90+
2019-05-26 12:55:42,786 - epoch 1 - loss: 0.03 acc: 0.66 - val_loss: 0.01 val_acc: 0.80
91+
2019-05-26 12:57:04,999 - epoch 2 - loss: 0.01 acc: 0.87 - val_loss: 0.01 val_acc: 0.84
92+
2019-05-26 12:58:36,714 - epoch 3 - loss: 0.01 acc: 0.91 - val_loss: 0.01 val_acc: 0.88
93+
2019-05-26 13:00:08,892 - epoch 4 - loss: 0.01 acc: 0.93 - val_loss: 0.01 val_acc: 0.89
94+
2019-05-26 13:01:41,746 - epoch 5 - loss: 0.01 acc: 0.94 - val_loss: 0.00 val_acc: 0.92
95+
96+
2019-05-26 13:01:47,011 - test - acc: 0.9212
97+
```
98+
99+
### FastText
100+
101+
另外,我使用 [FastText](https://fasttext.cc/) 对该数据集进行了分类,发现分类准确度能轻松达到 99% 以上。这也表明,对于长文本分类问题,词袋模型就足够了。深度模型,在此简单任务上并没有优势。
102+
103+
```
104+
F1-Score : 0.999400 Precision : 0.999800 Recall : 0.999000 __label__0
105+
F1-Score : 0.995690 Precision : 0.997991 Recall : 0.993400 __label__5
106+
F1-Score : 0.996396 Precision : 0.997395 Recall : 0.995400 __label__1
107+
F1-Score : 0.998701 Precision : 0.998003 Recall : 0.999400 __label__2
108+
F1-Score : 0.999000 Precision : 0.999400 Recall : 0.998600 __label__3
109+
F1-Score : 0.983119 Precision : 0.987884 Recall : 0.978400 __label__8
110+
F1-Score : 0.997598 Precision : 0.998397 Recall : 0.996800 __label__9
111+
F1-Score : 0.985344 Precision : 0.975873 Recall : 0.995000 __label__4
112+
F1-Score : 0.996898 Precision : 0.997597 Recall : 0.996200 __label__6
113+
F1-Score : 0.998700 Precision : 0.998800 Recall : 0.998600 __label__7
114+
N 50000
115+
P@1 0.995
116+
R@1 0.995
117+
```

cnn/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Text-CNN
2+
3+
- 论文:[Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf)
4+
5+
## 配置
6+
7+
```python
8+
class_num=10 # 类别数
9+
embed_num=5000 # 需要等于字典大小
10+
embed_dim=64 # 字向量维度
11+
kernel_num=128 # 卷积核数量
12+
kernel_size_list=[3,4,5] # 卷积核尺寸
13+
dropout=0.5 # 置 0 的概率
14+
```
15+
16+
## 基本原理
17+
18+
![image](https://user-images.githubusercontent.com/7794103/58327903-63a30180-7e63-11e9-9c82-acc55c8e0b21.png)
19+
20+
该模型的基本思想是对输入序列先做 Embedding,而后使用不同窗口大小的 1D Conv 提取特征,经过 MaxPooing1D 后 一个卷积核得到一个标量,最后全部拼接起来,得到一个向量,然后使用全连接层加 softmax 进行分类。

model.py renamed to cnn/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
class TextCNN(nn.Module):
88
def __init__(self,
99
class_num=None,
10-
embed_num=None,
11-
embed_dim=100,
10+
embed_size=None,
11+
embed_dim=64,
1212
kernel_num=128,
1313
kernel_size_list=(3,4,5),
1414
dropout=0.5):
1515

1616
super(TextCNN, self).__init__()
1717

18-
self.embedding = nn.Embedding(embed_num, embed_dim)
18+
self.embedding = nn.Embedding(embed_size, embed_dim)
1919

2020
self.conv1d_list = nn.ModuleList([
2121
nn.Conv1d(embed_dim, kernel_num, kernel_size)

data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
PAD_WORD = '<PAD>'
99
UNK_WORD = '<UNK>'
1010

11+
# 文档最大长度限制
1112
DOCUMENT_MAX_LENGTH = 500
1213

1314
CATEGIRY_LIST = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

0 commit comments

Comments
 (0)